feat: add BigVGAN vocoder sweep scheduler node

Runs a series of BigVGAN fine-tuning experiments from a JSON sweep file.
Audio clips loaded once, vocoder deep-copied per experiment, results
collected in experiment_summary.json with comparison loss curves.
Resume-aware — skips completed experiments on re-run.

Includes overnight sweep config (8 experiments): snake alpha steps,
GAFilter ablation, phase loss weight, discriminator FM, all_params.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-10 02:39:56 +02:00
parent 4226297735
commit 30127c13ca
3 changed files with 691 additions and 0 deletions
+65
View File
@@ -0,0 +1,65 @@
{
"name": "bigvgan_overnight",
"description": "BigVGAN vocoder quality sweep. Axes: snake_alpha steps, all_params short run, GAFilter on/off, discriminator FM, phase loss weight. All use LoRA-distorted mels as input so vocoder learns to fix LoRA artifacts.",
"data_dir": "/media/unraid/davinci/Selva/BJ/features",
"output_root": "/media/unraid/davinci/Selva/BJ/experiment/bigvgan_overnight",
"base": {
"train_mode": "snake_alpha_only",
"steps": 3000,
"lr": 1e-4,
"batch_size": 8,
"segment_seconds": 0.5,
"lambda_l2sp": 1e-3,
"use_gafilter": true,
"gafilter_kernel_size": 9,
"lambda_phase": 1.0,
"save_every": 1000,
"seed": 42,
"lora_adapter": "/media/unraid/davinci/Selva/BJ/experiment/pissa_sweep/standard_baseline/adapter_final.pt"
},
"experiments": [
{
"id": "snake_3k_baseline",
"description": "Snake alpha + GAFilter baseline. 3000 steps, same as first successful run but longer."
},
{
"id": "snake_5k",
"description": "Snake alpha + GAFilter, 5000 steps. Test if longer training improves further.",
"steps": 5000
},
{
"id": "snake_no_gafilter",
"description": "Snake alpha only, no GAFilter. Isolate GAFilter contribution.",
"use_gafilter": false
},
{
"id": "snake_no_phase",
"description": "Snake alpha + GAFilter, no phase loss. Isolate phase loss contribution.",
"lambda_phase": 0.0
},
{
"id": "snake_phase_2",
"description": "Snake alpha + GAFilter, phase weight 2.0. Stronger phase penalty.",
"lambda_phase": 2.0
},
{
"id": "snake_lr5e-5",
"description": "Snake alpha + GAFilter, lower LR 5e-5. Test if slower converges better.",
"lr": 5e-5,
"steps": 5000
},
{
"id": "snake_disc_fm",
"description": "Snake alpha + GAFilter + discriminator feature matching. Perceptual loss should directly penalize harmonic smearing.",
"discriminator_path": "/media/unraid/davinci/Selva/BJ/experiment/bigvgan_discriminator_optimizer.pt"
},
{
"id": "all_2k_l2sp1e-2",
"description": "All params, 2000 steps, strong L2-SP (1e-2). Test if full param tuning with heavy anchor beats snake-only.",
"train_mode": "all_params",
"steps": 2000,
"lr": 1e-5,
"lambda_l2sp": 1e-2
}
]
}
+1
View File
@@ -21,6 +21,7 @@ _NODES = {
"SelvaActivationSteeringLoader": (".selva_activation_steering_loader", "SelvaActivationSteeringLoader", "SelVA Activation Steering Loader"),
"SelvaBigvganTrainer": (".selva_bigvgan_trainer", "SelvaBigvganTrainer", "SelVA BigVGAN Trainer"),
"SelvaBigvganLoader": (".selva_bigvgan_loader", "SelvaBigvganLoader", "SelVA BigVGAN Loader"),
"SelvaBigvganScheduler": (".selva_bigvgan_scheduler", "SelvaBigvganScheduler", "SelVA BigVGAN Scheduler"),
"SelvaDittoOptimizer": (".selva_ditto_optimizer", "SelvaDittoOptimizer", "SelVA DITTO Optimizer"),
"SelvaDatasetLoader": (".selva_dataset_pipeline", "SelvaDatasetLoader", "SelVA Dataset Loader"),
"SelvaDatasetResampler": (".selva_dataset_pipeline", "SelvaDatasetResampler", "SelVA Dataset Resampler"),
+625
View File
@@ -0,0 +1,625 @@
"""SelVA BigVGAN Vocoder Scheduler — runs a sweep of vocoder fine-tuning experiments.
Each experiment inherits from a shared `base` config and overrides specific keys.
Audio clips are loaded once and reused across all experiments. Results are written
to `experiment_summary.json` (updated after each completed run) and a comparison
loss-curve image.
JSON format:
{
"name": "bigvgan_sweep",
"description": "optional note",
"data_dir": "/path/to/audio/clips",
"output_root": "/path/to/output",
"base": { "train_mode": "snake_alpha_only", "steps": 2000, "lr": 1e-4, ... },
"experiments": [
{"id": "baseline", "description": "..."},
{"id": "all_5k", "train_mode": "all_params", "steps": 5000, "lr": 1e-5},
...
]
}
"""
import copy
import csv
import json
import threading
import time
import traceback
from datetime import datetime, timezone
from pathlib import Path
import numpy as np
import torch
import torchaudio
import comfy.utils
import comfy.model_management
import folder_paths
from .utils import SELVA_CATEGORY, get_device, soft_empty_cache
from .selva_bigvgan_trainer import (
_do_train,
_pregenerate_lora_mels,
_load_wav,
)
from .selva_lora_trainer import _smooth_losses, _pil_to_tensor
from .selva_lora_scheduler import (
_get_system_info,
_resolve_path,
_draw_comparison_curves,
)
# Defaults mirror SelvaBigvganTrainer INPUT_TYPES defaults
_PARAM_DEFAULTS = {
"train_mode": "snake_alpha_only",
"steps": 2000,
"lr": 1e-4,
"batch_size": 4,
"segment_seconds": 2.0,
"lambda_l2sp": 1e-3,
"use_gafilter": True,
"gafilter_kernel_size": 9,
"lambda_phase": 1.0,
"save_every": 500,
"seed": 42,
"discriminator_path": "",
"lora_adapter": "",
}
def _merge_config(base: dict, experiment: dict) -> dict:
"""Merge param defaults + file base + experiment overrides."""
cfg = dict(_PARAM_DEFAULTS)
cfg.update(base)
cfg.update({k: v for k, v in experiment.items() if k not in ("id", "description")})
return cfg
def _parse_training_log(log_path: Path) -> list:
"""Parse BigVGAN training CSV → list of total_loss values."""
losses = []
if not log_path.exists():
return losses
try:
with open(log_path) as f:
reader = csv.DictReader(f)
for row in reader:
losses.append(float(row["total_loss"]))
except Exception:
pass
return losses
def _loss_at_steps(loss_history: list, log_interval: int, save_every: int,
total_steps: int) -> dict:
"""Build {step: loss} at each save_every boundary.
Uses round-to-nearest to handle log_interval that doesn't divide
save_every evenly (e.g. steps=3000 → log_interval=150, save_every=1000).
"""
result = {}
for target in range(save_every, total_steps + 1, save_every):
# loss_history[i] = loss at step (i+1)*log_interval
idx = round(target / log_interval) - 1
if 0 <= idx < len(loss_history):
result[str(target)] = round(loss_history[idx], 6)
return result
class SelvaBigvganScheduler:
"""Runs a sweep of BigVGAN vocoder fine-tuning experiments from a JSON file.
Audio clips are loaded once and reused across all experiments. Each experiment
deep-copies the vocoder and trains independently. Results are written to
`experiment_summary.json` after every completed run so partial results are
preserved if the sweep is interrupted.
"""
OUTPUT_NODE = True
CATEGORY = SELVA_CATEGORY
FUNCTION = "run"
RETURN_TYPES = ("STRING", "IMAGE")
RETURN_NAMES = ("summary_path", "comparison_curves")
OUTPUT_TOOLTIPS = (
"Path to experiment_summary.json — share this file to compare runs.",
"All smoothed loss curves overlaid on the same axes.",
)
DESCRIPTION = (
"Runs a series of BigVGAN vocoder fine-tuning experiments defined in a JSON sweep file. "
"Audio clips are loaded once and reused across all experiments. "
"Results (loss, config, checkpoint paths) are collected in experiment_summary.json."
)
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"model": ("SELVA_MODEL",),
"experiments_file": ("STRING", {
"default": "bigvgan_experiments.json",
"tooltip": (
"Path to JSON sweep file. Relative paths resolve to the ComfyUI "
"models directory; absolute paths are used as-is."
),
}),
}
}
def run(self, model, experiments_file):
# ------------------------------------------------------------------
# 1. Read + validate the JSON file
# ------------------------------------------------------------------
exp_path = Path(experiments_file.strip())
if not exp_path.is_absolute():
candidate = Path(folder_paths.models_dir) / exp_path
if not candidate.exists():
candidate = Path(folder_paths.get_output_directory()) / exp_path
exp_path = candidate
if not exp_path.exists():
raise FileNotFoundError(
f"[BigVGAN Scheduler] Experiment file not found: {exp_path}"
)
spec = json.loads(exp_path.read_text(encoding="utf-8"))
if "experiments" not in spec or not spec["experiments"]:
raise ValueError("[BigVGAN Scheduler] 'experiments' list is missing or empty.")
for i, exp in enumerate(spec["experiments"]):
if "id" not in exp:
raise ValueError(
f"[BigVGAN Scheduler] Experiment at index {i} is missing required 'id' field."
)
sweep_name = spec.get("name", exp_path.stem)
description = spec.get("description", "")
base_cfg = spec.get("base", {})
# ------------------------------------------------------------------
# 2. Resolve data_dir and output_root
# ------------------------------------------------------------------
if "data_dir" not in spec:
raise ValueError("[BigVGAN Scheduler] 'data_dir' is required in the sweep file.")
data_dir = _resolve_path(spec["data_dir"])
output_root = _resolve_path(spec.get("output_root", f"bigvgan_sweeps/{sweep_name}"))
output_root.mkdir(parents=True, exist_ok=True)
device = get_device()
mode = model["mode"]
dtype = model["dtype"]
feature_utils = model["feature_utils"]
mel_converter = feature_utils.mel_converter
strategy = model["strategy"]
if mode == "16k":
original_vocoder = feature_utils.tod.vocoder.vocoder
sample_rate = 16_000
elif mode == "44k":
original_vocoder = feature_utils.tod.vocoder
sample_rate = 44_100
else:
raise ValueError(f"[BigVGAN Scheduler] Unknown mode: {mode}")
print(f"\n[BigVGAN Scheduler] Sweep '{sweep_name}': "
f"{len(spec['experiments'])} experiment(s)", flush=True)
if description:
print(f"[BigVGAN Scheduler] {description}", flush=True)
print(f"[BigVGAN Scheduler] data_dir = {data_dir}", flush=True)
print(f"[BigVGAN Scheduler] output_root = {output_root}\n", flush=True)
# ------------------------------------------------------------------
# 3. Load audio clips once
# ------------------------------------------------------------------
# Find minimum segment length across all experiments so we load enough
min_segment_seconds = float("inf")
for exp in spec["experiments"]:
cfg = _merge_config(base_cfg, exp)
min_segment_seconds = min(min_segment_seconds, float(cfg.get("segment_seconds", 2.0)))
min_segment_samples = int(min_segment_seconds * sample_rate)
audio_files = []
for ext in ("*.wav", "*.flac", "*.mp3", "*.ogg", "*.aac"):
audio_files.extend(data_dir.rglob(ext))
if not audio_files:
raise FileNotFoundError(f"[BigVGAN Scheduler] No audio files in {data_dir}")
print(f"[BigVGAN Scheduler] Loading {len(audio_files)} audio files...", flush=True)
clips = []
for af in audio_files:
try:
wav, sr = _load_wav(af)
if wav.shape[0] > 1:
wav = wav.mean(0, keepdim=True)
if sr != sample_rate:
wav = torchaudio.functional.resample(wav, sr, sample_rate)
wav = wav.squeeze(0) # [L]
if wav.shape[0] >= min_segment_samples:
clips.append(wav.cpu())
else:
print(f" [BigVGAN Scheduler] Skip {af.name}: "
f"shorter than {min_segment_seconds}s", flush=True)
except Exception as e:
print(f" [BigVGAN Scheduler] Failed {af.name}: {e}", flush=True)
if not clips:
raise RuntimeError(
f"[BigVGAN Scheduler] No usable clips (need audio >= {min_segment_seconds}s)"
)
print(f"[BigVGAN Scheduler] {len(clips)} clips ready\n", flush=True)
# ------------------------------------------------------------------
# 4. Offload unused components to free VRAM
# ------------------------------------------------------------------
comfy.model_management.unload_all_models()
feature_utils.to("cpu")
if "generator" in model:
model["generator"].to("cpu")
if "video_enc" in model:
model["video_enc"].to("cpu")
soft_empty_cache()
# ------------------------------------------------------------------
# 5. Pre-compute text CLIP embeddings if any experiment uses LoRA
# ------------------------------------------------------------------
text_clip_cache = {}
any_lora = any(
_merge_config(base_cfg, exp).get("lora_adapter", "")
for exp in spec["experiments"]
)
if any_lora:
npz_files = sorted(data_dir.glob("*.npz"))
if npz_files:
prompt_map = {}
prompts_file = data_dir / "prompts.txt"
if prompts_file.exists():
for line in prompts_file.read_text(encoding="utf-8").splitlines():
line = line.strip()
if not line or line.startswith("#"):
continue
if "|" in line:
fname, prompt = line.split("|", 1)
prompt_map[fname.strip()] = prompt.strip()
default_prompt = data_dir.name
clip_model = feature_utils.clip_model
if clip_model is not None:
clip_model.to(device)
try:
for npz_path in npz_files:
data = dict(np.load(str(npz_path), allow_pickle=False))
prompt = prompt_map.get(
npz_path.name, data.get("prompt", default_prompt)
)
if isinstance(prompt, np.ndarray):
prompt = str(prompt)
tc = feature_utils.encode_text_clip([prompt])
text_clip_cache[npz_path.name] = tc.clone().detach().cpu()
finally:
if clip_model is not None:
clip_model.to("cpu")
soft_empty_cache()
if device.type == "cuda":
torch.cuda.empty_cache()
print(f"[BigVGAN Scheduler] Pre-encoded {len(text_clip_cache)} "
f"CLIP embeddings", flush=True)
# ------------------------------------------------------------------
# 6. Build or restore the summary (resume-aware)
# ------------------------------------------------------------------
summary_path = output_root / "experiment_summary.json"
completed_ids = set()
all_curve_data = []
if summary_path.exists():
try:
existing = json.loads(summary_path.read_text(encoding="utf-8"))
for rec in existing.get("experiments", []):
if rec.get("results", {}).get("status") == "completed":
completed_ids.add(rec["id"])
lh = rec["results"].get("loss_history", [])
all_curve_data.append({
"id": rec["id"],
"loss_history": lh,
"log_interval": rec["results"].get("log_interval", 100),
"start_step": 0,
})
summary = existing
summary["completed_at"] = None
if completed_ids:
print(f"[BigVGAN Scheduler] Resuming — skipping "
f"{len(completed_ids)} completed: "
f"{sorted(completed_ids)}", flush=True)
except Exception as e:
print(f"[BigVGAN Scheduler] Could not read existing summary "
f"({e}) — starting fresh", flush=True)
completed_ids = set()
all_curve_data = []
summary = None
if not completed_ids:
summary = {
"sweep_name": sweep_name,
"description": description,
"sweep_file": str(exp_path),
"started_at": datetime.now(timezone.utc).isoformat(),
"completed_at": None,
"system": _get_system_info(),
"data_dir": str(data_dir),
"n_clips": len(clips),
"experiments": [],
}
def _write_summary():
summary_path.write_text(
json.dumps(summary, indent=2), encoding="utf-8"
)
_write_summary()
# ------------------------------------------------------------------
# 7. Compute total steps for progress bar
# ------------------------------------------------------------------
total_steps = 0
for exp in spec["experiments"]:
if exp["id"] not in completed_ids:
cfg = _merge_config(base_cfg, exp)
total_steps += int(cfg.get("steps", 2000))
pbar = comfy.utils.ProgressBar(max(total_steps, 1))
# ------------------------------------------------------------------
# 8. Run experiments in a worker thread
# ------------------------------------------------------------------
# BigVGAN training requires a fresh thread because ComfyUI runs nodes
# inside torch.inference_mode(). inference_mode is thread-local — a
# new thread starts with it OFF, so all tensor operations produce
# normal autograd-compatible tensors.
_exc = [None]
def _worker():
try:
for exp in spec["experiments"]:
exp_id = exp["id"]
exp_desc = exp.get("description", "")
if exp_id in completed_ids:
print(f"[BigVGAN Scheduler] Skipping '{exp_id}' "
f"(already completed)", flush=True)
continue
cfg = _merge_config(base_cfg, exp)
# ── Extract experiment parameters ────────────────────
train_mode = str(cfg.get("train_mode", "snake_alpha_only"))
exp_steps = int(cfg.get("steps", 2000))
exp_lr = float(cfg.get("lr", 1e-4))
exp_bs = int(cfg.get("batch_size", 4))
exp_seg_s = float(cfg.get("segment_seconds", 2.0))
exp_l2sp = float(cfg.get("lambda_l2sp", 1e-3))
exp_gafilter = bool(cfg.get("use_gafilter", True))
exp_gaf_ks = int(cfg.get("gafilter_kernel_size", 9))
exp_phase = float(cfg.get("lambda_phase", 1.0))
exp_save = int(cfg.get("save_every", 500))
exp_seed = int(cfg.get("seed", 42))
exp_disc = str(cfg.get("discriminator_path", ""))
exp_lora = str(cfg.get("lora_adapter", ""))
segment_samples = int(exp_seg_s * sample_rate)
# Filter clips long enough for this experiment
exp_clips = [c for c in clips if c.shape[0] >= segment_samples]
if not exp_clips:
print(f"[BigVGAN Scheduler] '{exp_id}' skipped: "
f"no clips >= {exp_seg_s}s", flush=True)
summary["experiments"].append({
"id": exp_id, "description": exp_desc,
"config": dict(cfg),
"results": {
"status": "failed",
"error": f"No clips >= {exp_seg_s}s",
"duration_seconds": 0,
},
"checkpoint_path": None,
"output_dir": str(output_root / exp_id),
})
_write_summary()
continue
# ── Resolve discriminator path ───────────────────────
disc_path = None
if exp_disc:
disc_path = Path(exp_disc.strip())
if not disc_path.is_absolute():
disc_path = (
Path(folder_paths.get_output_directory()) / disc_path
)
if not disc_path.exists():
print(f"[BigVGAN Scheduler] '{exp_id}': "
f"discriminator not found: {disc_path}",
flush=True)
disc_path = None
# ── Pre-generate LoRA mels (disk-cached) ─────────────
lora_mel_pairs = None
if exp_lora:
lora_path = Path(exp_lora.strip())
if not lora_path.is_absolute():
lora_path = Path(folder_paths.base_path) / lora_path
if lora_path.exists():
seq_cfg = model["seq_cfg"]
lora_mel_pairs = _pregenerate_lora_mels(
model, data_dir, str(lora_path),
device, dtype, sample_rate,
seq_cfg.duration, seed=exp_seed,
cache_dir=str(output_root),
text_clip_cache=text_clip_cache,
)
if not lora_mel_pairs:
print(f"[BigVGAN Scheduler] '{exp_id}': "
f"no LoRA mel pairs generated",
flush=True)
lora_mel_pairs = None
if device.type == "cuda":
torch.cuda.empty_cache()
else:
print(f"[BigVGAN Scheduler] '{exp_id}': "
f"LoRA adapter not found: {lora_path}",
flush=True)
# ── Output dir ───────────────────────────────────────
exp_dir = output_root / exp_id
exp_dir.mkdir(parents=True, exist_ok=True)
out_path = exp_dir / f"bigvgan_{exp_id}.pt"
print(f"\n[BigVGAN Scheduler] ── Experiment '{exp_id}' ──",
flush=True)
if exp_desc:
print(f"[BigVGAN Scheduler] {exp_desc}", flush=True)
print(f"[BigVGAN Scheduler] mode={train_mode} "
f"steps={exp_steps} lr={exp_lr} bs={exp_bs} "
f"seg={exp_seg_s}s gafilter={exp_gafilter} "
f"phase={exp_phase} l2sp={exp_l2sp}", flush=True)
exp_record = {
"id": exp_id,
"description": exp_desc,
"config": {
"train_mode": train_mode, "steps": exp_steps,
"lr": exp_lr, "batch_size": exp_bs,
"segment_seconds": exp_seg_s,
"lambda_l2sp": exp_l2sp,
"use_gafilter": exp_gafilter,
"gafilter_kernel_size": exp_gaf_ks,
"lambda_phase": exp_phase,
"save_every": exp_save, "seed": exp_seed,
"discriminator_path": exp_disc,
"lora_adapter": exp_lora,
},
"results": {"status": "running"},
"checkpoint_path": None,
"output_dir": str(exp_dir),
}
summary["experiments"].append(exp_record)
_write_summary()
t_start = time.monotonic()
try:
# Ensure mel_converter is on device for this experiment
mel_converter.to(device)
# Fresh vocoder copy — _do_train modifies it in-place
vocoder_copy = copy.deepcopy(original_vocoder)
checkpoint_path = _do_train(
vocoder_copy, mel_converter, exp_clips,
device, dtype, strategy, feature_utils,
segment_samples, sample_rate,
train_mode, exp_steps, exp_lr, exp_bs,
exp_l2sp, exp_gafilter, exp_gaf_ks,
exp_phase, exp_save, exp_seed,
out_path, disc_path, pbar,
lora_mel_pairs,
)
duration = time.monotonic() - t_start
# Parse training CSV for loss history
log_path = exp_dir / f"bigvgan_{exp_id}_training_log.csv"
loss_history = _parse_training_log(log_path)
log_interval = max(1, exp_steps // 20)
smoothed = (
_smooth_losses(loss_history)
if loss_history else []
)
final_loss = (
round(smoothed[-1], 6) if smoothed else None
)
min_loss = (
round(min(smoothed), 6) if smoothed else None
)
min_idx = (
smoothed.index(min(smoothed))
if smoothed else None
)
min_loss_step = (
(min_idx + 1) * log_interval
if min_idx is not None else None
)
if loss_history:
quarter = max(1, len(loss_history) // 4)
loss_std = round(
float(np.std(loss_history[-quarter:])), 6
)
else:
loss_std = None
exp_record["results"] = {
"status": "completed",
"final_loss": final_loss,
"min_loss": min_loss,
"min_loss_step": min_loss_step,
"loss_std_last_quarter": loss_std,
"loss_at_steps": _loss_at_steps(
loss_history, log_interval,
exp_save, exp_steps,
),
"loss_history": [
round(v, 6) for v in loss_history
],
"log_interval": log_interval,
"duration_seconds": round(duration, 1),
}
exp_record["checkpoint_path"] = checkpoint_path
all_curve_data.append({
"id": exp_id,
"loss_history": loss_history,
"log_interval": log_interval,
"start_step": 0,
})
except Exception as e:
duration = time.monotonic() - t_start
print(f"[BigVGAN Scheduler] Experiment '{exp_id}' "
f"failed: {e}", flush=True)
traceback.print_exc()
exp_record["results"] = {
"status": "failed",
"error": str(e),
"duration_seconds": round(duration, 1),
}
finally:
# Clean up vocoder copy to free VRAM
soft_empty_cache()
_write_summary()
except Exception as e:
_exc[0] = e
traceback.print_exc()
t = threading.Thread(target=_worker, daemon=True)
t.start()
t.join()
if _exc[0] is not None:
raise _exc[0]
# ------------------------------------------------------------------
# 9. Finalise summary
# ------------------------------------------------------------------
summary["completed_at"] = datetime.now(timezone.utc).isoformat()
_write_summary()
print(f"\n[BigVGAN Scheduler] Sweep complete. "
f"Summary: {summary_path}", flush=True)
# ------------------------------------------------------------------
# 10. Comparison image
# ------------------------------------------------------------------
comparison_img = _draw_comparison_curves(all_curve_data)
comparison_img.save(str(output_root / "loss_comparison.png"))
comparison_tensor = _pil_to_tensor(comparison_img)
return (str(summary_path), comparison_tensor)