diff --git a/experiments/vocoder_finetune.json b/experiments/vocoder_finetune.json new file mode 100644 index 0000000..c85d5ba --- /dev/null +++ b/experiments/vocoder_finetune.json @@ -0,0 +1,30 @@ +{ + "name": "vocoder_finetune", + "description": "Single run with fine-tuned BJ BigVGAN vocoder injected. Validates vocoder integration with LoRA training. Best known config: lr=3e-4, rank=128.", + "data_dir": "/media/unraid/davinci/Selva/BJ/features", + "output_root": "/media/unraid/davinci/Selva/BJ/experiment/vocoder_finetune", + "base": { + "steps": 10000, + "rank": 128, + "alpha": 0.0, + "lr": 3e-4, + "batch_size": 16, + "warmup_steps": 200, + "grad_accum": 1, + "save_every": 2000, + "seed": 42, + "target": "attn.qkv", + "timestep_mode": "uniform", + "logit_normal_sigma": 1.0, + "curriculum_switch": 0.6, + "lora_dropout": 0.0, + "lora_plus_ratio": 1.0, + "lr_schedule": "constant" + }, + "experiments": [ + { + "id": "r128_lr_3e4_bj_vocoder", + "description": "lr=3e-4 rank=128 with fine-tuned BJ BigVGAN vocoder. Direct comparison baseline against previous best g1_r128_lr_3e4." + } + ] +} diff --git a/nodes/selva_dataset_pipeline.py b/nodes/selva_dataset_pipeline.py index 86772f2..f71434d 100644 --- a/nodes/selva_dataset_pipeline.py +++ b/nodes/selva_dataset_pipeline.py @@ -24,6 +24,19 @@ from .utils import SELVA_CATEGORY AUDIO_DATASET = "AUDIO_DATASET" _AUDIO_EXTS = {".wav", ".flac", ".mp3", ".ogg", ".aac", ".m4a"} +_SOUNDFILE_EXTS = {".wav", ".flac", ".ogg"} # handled natively without FFmpeg + + +def _load_audio(path: Path): + """Load audio file. Uses soundfile for WAV/FLAC/OGG to avoid torchcodec/FFmpeg issues.""" + if path.suffix.lower() in _SOUNDFILE_EXTS: + import soundfile as sf + wav_np, sr = sf.read(str(path), dtype="float32", always_2d=True) # [L, C] + wav = torch.from_numpy(wav_np).T.unsqueeze(0) # [1, C, L] + else: + wav, sr = torchaudio.load(str(path)) # [C, L] + wav = wav.unsqueeze(0).float() # [1, C, L] + return wav, sr class SelvaDatasetLoader: @@ -58,8 +71,7 @@ class SelvaDatasetLoader: dataset = [] for f in sorted(files): try: - wav, sr = torchaudio.load(str(f)) # [C, L] - wav = wav.unsqueeze(0).float() # [1, C, L] + wav, sr = _load_audio(f) dataset.append({"waveform": wav, "sample_rate": sr, "name": f.stem}) except Exception as e: print(f"[DatasetLoader] Skipping {f.name}: {e}", flush=True)