feat(ti-trainer): add loss curve IMAGE output
Reuses _draw_loss_curve + _smooth_losses + _pil_to_tensor from the LoRA trainer — raw loss in light blue, smoothed overlay in blue, matches the LoRA trainer's visual style. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -32,6 +32,9 @@ from .selva_lora_trainer import (
|
|||||||
_prepare_dataset,
|
_prepare_dataset,
|
||||||
_spectral_metrics,
|
_spectral_metrics,
|
||||||
_save_spectrogram,
|
_save_spectrogram,
|
||||||
|
_smooth_losses,
|
||||||
|
_draw_loss_curve,
|
||||||
|
_pil_to_tensor,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -113,9 +116,12 @@ class SelvaTextualInversionTrainer:
|
|||||||
OUTPUT_NODE = True
|
OUTPUT_NODE = True
|
||||||
CATEGORY = SELVA_CATEGORY
|
CATEGORY = SELVA_CATEGORY
|
||||||
FUNCTION = "train"
|
FUNCTION = "train"
|
||||||
RETURN_TYPES = ("STRING",)
|
RETURN_TYPES = ("STRING", "IMAGE")
|
||||||
RETURN_NAMES = ("embeddings_path",)
|
RETURN_NAMES = ("embeddings_path", "loss_curve")
|
||||||
OUTPUT_TOOLTIPS = ("Path to saved .pt embeddings — load with SelVA Textual Inversion Loader.",)
|
OUTPUT_TOOLTIPS = (
|
||||||
|
"Path to saved .pt embeddings — load with SelVA Textual Inversion Loader.",
|
||||||
|
"Smoothed training loss curve.",
|
||||||
|
)
|
||||||
DESCRIPTION = (
|
DESCRIPTION = (
|
||||||
"Trains K learnable CLIP token embeddings against your audio dataset "
|
"Trains K learnable CLIP token embeddings against your audio dataset "
|
||||||
"with all model weights frozen. The tokens are then injected into the "
|
"with all model weights frozen. The tokens are then injected into the "
|
||||||
@@ -208,7 +214,9 @@ class SelvaTextualInversionTrainer:
|
|||||||
n_tokens, steps, lr, batch_size,
|
n_tokens, steps, lr, batch_size,
|
||||||
warmup_steps, seed, save_every, init_text,
|
warmup_steps, seed, save_every, init_text,
|
||||||
)
|
)
|
||||||
return (r["embeddings_path"],)
|
smoothed = _smooth_losses(r["loss_history"]) if r["loss_history"] else []
|
||||||
|
curve_img = _draw_loss_curve(r["loss_history"], log_interval=50, smoothed=smoothed)
|
||||||
|
return (r["embeddings_path"], _pil_to_tensor(curve_img))
|
||||||
|
|
||||||
def _train_inner(
|
def _train_inner(
|
||||||
self, model, dataset, feature_utils_orig, seq_cfg,
|
self, model, dataset, feature_utils_orig, seq_cfg,
|
||||||
|
|||||||
Reference in New Issue
Block a user