feat: add dataset pipeline nodes + latent augmentation for LoRA trainer
New dataset pipeline nodes: - SelvaDatasetSpectralMatcher: batch spectral EQ toward VAE distribution - SelvaDatasetHfSmoother: batch HF attenuation for codec compatibility - SelvaDatasetAugmenter: gain/pitch/time-stretch variants with npz origin tracking Improvements: - Inspector: silence detection (max_silence_fraction param) - Saver: origin_name lookup for augmented clips' npz pairing - LoRA trainer: latent_mixup_alpha + latent_noise_sigma regularization - LoRA trainer: one-time SR mismatch warning in _load_audio Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -32,6 +32,9 @@ _NODES = {
|
|||||||
"SelvaDatasetSaver": (".selva_dataset_pipeline", "SelvaDatasetSaver", "SelVA Dataset Saver"),
|
"SelvaDatasetSaver": (".selva_dataset_pipeline", "SelvaDatasetSaver", "SelVA Dataset Saver"),
|
||||||
"SelvaHarmonicExciter": (".selva_audio_postprocess", "SelvaHarmonicExciter", "SelVA Harmonic Exciter"),
|
"SelvaHarmonicExciter": (".selva_audio_postprocess", "SelvaHarmonicExciter", "SelVA Harmonic Exciter"),
|
||||||
"SelvaOutputNormalizer": (".selva_audio_postprocess", "SelvaOutputNormalizer", "SelVA Output Normalizer"),
|
"SelvaOutputNormalizer": (".selva_audio_postprocess", "SelvaOutputNormalizer", "SelVA Output Normalizer"),
|
||||||
|
"SelvaDatasetSpectralMatcher": (".selva_dataset_pipeline", "SelvaDatasetSpectralMatcher", "SelVA Dataset Spectral Matcher"),
|
||||||
|
"SelvaDatasetHfSmoother": (".selva_dataset_pipeline", "SelvaDatasetHfSmoother", "SelVA Dataset HF Smoother"),
|
||||||
|
"SelvaDatasetAugmenter": (".selva_dataset_pipeline", "SelvaDatasetAugmenter", "SelVA Dataset Augmenter"),
|
||||||
}
|
}
|
||||||
|
|
||||||
for key, (module_path, class_name, display_name) in _NODES.items():
|
for key, (module_path, class_name, display_name) in _NODES.items():
|
||||||
|
|||||||
@@ -9,6 +9,12 @@ Typical chain:
|
|||||||
↓ AUDIO_DATASET
|
↓ AUDIO_DATASET
|
||||||
SelvaDatasetCompressor (optional)
|
SelvaDatasetCompressor (optional)
|
||||||
↓ AUDIO_DATASET
|
↓ AUDIO_DATASET
|
||||||
|
SelvaDatasetSpectralMatcher (optional — batch spectral EQ)
|
||||||
|
↓ AUDIO_DATASET
|
||||||
|
SelvaDatasetHfSmoother (optional — batch HF attenuation)
|
||||||
|
↓ AUDIO_DATASET
|
||||||
|
SelvaDatasetAugmenter (optional — gain/pitch/stretch variants)
|
||||||
|
↓ AUDIO_DATASET
|
||||||
SelvaDatasetInspector (optional)
|
SelvaDatasetInspector (optional)
|
||||||
↓ AUDIO_DATASET + STRING report
|
↓ AUDIO_DATASET + STRING report
|
||||||
SelvaDatasetItemExtractor → AUDIO (bridges to save/preview nodes)
|
SelvaDatasetItemExtractor → AUDIO (bridges to save/preview nodes)
|
||||||
@@ -342,6 +348,11 @@ class SelvaDatasetInspector:
|
|||||||
"default": True,
|
"default": True,
|
||||||
"tooltip": "Flag clips with a hard HF shelf above 15 kHz (MP3/codec artifact signature).",
|
"tooltip": "Flag clips with a hard HF shelf above 15 kHz (MP3/codec artifact signature).",
|
||||||
}),
|
}),
|
||||||
|
"max_silence_fraction": ("FLOAT", {
|
||||||
|
"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.05,
|
||||||
|
"tooltip": "Flag clips where more than this fraction of frames are near-silent "
|
||||||
|
"(< -60 dBFS). Set to 0 to disable silence detection.",
|
||||||
|
}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -355,7 +366,8 @@ class SelvaDatasetInspector:
|
|||||||
"Connect report to a ShowText node to preview in the UI."
|
"Connect report to a ShowText node to preview in the UI."
|
||||||
)
|
)
|
||||||
|
|
||||||
def inspect(self, dataset, skip_rejected: bool, min_snr_db: float, check_codec_artifacts: bool):
|
def inspect(self, dataset, skip_rejected: bool, min_snr_db: float,
|
||||||
|
check_codec_artifacts: bool, max_silence_fraction: float = 0.5):
|
||||||
clean = []
|
clean = []
|
||||||
flagged = []
|
flagged = []
|
||||||
lines = ["SelVA Dataset Inspector Report", "=" * 40]
|
lines = ["SelVA Dataset Inspector Report", "=" * 40]
|
||||||
@@ -380,6 +392,16 @@ class SelvaDatasetInspector:
|
|||||||
if check_codec_artifacts and _check_hf_shelf(wav, sr):
|
if check_codec_artifacts and _check_hf_shelf(wav, sr):
|
||||||
issues.append("codec artifact (HF shelf > 15 kHz)")
|
issues.append("codec artifact (HF shelf > 15 kHz)")
|
||||||
|
|
||||||
|
# Silence detection
|
||||||
|
if max_silence_fraction > 0:
|
||||||
|
mono = wav[0].mean(0)
|
||||||
|
if mono.shape[0] >= 2048:
|
||||||
|
frames = mono.unfold(0, 2048, 512)
|
||||||
|
rms = frames.pow(2).mean(-1).sqrt()
|
||||||
|
silent_frac = (rms < 1e-3).float().mean().item()
|
||||||
|
if silent_frac > max_silence_fraction:
|
||||||
|
issues.append(f"mostly silent ({silent_frac:.0%} < -60 dBFS)")
|
||||||
|
|
||||||
if issues:
|
if issues:
|
||||||
flagged.append(name)
|
flagged.append(name)
|
||||||
lines.append(f" FLAGGED {name}: {', '.join(issues)}")
|
lines.append(f" FLAGGED {name}: {', '.join(issues)}")
|
||||||
@@ -507,7 +529,9 @@ class SelvaDatasetSaver:
|
|||||||
saved += 1
|
saved += 1
|
||||||
|
|
||||||
if npz_src is not None:
|
if npz_src is not None:
|
||||||
npz_path = npz_src / f"{name}.npz"
|
# Augmented clips store their origin name — use it to find the .npz
|
||||||
|
lookup = item.get("origin_name", name)
|
||||||
|
npz_path = npz_src / f"{lookup}.npz"
|
||||||
if npz_path.exists():
|
if npz_path.exists():
|
||||||
shutil.copy2(str(npz_path), str(out / f"{name}.npz"))
|
shutil.copy2(str(npz_path), str(out / f"{name}.npz"))
|
||||||
npz_copied += 1
|
npz_copied += 1
|
||||||
@@ -525,3 +549,240 @@ class SelvaDatasetSaver:
|
|||||||
report = "\n".join(lines)
|
report = "\n".join(lines)
|
||||||
print(report, flush=True)
|
print(report, flush=True)
|
||||||
return (report,)
|
return (report,)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Batch wrappers for audio preprocessors ───────────────────────────────────
|
||||||
|
|
||||||
|
class SelvaDatasetSpectralMatcher:
|
||||||
|
"""Apply SelVA Spectral Matcher to every clip in an AUDIO_DATASET.
|
||||||
|
|
||||||
|
Wraps SelvaSpectralMatcher so it works on batch datasets instead of
|
||||||
|
individual AUDIO items. Same parameters — see that node for details.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"dataset": (AUDIO_DATASET,),
|
||||||
|
"mode": (["44k", "16k"], {
|
||||||
|
"tooltip": "Must match the SelVA model you are training. "
|
||||||
|
"44k = large model, 16k = small model.",
|
||||||
|
}),
|
||||||
|
"strength": ("FLOAT", {
|
||||||
|
"default": 0.8, "min": 0.0, "max": 1.0, "step": 0.05,
|
||||||
|
"tooltip": "0 = no correction, 1 = full match to VAE distribution.",
|
||||||
|
}),
|
||||||
|
"max_gain_db": ("FLOAT", {
|
||||||
|
"default": 12.0, "min": 1.0, "max": 30.0, "step": 1.0,
|
||||||
|
"tooltip": "Clamps per-band gain to ±dB.",
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = (AUDIO_DATASET,)
|
||||||
|
RETURN_NAMES = ("dataset",)
|
||||||
|
FUNCTION = "process"
|
||||||
|
CATEGORY = SELVA_CATEGORY
|
||||||
|
DESCRIPTION = (
|
||||||
|
"Apply adaptive spectral matching to every clip in a dataset. "
|
||||||
|
"Batch version of SelVA Spectral Matcher — same per-band EQ toward the "
|
||||||
|
"VAE's expected distribution."
|
||||||
|
)
|
||||||
|
|
||||||
|
def process(self, dataset, mode: str, strength: float, max_gain_db: float):
|
||||||
|
from .selva_audio_preprocessors import SelvaSpectralMatcher
|
||||||
|
matcher = SelvaSpectralMatcher()
|
||||||
|
out = []
|
||||||
|
for item in dataset:
|
||||||
|
audio = {"waveform": item["waveform"], "sample_rate": item["sample_rate"]}
|
||||||
|
(result,) = matcher.process(audio, mode, strength, max_gain_db)
|
||||||
|
new_item = dict(item) # preserve origin_name and any extra keys
|
||||||
|
new_item["waveform"] = result["waveform"]
|
||||||
|
new_item["sample_rate"] = result["sample_rate"]
|
||||||
|
out.append(new_item)
|
||||||
|
print(f"[DatasetSpectralMatcher] {len(out)} clips processed "
|
||||||
|
f"mode={mode} strength={strength}", flush=True)
|
||||||
|
return (out,)
|
||||||
|
|
||||||
|
|
||||||
|
class SelvaDatasetHfSmoother:
|
||||||
|
"""Apply SelVA HF Smoother to every clip in an AUDIO_DATASET.
|
||||||
|
|
||||||
|
Wraps SelvaHfSmoother so it works on batch datasets instead of
|
||||||
|
individual AUDIO items. Same parameters — see that node for details.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"dataset": (AUDIO_DATASET,),
|
||||||
|
"cutoff_hz": ("FLOAT", {
|
||||||
|
"default": 12000.0, "min": 2000.0, "max": 20000.0, "step": 500.0,
|
||||||
|
"tooltip": "Low-pass cutoff. 12 kHz is gentle; lower = more aggressive.",
|
||||||
|
}),
|
||||||
|
"blend": ("FLOAT", {
|
||||||
|
"default": 0.7, "min": 0.0, "max": 1.0, "step": 0.05,
|
||||||
|
"tooltip": "0 = original, 1 = fully filtered.",
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = (AUDIO_DATASET,)
|
||||||
|
RETURN_NAMES = ("dataset",)
|
||||||
|
FUNCTION = "process"
|
||||||
|
CATEGORY = SELVA_CATEGORY
|
||||||
|
DESCRIPTION = (
|
||||||
|
"Apply soft HF attenuation to every clip in a dataset. "
|
||||||
|
"Batch version of SelVA HF Smoother — blends a low-pass filtered copy "
|
||||||
|
"with the original to tame extreme HF content."
|
||||||
|
)
|
||||||
|
|
||||||
|
def process(self, dataset, cutoff_hz: float, blend: float):
|
||||||
|
from .selva_audio_preprocessors import SelvaHfSmoother
|
||||||
|
smoother = SelvaHfSmoother()
|
||||||
|
out = []
|
||||||
|
for item in dataset:
|
||||||
|
audio = {"waveform": item["waveform"], "sample_rate": item["sample_rate"]}
|
||||||
|
(result,) = smoother.process(audio, cutoff_hz, blend)
|
||||||
|
new_item = dict(item) # preserve origin_name and any extra keys
|
||||||
|
new_item["waveform"] = result["waveform"]
|
||||||
|
new_item["sample_rate"] = result["sample_rate"]
|
||||||
|
out.append(new_item)
|
||||||
|
print(f"[DatasetHfSmoother] {len(out)} clips processed "
|
||||||
|
f"cutoff={cutoff_hz:.0f}Hz blend={blend:.2f}", flush=True)
|
||||||
|
return (out,)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Dataset augmenter ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class SelvaDatasetAugmenter:
|
||||||
|
"""Create augmented variants of each clip to expand a small dataset.
|
||||||
|
|
||||||
|
Supports gain variation (always available) and optionally pitch shift
|
||||||
|
and time stretch via audiomentations. Install audiomentations for the
|
||||||
|
full feature set: ``pip install audiomentations``
|
||||||
|
|
||||||
|
Each original clip produces ``variants_per_clip`` augmented copies.
|
||||||
|
Originals are kept by default (toggle ``keep_originals``).
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"dataset": (AUDIO_DATASET,),
|
||||||
|
"variants_per_clip": ("INT", {
|
||||||
|
"default": 2, "min": 1, "max": 20,
|
||||||
|
"tooltip": "Number of augmented copies per original clip.",
|
||||||
|
}),
|
||||||
|
"gain_range_db": ("FLOAT", {
|
||||||
|
"default": 3.0, "min": 0.0, "max": 12.0, "step": 0.5,
|
||||||
|
"tooltip": "Random gain ±dB applied to each variant. 3 dB is subtle.",
|
||||||
|
}),
|
||||||
|
"seed": ("INT", {"default": 42}),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"pitch_range_semitones": ("FLOAT", {
|
||||||
|
"default": 0.0, "min": 0.0, "max": 4.0, "step": 0.25,
|
||||||
|
"tooltip": "Random pitch shift ±semitones. Requires audiomentations. 0 = disabled.",
|
||||||
|
}),
|
||||||
|
"time_stretch_range": ("FLOAT", {
|
||||||
|
"default": 0.0, "min": 0.0, "max": 0.3, "step": 0.05,
|
||||||
|
"tooltip": "Random time stretch ±fraction (0.1 = 90%–110% speed). "
|
||||||
|
"Requires audiomentations. 0 = disabled.",
|
||||||
|
}),
|
||||||
|
"keep_originals": ("BOOLEAN", {
|
||||||
|
"default": True,
|
||||||
|
"tooltip": "Include the original unaugmented clips in the output.",
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = (AUDIO_DATASET,)
|
||||||
|
RETURN_NAMES = ("dataset",)
|
||||||
|
FUNCTION = "augment"
|
||||||
|
CATEGORY = SELVA_CATEGORY
|
||||||
|
DESCRIPTION = (
|
||||||
|
"Create augmented variants of each clip (gain, pitch, time stretch) "
|
||||||
|
"to expand small training datasets. Gain is always available; pitch and "
|
||||||
|
"time stretch require audiomentations (pip install audiomentations)."
|
||||||
|
)
|
||||||
|
|
||||||
|
def augment(self, dataset, variants_per_clip: int, gain_range_db: float,
|
||||||
|
seed: int, pitch_range_semitones: float = 0.0,
|
||||||
|
time_stretch_range: float = 0.0, keep_originals: bool = True):
|
||||||
|
rng = np.random.RandomState(seed)
|
||||||
|
|
||||||
|
# Try audiomentations for pitch/stretch
|
||||||
|
use_am = False
|
||||||
|
am_compose = None
|
||||||
|
needs_am = pitch_range_semitones > 0 or time_stretch_range > 0
|
||||||
|
if needs_am:
|
||||||
|
try:
|
||||||
|
import audiomentations as am
|
||||||
|
transforms = []
|
||||||
|
if pitch_range_semitones > 0:
|
||||||
|
transforms.append(am.PitchShift(
|
||||||
|
min_semitones=-pitch_range_semitones,
|
||||||
|
max_semitones=pitch_range_semitones,
|
||||||
|
p=0.5,
|
||||||
|
))
|
||||||
|
if time_stretch_range > 0:
|
||||||
|
transforms.append(am.TimeStretch(
|
||||||
|
min_rate=1.0 - time_stretch_range,
|
||||||
|
max_rate=1.0 + time_stretch_range,
|
||||||
|
leave_length_unchanged=True,
|
||||||
|
p=0.5,
|
||||||
|
))
|
||||||
|
am_compose = am.Compose(transforms)
|
||||||
|
use_am = True
|
||||||
|
except ImportError:
|
||||||
|
print("[DatasetAugmenter] audiomentations not installed — "
|
||||||
|
"pitch_shift and time_stretch disabled. "
|
||||||
|
"Install: pip install audiomentations", flush=True)
|
||||||
|
|
||||||
|
out = []
|
||||||
|
if keep_originals:
|
||||||
|
out.extend(dataset)
|
||||||
|
|
||||||
|
for item in dataset:
|
||||||
|
wav = item["waveform"] # [1, C, L]
|
||||||
|
sr = item["sample_rate"]
|
||||||
|
name = item["name"]
|
||||||
|
|
||||||
|
for v in range(variants_per_clip):
|
||||||
|
# Gain variation (always applied)
|
||||||
|
gain_db = rng.uniform(-gain_range_db, gain_range_db) if gain_range_db > 0 else 0.0
|
||||||
|
gain_lin = 10.0 ** (gain_db / 20.0)
|
||||||
|
wav_aug = wav * gain_lin
|
||||||
|
|
||||||
|
# Pitch/stretch via audiomentations
|
||||||
|
if use_am and am_compose is not None:
|
||||||
|
wav_np = wav_aug[0].numpy() # [C, L] float32
|
||||||
|
if wav_np.shape[0] == 1:
|
||||||
|
wav_np = wav_np[0] # [L] mono for audiomentations
|
||||||
|
wav_np = am_compose(samples=wav_np, sample_rate=sr)
|
||||||
|
if wav_np.ndim == 1:
|
||||||
|
wav_np = wav_np[np.newaxis, :] # back to [1, L]
|
||||||
|
wav_aug = torch.from_numpy(wav_np).unsqueeze(0) # [1, C, L]
|
||||||
|
|
||||||
|
# Prevent clipping
|
||||||
|
peak = wav_aug.abs().max()
|
||||||
|
if peak > 1.0:
|
||||||
|
wav_aug = wav_aug / peak
|
||||||
|
|
||||||
|
out.append({
|
||||||
|
"waveform": wav_aug,
|
||||||
|
"sample_rate": sr,
|
||||||
|
"name": f"{name}_aug{v:02d}",
|
||||||
|
"origin_name": name,
|
||||||
|
})
|
||||||
|
|
||||||
|
print(f"[DatasetAugmenter] {len(dataset)} originals → {len(out)} total clips "
|
||||||
|
f"gain=±{gain_range_db:.1f}dB"
|
||||||
|
+ (f" pitch=±{pitch_range_semitones:.1f}st" if pitch_range_semitones > 0 else "")
|
||||||
|
+ (f" stretch=±{time_stretch_range:.0%}" if time_stretch_range > 0 else ""),
|
||||||
|
flush=True)
|
||||||
|
return (out,)
|
||||||
|
|||||||
@@ -72,6 +72,10 @@ def _load_audio(path: Path, target_sr: int, duration: float) -> torch.Tensor:
|
|||||||
waveform = waveform.mean(0, keepdim=True)
|
waveform = waveform.mean(0, keepdim=True)
|
||||||
waveform = waveform.squeeze(0).float()
|
waveform = waveform.squeeze(0).float()
|
||||||
if sr != target_sr:
|
if sr != target_sr:
|
||||||
|
if not getattr(_load_audio, "_sr_warned", False):
|
||||||
|
print(f"[LoRA Trainer] WARNING: audio sr={sr} ≠ target {target_sr}, resampling on-the-fly. "
|
||||||
|
f"Pre-resample with SelVA Dataset Resampler for faster loading.", flush=True)
|
||||||
|
_load_audio._sr_warned = True
|
||||||
waveform = torchaudio.functional.resample(
|
waveform = torchaudio.functional.resample(
|
||||||
waveform.unsqueeze(0), sr, target_sr).squeeze(0)
|
waveform.unsqueeze(0), sr, target_sr).squeeze(0)
|
||||||
target_len = int(duration * target_sr)
|
target_len = int(duration * target_sr)
|
||||||
@@ -557,6 +561,17 @@ class SelvaLoraTrainer:
|
|||||||
"cosine: decay from lr to ~0 following a cosine curve — "
|
"cosine: decay from lr to ~0 following a cosine curve — "
|
||||||
"prevents oscillation when LR is slightly too high.",
|
"prevents oscillation when LR is slightly too high.",
|
||||||
}),
|
}),
|
||||||
|
"latent_mixup_alpha": ("FLOAT", {
|
||||||
|
"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.05,
|
||||||
|
"tooltip": "Beta distribution alpha for latent mixup (MusicLDM, arXiv:2308.01546). "
|
||||||
|
"0 = disabled. 0.4 recommended. Mixes pairs of training latents "
|
||||||
|
"to reduce memorization on small datasets.",
|
||||||
|
}),
|
||||||
|
"latent_noise_sigma": ("FLOAT", {
|
||||||
|
"default": 0.0, "min": 0.0, "max": 0.1, "step": 0.005,
|
||||||
|
"tooltip": "Additive Gaussian noise on training latents, scaled by x1.std(). "
|
||||||
|
"0 = disabled. 0.01–0.03 adds mild regularization against overfitting.",
|
||||||
|
}),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -581,7 +596,8 @@ class SelvaLoraTrainer:
|
|||||||
grad_accum=1, save_every=500, resume_path="", seed=42,
|
grad_accum=1, save_every=500, resume_path="", seed=42,
|
||||||
timestep_mode="uniform", logit_normal_sigma=1.0, curriculum_switch=0.6,
|
timestep_mode="uniform", logit_normal_sigma=1.0, curriculum_switch=0.6,
|
||||||
lora_dropout=0.0, lora_plus_ratio=1.0,
|
lora_dropout=0.0, lora_plus_ratio=1.0,
|
||||||
init_mode="pissa", use_rslora=True, lr_schedule="constant"):
|
init_mode="pissa", use_rslora=True, lr_schedule="constant",
|
||||||
|
latent_mixup_alpha=0.0, latent_noise_sigma=0.0):
|
||||||
|
|
||||||
torch.manual_seed(seed)
|
torch.manual_seed(seed)
|
||||||
random.seed(seed)
|
random.seed(seed)
|
||||||
@@ -633,6 +649,7 @@ class SelvaLoraTrainer:
|
|||||||
timestep_mode, logit_normal_sigma, curriculum_switch,
|
timestep_mode, logit_normal_sigma, curriculum_switch,
|
||||||
lora_dropout, lora_plus_ratio, lr_schedule,
|
lora_dropout, lora_plus_ratio, lr_schedule,
|
||||||
init_mode, use_rslora,
|
init_mode, use_rslora,
|
||||||
|
latent_mixup_alpha, latent_noise_sigma,
|
||||||
)
|
)
|
||||||
return (r["patched_model"], r["adapter_path"], r["loss_curve"])
|
return (r["patched_model"], r["adapter_path"], r["loss_curve"])
|
||||||
|
|
||||||
@@ -645,6 +662,7 @@ class SelvaLoraTrainer:
|
|||||||
timestep_mode="uniform", logit_normal_sigma=1.0, curriculum_switch=0.6,
|
timestep_mode="uniform", logit_normal_sigma=1.0, curriculum_switch=0.6,
|
||||||
lora_dropout=0.0, lora_plus_ratio=1.0, lr_schedule="constant",
|
lora_dropout=0.0, lora_plus_ratio=1.0, lr_schedule="constant",
|
||||||
init_mode="pissa", use_rslora=True,
|
init_mode="pissa", use_rslora=True,
|
||||||
|
latent_mixup_alpha=0.0, latent_noise_sigma=0.0,
|
||||||
):
|
):
|
||||||
# --- Prepare generator copy with LoRA ---
|
# --- Prepare generator copy with LoRA ---
|
||||||
generator = copy.deepcopy(model["generator"]).to(device, dtype)
|
generator = copy.deepcopy(model["generator"]).to(device, dtype)
|
||||||
@@ -748,6 +766,8 @@ class SelvaLoraTrainer:
|
|||||||
"lr_schedule": lr_schedule,
|
"lr_schedule": lr_schedule,
|
||||||
"init_mode": init_mode,
|
"init_mode": init_mode,
|
||||||
"use_rslora": use_rslora,
|
"use_rslora": use_rslora,
|
||||||
|
"latent_mixup_alpha": latent_mixup_alpha,
|
||||||
|
"latent_noise_sigma": latent_noise_sigma,
|
||||||
}
|
}
|
||||||
|
|
||||||
# For curriculum mode: compute the step at which we switch from logit_normal to uniform
|
# For curriculum mode: compute the step at which we switch from logit_normal to uniform
|
||||||
@@ -775,6 +795,18 @@ class SelvaLoraTrainer:
|
|||||||
|
|
||||||
x1 = generator.normalize(x1)
|
x1 = generator.normalize(x1)
|
||||||
|
|
||||||
|
# Latent mixup (MusicLDM, arXiv:2308.01546)
|
||||||
|
if latent_mixup_alpha > 0 and x1.shape[0] > 1:
|
||||||
|
lam = torch.distributions.Beta(
|
||||||
|
latent_mixup_alpha, latent_mixup_alpha
|
||||||
|
).sample().to(device)
|
||||||
|
idx = torch.randperm(x1.shape[0], device=device)
|
||||||
|
x1 = lam * x1 + (1 - lam) * x1[idx]
|
||||||
|
|
||||||
|
# Latent noise injection
|
||||||
|
if latent_noise_sigma > 0:
|
||||||
|
x1 = x1 + torch.randn_like(x1) * latent_noise_sigma * x1.std()
|
||||||
|
|
||||||
if timestep_mode == "logit_normal" or (
|
if timestep_mode == "logit_normal" or (
|
||||||
timestep_mode == "curriculum" and step <= curriculum_switch_step
|
timestep_mode == "curriculum" and step <= curriculum_switch_step
|
||||||
):
|
):
|
||||||
|
|||||||
Reference in New Issue
Block a user