diff --git a/sms_wsj/database/write_files.py b/sms_wsj/database/write_files.py index b3e1f4f..2afb4d1 100644 --- a/sms_wsj/database/write_files.py +++ b/sms_wsj/database/write_files.py @@ -63,7 +63,11 @@ def audio_read(example): with soundfile.SoundFile(wav_file, mode='r') as f: audio_data.append(f.read().T) - example['audio_data'][audio_key] = np.array(audio_data) + try: + example['audio_data'][audio_key] = np.array(audio_data) + except ValueError: + example['audio_data'][audio_key] = np.array(audio_data, dtype="object") + return example