diff --git a/nodes/selva_lora_evaluator.py b/nodes/selva_lora_evaluator.py index ccd0f22..42694e7 100644 --- a/nodes/selva_lora_evaluator.py +++ b/nodes/selva_lora_evaluator.py @@ -78,8 +78,7 @@ def _draw_metric_comparison(adapter_ids: list, metrics_list: list, output_path: "#9B59B6", "#1ABC9C", "#E67E22", "#95A5A6", ] - n = len(adapter_ids) - fig = Figure(figsize=(12, max(4, n * 0.6 + 2)), dpi=110, tight_layout=True) + fig = Figure(figsize=(12, max(4, len(adapter_ids) * 0.6 + 2)), dpi=110, tight_layout=True) axes = [fig.add_subplot(2, 2, i + 1) for i in range(4)] for ax, (key, title) in zip(axes, METRICS): @@ -212,8 +211,8 @@ class SelvaLoraEvaluator: try: ref_wav = _load_audio(audio_path, seq_cfg.sampling_rate, seq_cfg.duration) ref_wav = ref_wav.unsqueeze(0) # [1, L] - ref_out = output_dir / "reference.wav" import shutil + ref_out = output_dir / f"reference{audio_path.suffix}" shutil.copy2(str(audio_path), str(ref_out)) ref_record["wav_path"] = str(ref_out) ref_record["spectral_metrics"] = _spectral_metrics(ref_wav, seq_cfg.sampling_rate)