diff --git a/.gitignore b/.gitignore index 108609f..87b1d11 100644 --- a/.gitignore +++ b/.gitignore @@ -169,3 +169,4 @@ fadtk/ scoreq/ fairseq/ UTMOSv2/ +/data diff --git a/run.sh b/run.sh index 29d1632..35bfdcb 100644 --- a/run.sh +++ b/run.sh @@ -1,18 +1,35 @@ -stage=2 +stage=0 # download data if [ $stage -eq 0 ]; then echo stage $stage: Prepare data - if [ ! -d data/LibriSpeech/test-clean ]; then - mkdir -p data - wget http://www.openslr.org/resources/12/test-clean.tar.gz -P ./data - (cd ./data && tar -xvzf test-clean.tar.gz) + # librispeech + if [ ! -d data//LibriSpeech/test-clean ]; then + mkdir -p data/ + wget http://www.openslr.org/resources/12/test-clean.tar.gz -P data/ + (cd data/ && tar -xvzf test-clean.tar.gz) rm data/test-clean.tar.gz fi - if [ ! -d data/LibriSpeech/test-clean/prepared ]; then - python scripts/prepare_librispeech-test-clean.py --root_dir data/LibriSpeech/test-clean + if [ ! -d data//LibriSpeech/test-clean/prepared ]; then + python scripts/prepare_librispeech-test-clean.py --root_dir data//LibriSpeech/test-clean + fi + + # musdb + if [ ! -d data/musdb/test ]; then + wget https://zenodo.org/records/3338373/files/musdb18hq.zip -P data/ + (cd data/ && unzip musdb18hq.zip -d ./musdb) + rm data/musdb18hq.zip + fi + + if [ ! -d data/musdb/prepared ]; then + python scripts/prepare_musdb.py --main_directory data/musdb/ --output_dir data/musdb/prepared --chunk_length 5.0 + fi + + # audioset + if [ ! -d data/audioset ]; then + python scripts/prepare_audioset-test.py --output_dir data/audioset fi fi @@ -20,19 +37,22 @@ fi # Evaluation pred_path=data/LibriSpeech/test-clean/prepared/ori.scp gt_path=data/LibriSpeech/test-clean/prepared/ori.scp +tag=musdb_encodec_24k_12bps +eval_sr=24000 if [ $stage -eq 1 ]; then - result_path="test_result" + result_path="test_result_${tag}" echo stage $stage: Evaluation if test -f ${result_path}; then echo ${result_path} exists else python versa/bin/scorer.py \ - --score_config egs/speech.yaml \ + --score_config egs/general.yaml \ --use_gpu True \ --gt ${gt_path} \ --pred ${pred_path} \ - --output_file ${result_path} + --output_file ${result_path} \ + --eval_sr ${eval_sr} \ # change in versa necessary! fi python scripts/average_result.py --file_path ${result_path} >> ${result_path} diff --git a/scripts/prepare_audioset-test.py b/scripts/prepare_audioset-test.py new file mode 100644 index 0000000..548c326 --- /dev/null +++ b/scripts/prepare_audioset-test.py @@ -0,0 +1,83 @@ +from huggingface_hub import HfApi +import os, argparse +import tarfile + +# Initialize the Hugging Face API +api = HfApi() + +def main(): + # Set up argument parsing + parser = argparse.ArgumentParser(description="Split 'mix.wav' files into chunks.") + parser.add_argument('--output_dir', type=str, help="Where AudioSet files will be saved.") + + # Parse the arguments + args = parser.parse_args() + + # Define the repository details + repo_id = "agkphysics/AudioSet" # Dataset repository + repo_path = "data" + local_save_dir = args.output_dir + audio_dump = os.path.join(local_save_dir, "audio_files") + + # Create the local directory if it doesn't exist + os.makedirs(local_save_dir, exist_ok=True) + + # List files in the dataset repository (specify repo_type="dataset") + repo_files = api.list_repo_files(repo_id=repo_id, repo_type="dataset") + + # Filter files matching the desired pattern and path + files_to_download = [ + file for file in repo_files + if file.startswith(repo_path) and file.endswith(".tar") and "eval" in file + ] + + print(f"Files to download: {files_to_download}") + + # Base URL for the dataset files + base_url = f"https://huggingface.co/datasets/{repo_id}/resolve/main/" + + # Download each file + for file_path in files_to_download: + file_url = base_url + file_path + local_file_path = os.path.join(local_save_dir, os.path.basename(file_path)) + if os.path.exists(local_file_path): + print(f"File {local_file_path} already exists, skipping download.") + else: + print(f"Downloading {file_url} to {local_file_path}...") + + # Download the file manually using requests + import requests + response = requests.get(file_url, stream=True) + if response.status_code == 200: + with open(local_file_path, "wb") as f: + for chunk in response.iter_content(chunk_size=1024): + if chunk: + f.write(chunk) + else: + print(f"Failed to download {file_url}, status code: {response.status_code}") + + # Extract the .tar file + print(f"Extracting {local_file_path} to {audio_dump}...") + try: + with tarfile.open(local_file_path, "r") as tar: + tar.extractall( + path=audio_dump, + members=[ + member for member in tar.getmembers() + if member.isfile() # Only extract files, skip directories + ] + ) + except Exception as e: + print(f"Error extracting {local_file_path}: {e}") + + # Delete the .tar file after successful extraction + print(f"Deleting {local_file_path}...") + try: + os.remove(local_file_path) + except Exception as e: + print(f"Error deleting {local_file_path}: {e}") + + print("All files downloaded and extracted.") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/scripts/prepare_musdb.py b/scripts/prepare_musdb.py new file mode 100644 index 0000000..2bd5d43 --- /dev/null +++ b/scripts/prepare_musdb.py @@ -0,0 +1,65 @@ +import os +import argparse +import soundfile as sf +import numpy as np +from glob import glob +from tqdm import tqdm +import librosa + +def split_audio(file_path, chunk_length, output_dir, parent_folder): + """Split the audio file into non-overlapping chunks and save them.""" + # Load the audio file using soundfile + print(f"Processing: {file_path}") + waveform, sample_rate = librosa.load(file_path, mono=True) + total_length = len(waveform) / sample_rate # Total length in seconds + + # Calculate the number of chunks + num_chunks = int(np.ceil(total_length / chunk_length)) + + # Split and save each chunk + for i in range(num_chunks): + start_time = i * chunk_length + end_time = min((i + 1) * chunk_length, total_length) + + # Find the sample indices for the chunk + start_sample = int(start_time * sample_rate) + end_sample = int(end_time * sample_rate) + + # Extract the chunk waveform + chunk_waveform = waveform[start_sample:end_sample] + + # Build the output file path + output_filename = f"{parent_folder}_{start_time}_{end_time}.wav" + output_file_path = os.path.join(output_dir, output_filename) + + # Save the chunk to the output directory + sf.write(output_file_path, chunk_waveform, sample_rate) + print(f"Saved chunks for : {parent_folder}") + +def process_directory(main_directory, chunk_length, output_dir): + """Process all subdirectories in the main directory and split 'mix.wav' files.""" + for mix_wav_path in tqdm(glob(os.path.join(main_directory, '**/mixture.wav'), recursive=True)): + # Get the parent folder name + parent_folder = mix_wav_path.split('/')[-2].replace(' ','_') + # Split the 'mix.wav' into chunks + split_audio(mix_wav_path, chunk_length, output_dir, parent_folder) + +def main(): + # Set up argument parsing + parser = argparse.ArgumentParser(description="Split 'mix.wav' files into chunks.") + parser.add_argument('--main_directory', type=str, help="Path to the main directory containing subfolders.") + parser.add_argument('--output_dir', type=str, help="Directory where the chunks will be saved.") + parser.add_argument('--chunk_length', type=float, help="Length of each chunk in seconds.") + + # Parse the arguments + args = parser.parse_args() + + # Ensure output directory exists + if not os.path.exists(args.output_dir): + os.makedirs(args.output_dir) + + # Process the directory + process_directory(args.main_directory, args.chunk_length, args.output_dir) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/versa/bin/scorer.py b/versa/bin/scorer.py index d217ea7..cb00baa 100644 --- a/versa/bin/scorer.py +++ b/versa/bin/scorer.py @@ -17,6 +17,11 @@ def get_parser() -> argparse.Namespace: """Get argument parser.""" parser = argparse.ArgumentParser(description="Speech Evaluation Interface") + parser.add_argument( + "--eval_sr", + type=int, + help="All wfs wil lbe resampeld to eval_sr prior to eval.", + ) parser.add_argument( "--pred", type=str, @@ -140,7 +145,6 @@ def main(): ) assert len(score_config) > 0, "no scoring function is provided" - score_info = list_scoring( gen_files, score_modules, @@ -148,6 +152,7 @@ def main(): text_info, output_file=args.output_file, io=args.io, + eval_sr=args.eval_sr ) logging.info("Summary: {}".format(load_summary(score_info))) diff --git a/versa/scorer_shared.py b/versa/scorer_shared.py index a922f08..b27cedd 100644 --- a/versa/scorer_shared.py +++ b/versa/scorer_shared.py @@ -47,7 +47,7 @@ def audio_loader_setup(audio, io): audio_files = kaldiio.load_scp(audio) elif io == "dir": audio_files = find_files(audio) - elif io == "soundfile": + else: audio_files = {} with open(audio) as f: for line in f.readlines(): @@ -362,40 +362,6 @@ def load_score_modules(score_config, use_gt=True, use_gt_text=False, use_gpu=Fal ), } logging.info("Initiate Whisper WER calculation successfully") - - elif config["name"] == "scoreq_ref": - if not use_gt: - logging.warning("Cannot use scoreq_ref because no gt audio is provided") - continue - - logging.info("Loadding scoreq metrics with reference") - from versa import scoreq_ref_setup, scoreq_ref - model = scoreq_ref_setup( - data_domain=config.get("data_domain", "synthetic"), - cache_dir=config.get("model_cache", "./scoreq_pt-models"), - use_gpu=use_gpu, - ) - - score_modules["scoreq_ref"] = { - "module": scoreq_ref, - "model": model, - } - logging.info("Initiate scoreq (with reference) successfully") - - elif config["name"] == "scoreq_nr": - logging.info("Loadding scoreq metrics without reference") - from versa import scoreq_nr_setup, scoreq_nr - model = scoreq_nr_setup( - data_domain=config.get("data_domain", "synthetic"), - cache_dir=config.get("model_cache", "./scoreq_pt-models"), - use_gpu=use_gpu, - ) - - score_modules["scoreq_nr"] = { - "module": scoreq_nr, - "model": model, - } - logging.info("Initiate scoreq (with reference) successfully") return score_modules @@ -451,14 +417,6 @@ def use_score_modules(score_modules, gen_wav, gt_wav, gen_sr, text=None): text, gen_sr, ) - elif key == "scoreq_ref": - score = score_modules[key]["module"]( - score_modules[key]["model"], - gen_wav, gt_wav, gen_sr) - elif key == "scoreq_nr": - score = score_modules[key]["module"]( - score_modules[key]["model"], - gen_wav, gen_sr) else: raise NotImplementedError(f"Not supported {key}") @@ -474,6 +432,7 @@ def list_scoring( text_info=None, output_file=None, io="kaldi", + eval_sr=16_000 ): if output_file is not None: f = open(output_file, "w", encoding="utf-8") @@ -531,16 +490,16 @@ def list_scoring( else: text = None - if gt_sr is not None and gen_sr > gt_sr: + if gt_sr != eval_sr: logging.warning( - "Resampling the generated audio to match the ground truth audio" + "Resampling the ground truth audio to match the eval sr" ) - gen_wav = librosa.resample(gen_wav, orig_sr=gen_sr, target_sr=gt_sr) - elif gt_sr is not None and gen_sr < gt_sr: + gt_wav = librosa.resample(gt_wav, orig_sr=gt_sr, target_sr=eval_sr) + if gen_sr != eval_sr: logging.warning( - "Resampling the ground truth audio to match the generated audio" + "Resampling the generated audio to match the eval sr" ) - gt_wav = librosa.resample(gt_wav, orig_sr=gt_sr, target_sr=gen_sr) + gen_wav = librosa.resample(gen_wav, orig_sr=gen_sr, target_sr=eval_sr) utt_score = {"key": key}