2 Commits

Author SHA1 Message Date
Ethanfel c9550ce693 experiment: add crop_rect option — rect bbox crop without squarification
crop_rect crops frames to a rectangle around the mask bounding box
(with margin applied independently on each axis) before resizing.
The model still stretches the result to 384×384 / 224×224, but only
sees the region around the target element — simpler than crop_to_mask.

- Refactored _compute_mask_bbox with square= param (True = existing
  square logic, False = rect with per-axis margin)
- Empty mask fallback for rect mode returns the full frame (no-op)
- crop_to_mask takes priority over crop_rect when both are enabled
- crop_margin is shared between both crop modes
- Hash includes crop_rect; crop_margin hashed when either crop is active
- Log line reports mode (square/rect) alongside bbox dimensions

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-05 13:04:46 +02:00
Ethanfel f3cabcad90 experiment: crop-to-mask mode on feature extractor
Instead of squishing the full frame to a square, optionally crops a square
region around the mask bounding box (union across all frames) before resizing.
Preserves aspect ratio of the subject and gives the model a focused,
undistorted view.

New optional inputs on SelVA Feature Extractor:
- crop_to_mask (BOOLEAN, default false) — enable the crop path
- crop_margin (FLOAT 0–1, default 0.1) — padding around the bbox as a
  fraction of the bounding box side

_compute_mask_bbox: resizes mask to frame resolution, takes union over
all mask frames, expands to square + margin, shifts into frame bounds to
preserve square shape before clamping. Falls back to center square crop
if mask is empty.

Bbox is computed once from the original-resolution mask and reused for
both CLIP (384px) and sync (224px) frame sequences. Combine with
mask_clip/mask_sync for full background suppression on top of the crop.
Cache hash includes crop_to_mask and crop_margin when mask is connected.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-05 12:52:03 +02:00
7 changed files with 106 additions and 1683 deletions
-392
View File
@@ -1,392 +0,0 @@
# LoRA Training for SelVA
LoRA lets you teach the model new or partially-known sound classes using a small set of video+audio pairs. Only ~10 MB of adapter weights are trained instead of the full 4.4 GB model.
---
## Overview
Training is split into two steps:
1. **Dataset preparation** (in ComfyUI) — extract visual features from your video clips using the `SelVA Feature Extractor` node, and collect clean matching audio files.
2. **Training** (in ComfyUI or command line) — run the `SelVA LoRA Trainer` node or `train_lora.py`.
The training script only loads the generator and the VAE encoder. CLIP visual features and sync features come pre-computed from the `.npz` files, so Synchformer and T5 are not loaded during training, saving 34 GB of VRAM.
---
## Requirements
Same environment as SelVA inference. Additional Python packages:
```
torchaudio
soundfile
```
---
## Step 1 — Prepare the dataset
### 1.1 Extract visual features in ComfyUI
For each video clip you want to train on:
1. Load the video with a VHS LoadVideo node.
2. Connect it to **SelVA Feature Extractor**.
3. Set **`cache_dir`** to a dedicated dataset folder, e.g. `dataset/my_sound`.
4. Set **`name`** to a short descriptive label, e.g. `dog_bark`. The node will save `dog_bark_001.npz`, then `dog_bark_002.npz`, etc. automatically as you process more clips.
5. Set the **`prompt`** to describe the sound (e.g. `a dog barking`). This prompt conditions the Synchformer sync features — be as specific as possible (see prompt guide below).
6. Optionally connect a **mask** to isolate the sound source in frame (strongly recommended when multiple objects are visible — see masking note below).
> **Tip:** The prompt used for feature extraction conditions the *visual sync features*. You can use a different, more precise prompt at training time — see Step 2.
### Prompt guide
The prompt is not just a label — it directly shapes what the Synchformer pays attention to in the video. Imprecise prompts produce unfocused sync features, which the LoRA then has to compensate for, often introducing noise.
**Good prompts are specific about:**
- The sound source (what object is making the sound)
- The acoustic character (loud/quiet, sharp/soft, wet/dry)
- The action producing the sound (if applicable)
| Sound | Weak prompt | Strong prompt |
|---|---|---|
| Dog bark | `dog` | `a large dog barking loudly` |
| Footsteps | `walking` | `heavy boots on a wooden floor` |
| Water | `water` | `water dripping into a metal bucket` |
| Explosion | `explosion` | `a large explosion with deep bass rumble` |
| Door | `door` | `a heavy wooden door slamming shut` |
**Rules of thumb:**
- Describe the *sound*, not the visual scene. `a person hitting a drum` is better than `a drummer on stage`.
- Keep prompts consistent across all clips for the same sound class. Mixing `a dog barking` and `loud barking dog` in the same dataset creates conflicting sync features.
- Avoid negations (`no background noise`) — the model does not understand negations in sync feature conditioning.
### Masking note
If the video frame contains multiple moving objects, CLIP and sync features will be diluted by irrelevant motion. Use a segmentation mask (SAM2 or Grounding DINO+SAM) to isolate the sound source:
- Connect the mask to the **`mask`** input on SelVA Feature Extractor.
- Leave `mask_strength` at `1.0` for clean isolation; lower it only if the masked region is very small and the model loses context.
- Re-extract features with a mask even if you already have `.npz` files — better features directly reduce training noise.
### 1.2 Collect clean audio
For each `.npz` file, place a matching audio file with the **same filename stem** in the same directory:
```
dataset/my_sound/
dog_bark_001.npz ← from SelVA Feature Extractor
dog_bark_001.wav ← clean isolated audio recording
dog_bark_002.npz
dog_bark_002.wav
dog_bark_003.npz
dog_bark_003.wav
```
Supported audio formats: `.wav`, `.flac`, `.ogg`, `.aiff`, `.aif`
> `.mp3` is not recommended — lossy compression degrades training quality. Use `.flac` or `.wav`.
The audio will be automatically resampled and trimmed/padded to match the model's expected duration. Use clean, isolated recordings — no background noise.
### 1.3 Optional: prompts.txt
If you want a different prompt at training time than the one embedded in the `.npz`, create a `prompts.txt` file in the dataset directory:
```
# One line per file: filename: prompt text
dog_bark.npz: a large dog barking aggressively
dog_bark_001.npz: a dog barking in the distance
```
Priority: `prompts.txt` > prompt embedded in `.npz` > directory name as fallback.
---
## Step 2 — Train
### Option A — SelVA LoRA Trainer node (ComfyUI)
Connect the node and set parameters directly in the UI. The node outputs the trained model ready to wire into the Sampler, and saves loss curve images to the output directory.
```
SelVA Model Loader → SelVA LoRA Trainer → SelVA Sampler
```
### Option B — Command line
```bash
python train_lora.py \
--data_dir dataset/my_sound \
--output_dir lora_output/my_sound \
--variant large_44k \
--selva_dir /path/to/ComfyUI/models/selva \
--rank 16 \
--steps 4000 \
--batch_size 4 \
--lr 1e-4
```
The script will:
1. Load the VAE, CLIP text encoder, and generator.
2. Pre-load all clips (audio encoded to latents, features loaded from `.npz`).
3. Train LoRA adapters for the specified number of steps.
4. Save a checkpoint every `--save_every` steps, a final `adapter_final.pt`, and loss curve images.
---
## CLI Reference
| Argument | Default | Description |
|---|---|---|
| `--data_dir` | required | Directory containing `.npz` + audio pairs |
| `--output_dir` | `lora_output` | Where to save adapter checkpoints |
| `--variant` | `large_44k` | Model variant: `small_16k`, `small_44k`, `medium_44k`, `large_44k` |
| `--selva_dir` | required | Path to SelVA model weights directory |
| `--rank` | `16` | LoRA rank — higher = more capacity, more VRAM |
| `--alpha` | `rank` | LoRA alpha scaling. Default (= rank) means scale = 1.0 |
| `--target` | `attn.qkv` | Which layers to adapt. Add `linear1` for post-attention projections |
| `--lr` | `1e-4` | Learning rate |
| `--steps` | `2000` | Total training steps |
| `--warmup_steps` | `100` | Linear LR warmup steps |
| `--batch_size` | `4` | Clips per training step — higher is more stable, uses more VRAM |
| `--grad_accum` | `1` | Gradient accumulation steps (use when batch_size is already > 1) |
| `--save_every` | `500` | Save a checkpoint every N steps |
| `--resume` | `None` | Path to a step checkpoint to resume from (e.g. `lora_output/adapter_step04000.pt`) |
| `--precision` | `bf16` | Mixed precision: `bf16`, `fp16`, `fp32` |
| `--seed` | `42` | Random seed |
---
## Step 3 — Load the adapter in ComfyUI
Connect **SelVA LoRA Loader** between the model loader and the sampler:
```
SelVA Model Loader → SelVA LoRA Loader → SelVA Sampler
```
> **Important:** Wire the LoRA Loader output to the **Sampler**, not the Feature Extractor. The LoRA adapts the generator which only runs in the Sampler.
| Input | Description |
|---|---|
| `model` | SELVA_MODEL from the model loader |
| `adapter_path` | Path to `adapter_final.pt` or any `adapter_stepXXXXX.pt` |
| `strength` | 0.0 = adapter disabled, 1.0 = full strength, >1.0 = exaggerated |
The loader reads rank, alpha, and target layers from the metadata embedded in the `.pt` file — no need to set them manually.
> The base model is not modified. The loader returns a shallow copy with a deep-copied generator so the original stays intact.
---
## Tuning Guide
### Clip length
The model has a **fixed input duration of 8 seconds** for all variants (both 16k and 44k). This is not a parameter you can change.
- Audio shorter than 8 s is **zero-padded** (silence appended). The model will learn the sound but may also learn silence as part of the pattern — keep in mind for very short sounds.
- Audio longer than 8 s is **trimmed** at 8 s. Content beyond that is lost.
- Video shorter than 8 s has its **last frame repeated** to fill the clip.
**Practical recommendations:**
| Sound type | Clip strategy |
|---|---|
| Continuous sound (rain, engine, wind) | 8 s recordings, as many positions in the audio as possible |
| Single event < 2 s (click, bark, knock) | Center the event — pad deliberately with silence before/after, or loop the event 23 times per clip |
| Repeating event (footsteps, dripping) | Record full 8 s with natural repetition at the intended cadence |
| Sound with a clear onset (explosion, splash) | Put the onset at ~12 s from the start, not at 0 s — gives the model context |
> **Tip:** When extracting features in ComfyUI, set `duration` to 0 to use the full video length up to 8 s. Clips longer than 8 s are automatically clamped.
### How many clips do I need?
The table below gives a rough scaling guide. Quality and diversity of recordings matter more than raw count.
| Dataset size | Scenario | Expected result |
|---|---|---|
| **510 clips** | Quick test / proof of concept | May work if the model already partially knows the sound; often underfits |
| **1530 clips** | Fine-tuning a sound the model knows but gets wrong | Good starting point — covers the main variations |
| **3060 clips** | Teaching a new but acoustically simple sound class | Reliable convergence with default hyperparameters |
| **60150 clips** | Unusual or complex sounds, strong style shift | Needed for stable generalization across video contexts |
| **150300 clips** | Sounds the model has never encountered | Required to avoid overfitting; increase rank to 32 |
| **300+** | Large-scale domain shift | Consider also targeting `linear1` in addition to `attn.qkv` |
**Diversity beats quantity.** Ten clips of a dog barking in different environments (indoors, outdoors, distant, close) train better than fifty clips of the same recording. Vary: distance, room acoustics, intensity, speed.
### Batch size
| Batch size | VRAM (large_44k) | Use case |
|---|---|---|
| `1` | ~9 GB | Minimal VRAM, noisy gradients |
| `4` | ~12 GB | Good default — stable gradients, reasonable speed |
| `8` | ~15 GB | Better convergence on larger datasets |
| `16` | ~20 GB | Best gradient quality when VRAM allows |
Higher batch size gives smoother loss curves and faster convergence. If you have headroom, prefer larger batches over more steps.
**Observed results:** batch 16 reaches the same loss in ~2600 steps that batch 1 needed 8000+ steps to reach, with a near-perfectly smooth curve. On a 24 GB GPU, batch 16 is the recommended default for `large_44k`.
### Rank
| Rank | Use case |
|---|---|
| `8` | Fine details on a sound the model already knows well |
| `16` | Default — good balance of capacity and VRAM |
| `32` | Harder sounds or larger style shifts (30+ clips recommended) |
Higher rank increases VRAM usage and overfitting risk on small datasets.
### Steps
With `batch_size=4` as the default, these are rough guidelines:
| Dataset size | Recommended steps |
|---|---|
| 1020 clips | 20004000 |
| 2050 clips | 40008000 |
| 50+ clips | 600015000 |
Watch the loss curve — if the smoothed line has been flat for 2000+ steps, training has converged for your dataset size. Adding more clips will let it go lower.
### Learning rate
`1e-4` is the recommended default for any batch size. If training is unstable (loss spikes in the first 200 steps), try `5e-5`. If convergence is very slow, try `2e-4`.
Warmup (default 100 steps) ramps the LR from 0 to avoid instability at the start.
### Target layers
`attn.qkv` (default) adapts only the self-attention QKV projections. This is the recommended starting point for all dataset sizes.
Add `linear1` to also adapt post-attention projections for large-scale domain shifts or when `attn.qkv` alone plateaus too early:
```bash
--target attn.qkv linear1
```
Only add `linear1` once you have 150+ clips — it doubles the adapted parameter count and overfits faster on small datasets.
### Adapter strength at inference
| Strength | Effect |
|---|---|
| `0.50.7` | Conservative — blends adapter with base model, less noise |
| `1.0` | Full adapter strength (default) |
| `>1.0` | Exaggerated effect, may introduce artifacts |
If the generated audio has noticeable white noise or artifacts, lower the strength to `0.60.7` before adjusting anything else. Also try lowering CFG scale in the Sampler.
### Loss interpretation
A typical loss curve:
- Starts around `0.81.0`
- Should reach `0.550.65` after convergence on a clean sound class with 1030 clips
- Below `0.4` indicates strong learning — usually requires 50+ diverse clips
- Below `0.1` on a small dataset means overfitting
The smoothed curve flattening for 2000+ steps is the clearest sign to stop or add more data.
### Precision
Use `bf16` on Ampere+ GPUs (RTX 3xxx/4xxx, A100). Fall back to `fp16` on older GPUs. `fp32` is only needed for debugging — 2× more VRAM.
---
## Output files
```
lora_output/my_sound/
adapter_step00500.pt ← step checkpoint (includes optimizer state for resume)
adapter_step01000.pt
...
adapter_final.pt ← final adapter with embedded metadata (inference only)
meta.json ← human-readable metadata
sample_step00500.wav ← quick eval sample at each checkpoint
loss_raw.png ← raw loss curve
loss_smoothed.png ← EMA-smoothed loss curve
```
`adapter_final.pt` format:
```python
{
"state_dict": { "blocks.0.attn.qkv.lora_A": ..., ... },
"meta": {
"variant": "large_44k",
"rank": 16,
"alpha": 16.0,
"target": ["attn.qkv"],
"steps": 2000
}
}
```
Step checkpoints (e.g. `adapter_step01000.pt`) additionally contain `optimizer` and `scheduler` state for resuming.
---
## Troubleshooting
**`No layers matched target=...`**
The `--target` suffixes do not match any layer names. The default `attn.qkv` targets `SelfAttention.qkv` in all transformer blocks. If you changed `--target`, verify the layer names with `model.named_modules()`.
**`No .npz files found in ...`**
The `--data_dir` path is wrong or no `.npz` files were extracted there yet. Run SelVA Feature Extractor in ComfyUI first with the matching `cache_dir`.
**`No audio file found for clip.npz`**
Place an audio file with the exact same stem next to the `.npz`: `clip.wav`, `clip.flac`, etc.
**The sound is audible but there is white noise on top**
Lower the adapter strength to `0.60.7` in SelVA LoRA Loader. Also try lowering CFG scale in the Sampler. This is normal when the model hasn't fully converged — more clips and more steps will reduce it.
**LoRA appears to have no effect**
Make sure the SelVA LoRA Loader output is wired to the **Sampler** input, not the Feature Extractor. The Feature Extractor does not use the generator.
**Loss does not decrease**
- Increase `batch_size` for more stable gradients.
- Try a higher learning rate (`2e-4`) or check that warmup isn't too long.
- Check that the audio files are clean and actually contain the target sound.
- Check that the `.npz` features were extracted with a relevant prompt.
**Loss explodes or NaN**
- Lower the learning rate (`5e-5`).
- Make sure audio is normalized to `[-1, 1]`. PCM files with 16-bit integer encoding may need to be converted: `ffmpeg -i input.wav -ar 44100 -sample_fmt s16 output.wav`
**Loss plateaus early (above 0.7)**
Dataset is the bottleneck. Add more clips — diversity matters more than quantity.
---
## Observations (work in progress)
These are empirical findings from ongoing experiments. They will be promoted to the main guide once more validated.
### Precision and batch size
| Config | Smoothed loss at step 2000 | Notes |
|---|---|---|
| bf16 batch 1 | ~0.73 | Noisy gradients, slow |
| bf16 batch 16 | ~0.65 | Stable, plateaued around step 60008000 at ~0.59 |
| bf16 batch 16 logit_normal | ~0.47 | Lower loss floor, similar or marginally better audio |
| fp32 batch 32 | ~0.58 | Matches bf16 batch 16 at step 6000 already at step 2000 |
**Key finding:** fp32 batch 32 converges to the same perceptual quality point in ~2000 steps that bf16 batch 16 needs 6000+ steps to reach. However, fp32 batch 32 continues descending well past that point on small datasets (10 clips), eventually overfitting. **Stop fp32 batch 32 around step 2000 on a 10-clip dataset** — later checkpoints sound worse despite lower loss.
**Lower loss ≠ better audio.** Once overfitting begins the model memorizes training clips rather than generalizing to new video inputs. Test intermediate checkpoints (e.g. step 500, 1000, 2000) to find the perceptual sweet spot.
### logit_normal vs uniform
logit_normal consistently reaches a lower loss floor than uniform. However perceptual improvement is dataset-dependent — on 10 clips the difference is marginal. May be more impactful with larger datasets. No conclusion yet.
### White noise
Residual white noise on generated audio is primarily a **dataset** problem, not a training one. Observed with all configs on 10 clips. Likely causes:
- Too few clips for the model to confidently predict the target sound
- Imprecise extraction prompts producing unfocused sync features
- Missing mask when multiple objects are in frame
CFG scale amplifies any adapter noise bias. Reducing CFG to 3.03.5 or adapter strength to 0.60.7 helps at inference.
-2
View File
@@ -5,8 +5,6 @@ _NODES = {
"SelvaModelLoader": (".selva_model_loader", "SelvaModelLoader", "SelVA Model Loader"),
"SelvaFeatureExtractor": (".selva_feature_extractor", "SelvaFeatureExtractor", "SelVA Feature Extractor"),
"SelvaSampler": (".selva_sampler", "SelvaSampler", "SelVA Sampler"),
"SelvaLoraLoader": (".selva_lora_loader", "SelvaLoraLoader", "SelVA LoRA Loader"),
"SelvaLoraTrainer": (".selva_lora_trainer", "SelvaLoraTrainer", "SelVA LoRA Trainer"),
}
for key, (module_path, class_name, display_name) in _NODES.items():
+101 -25
View File
@@ -35,6 +35,66 @@ def _resize_frames(frames, size):
return x.clamp(0.0, 1.0) # [N, C, H, W]
def _compute_mask_bbox(mask, frame_h, frame_w, margin=0.1, square=True):
"""
Compute a bounding box around the union of all mask frames.
mask: [M, H', W'] float [0,1]
square: if True, expand bbox to a square and shift into frame bounds;
if False, apply margin independently on each axis (rect crop).
Returns (y0, x0, y1, x1) in pixel coords clamped to (frame_h, frame_w).
"""
if mask.shape[1] != frame_h or mask.shape[2] != frame_w:
m = F.interpolate(
mask.float().unsqueeze(1), size=(frame_h, frame_w), mode="nearest-exact"
).squeeze(1)
else:
m = mask.float()
union = (m > 0.5).max(dim=0).values # [H, W] bool
if not union.any():
if square:
# Empty mask — center square crop
side = min(frame_h, frame_w)
cy, cx = frame_h // 2, frame_w // 2
y0 = max(0, cy - side // 2)
x0 = max(0, cx - side // 2)
return y0, x0, min(frame_h, y0 + side), min(frame_w, x0 + side)
else:
# Empty mask — return full frame (no meaningful rect to crop to)
return 0, 0, frame_h, frame_w
ys = union.any(dim=1).nonzero(as_tuple=True)[0]
xs = union.any(dim=0).nonzero(as_tuple=True)[0]
y0, y1 = int(ys[0]), int(ys[-1]) + 1
x0, x1 = int(xs[0]), int(xs[-1]) + 1
if square:
side = max(y1 - y0, x1 - x0)
pad = int(side * margin)
side += 2 * pad
cy = (y0 + y1) // 2
cx = (x0 + x1) // 2
y0n = cy - side // 2
x0n = cx - side // 2
y1n = y0n + side
x1n = x0n + side
# Shift into frame bounds to preserve square shape
if y0n < 0: y1n -= y0n; y0n = 0
if y1n > frame_h: y0n -= y1n - frame_h; y1n = frame_h
if x0n < 0: x1n -= x0n; x0n = 0
if x1n > frame_w: x0n -= x1n - frame_w; x1n = frame_w
return max(0, int(y0n)), max(0, int(x0n)), min(frame_h, int(y1n)), min(frame_w, int(x1n))
else:
pad_y = int(max(1, y1 - y0) * margin)
pad_x = int(max(1, x1 - x0) * margin)
return max(0, y0 - pad_y), max(0, x0 - pad_x), min(frame_h, y1 + pad_y), min(frame_w, x1 + pad_x)
def _apply_mask(frames, mask, source_fps, target_fps, mask_strength=1.0):
"""
Apply a ComfyUI MASK to resized frames.
@@ -68,20 +128,9 @@ def _apply_mask(frames, mask, source_fps, target_fps, mask_strength=1.0):
return frames * alpha + 0.5 * (1.0 - alpha)
def _resolve_named_path(cache_dir: str, name: str) -> str:
"""Return cache_dir/name.npz, incrementing to name_001.npz etc. if the file already exists."""
# Sanitize: replace path separators so the name stays inside cache_dir
name = name.replace("/", "_").replace("\\", "_").replace("\x00", "_")
i = 1
while True:
p = os.path.join(cache_dir, f"{name}_{i:03d}.npz")
if not os.path.exists(p):
return p
i += 1
def _hash_inputs(video_tensor, prompt, fps, duration, variant, mask=None,
mask_strength=1.0, mask_clip=True, mask_sync=True):
mask_strength=1.0, mask_clip=True, mask_sync=True,
crop_to_mask=False, crop_rect=False, crop_margin=0.1):
h = hashlib.sha256()
raw = video_tensor.cpu().numpy().tobytes()
n = len(raw)
@@ -99,6 +148,10 @@ def _hash_inputs(video_tensor, prompt, fps, duration, variant, mask=None,
h.update(str(round(mask_strength, 4)).encode())
h.update(str(mask_clip).encode())
h.update(str(mask_sync).encode())
h.update(str(crop_to_mask).encode())
h.update(str(crop_rect).encode())
if crop_to_mask or crop_rect:
h.update(str(round(crop_margin, 4)).encode())
h.update(prompt.encode())
h.update(str(fps).encode())
h.update(str(round(duration, 3)).encode()) # resolved duration affects frame count
@@ -128,8 +181,6 @@ class SelvaFeatureExtractor:
"tooltip": "Clip duration in seconds. 0 = use the full video length. Clamped to actual video length if too long."}),
"cache_dir": ("STRING", {"default": "",
"tooltip": "Where to store extracted feature files (.npz). Leave empty for the system temp directory. Reusing the same directory enables instant cache hits on re-runs."}),
"name": ("STRING", {"default": "",
"tooltip": "Optional filename for the saved .npz (without extension). If provided, features are always saved with this name instead of a content hash — useful for building a named training dataset. Auto-increments: dog_bark → dog_bark_001 → dog_bark_002 if the file already exists. Leave empty to use the default content-hash cache."}),
"mask": ("MASK", {
"tooltip": "Optional segmentation mask [T,H,W] float [0,1]. Background pixels are zeroed before encoding — useful when multiple objects compete for the same sound. Static (1-frame) or per-frame masks both supported. Connect SAM2 or Grounding DINO+SAM output.",
}),
@@ -145,6 +196,18 @@ class SelvaFeatureExtractor:
"default": True,
"tooltip": "Apply the mask to TextSynchformer sync features (224px). This is the primary path for isolating which object's motion drives the audio.",
}),
"crop_to_mask": ("BOOLEAN", {
"default": False,
"tooltip": "Experimental. Crops frames to a square region around the mask bounding box before resizing. The model sees an undistorted view of the subject. Requires mask. Takes priority over crop_rect.",
}),
"crop_rect": ("BOOLEAN", {
"default": False,
"tooltip": "Experimental. Crops frames to a rectangle around the mask bounding box (with margin) before resizing. The model still stretches the crop to a square, but only sees the region around the target element. Simpler than crop_to_mask. Requires mask.",
}),
"crop_margin": ("FLOAT", {
"default": 0.1, "min": 0.0, "max": 1.0, "step": 0.05,
"tooltip": "Margin added around the bounding box as a fraction of the bbox size. Shared by crop_to_mask and crop_rect. 0.1 = 10% on each side.",
}),
},
}
@@ -155,14 +218,14 @@ class SelvaFeatureExtractor:
"Source fps of the video — wire to VHS_VideoCombine frame_rate.",
"The prompt used during extraction — wire to Sampler prompt to avoid re-typing.",
)
OUTPUT_NODE = True # always execute: the node's purpose is saving .npz files to disk
FUNCTION = "extract_features"
CATEGORY = SELVA_CATEGORY
DESCRIPTION = "Extracts CLIP visual features and text-conditioned sync features from a video. Results are cached — re-running with the same inputs is instant."
def extract_features(self, model, video, prompt, video_info=None, fps=30.0,
duration=0.0, cache_dir="", name="", mask=None,
mask_strength=1.0, mask_clip=True, mask_sync=True):
duration=0.0, cache_dir="", mask=None,
mask_strength=1.0, mask_clip=True, mask_sync=True,
crop_to_mask=False, crop_rect=False, crop_margin=0.1):
if video_info is not None:
fps = video_info["loaded_fps"]
@@ -178,15 +241,11 @@ class SelvaFeatureExtractor:
if not cache_dir:
cache_dir = os.path.join(tempfile.gettempdir(), "selva_features")
os.makedirs(cache_dir, exist_ok=True)
if name.strip():
# Named mode: always extract and save to an incremented filename
cached_path = _resolve_named_path(cache_dir, name.strip())
else:
# Hash mode: skip extraction if identical inputs were already processed
cache_key = _hash_inputs(video, prompt, fps, duration, model["variant"], mask=mask,
mask_strength=mask_strength, mask_clip=mask_clip, mask_sync=mask_sync)
mask_strength=mask_strength, mask_clip=mask_clip, mask_sync=mask_sync,
crop_to_mask=crop_to_mask, crop_rect=crop_rect, crop_margin=crop_margin)
cached_path = os.path.join(cache_dir, f"{cache_key}.npz")
if os.path.exists(cached_path):
print(f"[SelVA] Using cached features: {cached_path}", flush=True)
cached = _load_cached(cached_path)
@@ -206,10 +265,24 @@ class SelvaFeatureExtractor:
print(f"[SelVA] Extracting features: duration={duration:.2f}s fps={fps:.3f} prompt='{prompt[:60]}'", flush=True)
pbar = comfy.utils.ProgressBar(3)
# Pre-compute crop bbox once from the original-resolution mask
crop_bbox = None
if mask is not None and (crop_to_mask or crop_rect):
H_vid, W_vid = video.shape[1], video.shape[2]
_square = crop_to_mask # crop_to_mask takes priority; crop_rect is rect-only
crop_bbox = _compute_mask_bbox(mask, H_vid, W_vid, crop_margin, square=_square)
cy0, cx0, cy1, cx1 = crop_bbox
_mode = "square" if _square else "rect"
print(f"[SelVA] Mask crop ({_mode}): y={cy0}:{cy1} x={cx0}:{cx1} "
f"({cy1-cy0}×{cx1-cx0}px from {H_vid}×{W_vid})", flush=True)
try:
with torch.no_grad():
# --- CLIP frames: [1, N, C, 384, 384] float32 [0,1] ---
clip_frames = _sample_frames(video, fps, _CLIP_FPS, duration) # [N, H, W, C]
if crop_bbox is not None:
cy0, cx0, cy1, cx1 = crop_bbox
clip_frames = clip_frames[:, cy0:cy1, cx0:cx1, :]
clip_frames = _resize_frames(clip_frames, _CLIP_SIZE) # [N, C, 384, 384]
if mask is not None and mask_clip:
clip_frames = _apply_mask(clip_frames, mask, fps, _CLIP_FPS, mask_strength)
@@ -222,6 +295,9 @@ class SelvaFeatureExtractor:
# --- Sync frames: [1, N, C, 224, 224] float32 [-1,1] ---
sync_frames = _sample_frames(video, fps, _SYNC_FPS, duration) # [N, H, W, C]
if crop_bbox is not None:
cy0, cx0, cy1, cx1 = crop_bbox
sync_frames = sync_frames[:, cy0:cy1, cx0:cx1, :]
sync_frames = _resize_frames(sync_frames, _SYNC_SIZE) # [N, C, 224, 224]
if mask is not None and mask_sync:
sync_frames = _apply_mask(sync_frames, mask, fps, _SYNC_FPS, mask_strength)
-102
View File
@@ -1,102 +0,0 @@
import copy
import torch
import folder_paths
from .utils import SELVA_CATEGORY
from selva_core.model.lora import apply_lora, load_lora
class SelvaLoraLoader:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"model": ("SELVA_MODEL",),
"adapter_path": ("STRING", {
"default": "",
"tooltip": "Path to a LoRA adapter .pt file produced by train_lora.py.",
}),
"strength": ("FLOAT", {
"default": 1.0, "min": 0.0, "max": 2.0, "step": 0.05,
"tooltip": "Scale applied to all LoRA contributions. "
"1.0 = full adapter strength. "
"0.0 = effectively disables the adapter. "
"Values above 1.0 exaggerate the effect.",
}),
},
}
RETURN_TYPES = ("SELVA_MODEL",)
RETURN_NAMES = ("model",)
OUTPUT_TOOLTIPS = ("Model with LoRA adapter applied — connect to Sampler.",)
FUNCTION = "load"
CATEGORY = SELVA_CATEGORY
DESCRIPTION = (
"Loads a LoRA adapter produced by train_lora.py and applies it to the generator. "
"The base model is not modified — a shallow copy of the model bundle is returned."
)
def load(self, model: dict, adapter_path: str, strength: float) -> tuple:
if not adapter_path.strip():
raise ValueError("[SelVA LoRA] adapter_path is empty.")
# Resolve path: allow absolute or relative to ComfyUI base
from pathlib import Path
p = Path(adapter_path)
if not p.is_absolute():
p = Path(folder_paths.base_path) / p
if not p.exists():
raise FileNotFoundError(f"[SelVA LoRA] Adapter not found: {p}")
checkpoint = torch.load(str(p), map_location="cpu", weights_only=False)
# Support both raw state_dict and {state_dict, meta} formats
if isinstance(checkpoint, dict) and "state_dict" in checkpoint:
state_dict = checkpoint["state_dict"]
meta = checkpoint.get("meta", {})
else:
state_dict = checkpoint
meta = {}
rank = int(meta.get("rank", 16))
alpha = float(meta.get("alpha", float(rank)))
target = list(meta.get("target", ["attn.qkv"]))
print(f"[SelVA LoRA] Loading adapter: {p.name}", flush=True)
print(f"[SelVA LoRA] rank={rank} alpha={alpha} target={target} strength={strength}",
flush=True)
# Shallow-copy the model bundle so the original generator is not mutated
patched = {**model}
generator = copy.deepcopy(model["generator"])
n = apply_lora(generator, rank=rank, alpha=alpha, target_suffixes=tuple(target))
if n == 0:
raise RuntimeError(
f"[SelVA LoRA] No layers matched target={target}. "
"Check that the adapter was trained with the same target suffixes."
)
load_lora(generator, state_dict)
# Sanity check: confirm lora_A weights are non-zero (lora_B starts at zero by design)
norms = [p.norm().item() for name, p in generator.named_parameters()
if "lora_A" in name]
if norms:
print(f"[SelVA LoRA] lora_A weight norms: min={min(norms):.4f} "
f"max={max(norms):.4f} mean={sum(norms)/len(norms):.4f}", flush=True)
else:
print("[SelVA LoRA] WARNING: no lora_A params found after loading!", flush=True)
# Apply strength scaling: multiply all lora_B params by strength
# (lora_B is initialised to zero, so scaling A is equivalent but less clean)
if strength != 1.0:
with torch.no_grad():
for name, param in generator.named_parameters():
if "lora_B" in name:
param.mul_(strength)
generator.to(model["generator"].parameters().__next__().device)
patched["generator"] = generator
print(f"[SelVA LoRA] Applied {n} LoRA layers.", flush=True)
return (patched,)
-617
View File
@@ -1,617 +0,0 @@
import copy
import json
import random
import traceback
from pathlib import Path
import numpy as np
import torch
import torch.nn.functional as F
import torchaudio
from PIL import Image, ImageDraw
import comfy.utils
import folder_paths
from .utils import SELVA_CATEGORY, get_device, soft_empty_cache
from selva_core.model.utils.features_utils import FeaturesUtils
from selva_core.model.flow_matching import FlowMatching
from selva_core.model.lora import apply_lora, get_lora_state_dict, load_lora
_AUDIO_EXTS = {".wav", ".flac", ".mp3", ".ogg", ".aiff", ".aif"}
_SELVA_DIR = Path(folder_paths.models_dir) / "selva"
# ---------------------------------------------------------------------------
# Data helpers (mirror train_lora.py)
# ---------------------------------------------------------------------------
def _load_prompts(data_dir: Path) -> dict:
p = data_dir / "prompts.txt"
if not p.exists():
return {}
mapping = {}
for line in p.read_text(encoding="utf-8").splitlines():
line = line.strip()
if not line or line.startswith("#"):
continue
if ":" in line:
fname, prompt = line.split(":", 1)
mapping[fname.strip()] = prompt.strip()
return mapping
def _find_audio(npz_path: Path) -> Path | None:
for ext in _AUDIO_EXTS:
c = npz_path.with_suffix(ext)
if c.exists():
return c
return None
def _load_audio(path: Path, target_sr: int, duration: float) -> torch.Tensor:
try:
waveform, sr = torchaudio.load(str(path))
except RuntimeError as e:
if "torchcodec" not in str(e).lower() and "libtorchcodec" not in str(e).lower():
raise
# torchcodec unavailable (FFmpeg shared libs missing) — fall back to soundfile
import soundfile as sf
data, sr = sf.read(str(path), always_2d=True) # [frames, channels]
waveform = torch.from_numpy(data.T).float() # [channels, frames]
if waveform.shape[0] > 1:
waveform = waveform.mean(0, keepdim=True)
waveform = waveform.squeeze(0).float()
if sr != target_sr:
waveform = torchaudio.functional.resample(
waveform.unsqueeze(0), sr, target_sr).squeeze(0)
target_len = int(duration * target_sr)
if waveform.shape[0] >= target_len:
return waveform[:target_len]
return F.pad(waveform, (0, target_len - waveform.shape[0]))
def _load_npz(path: Path) -> dict:
data = np.load(str(path), allow_pickle=False)
bundle = {
"clip_features": torch.from_numpy(data["clip_features"]),
"sync_features": torch.from_numpy(data["sync_features"]),
}
if "prompt" in data:
bundle["prompt"] = str(data["prompt"])
return bundle
# ---------------------------------------------------------------------------
# Eval sample
# ---------------------------------------------------------------------------
def _eval_sample(generator, feature_utils_orig, dataset, seq_cfg, device, dtype,
num_steps: int = 8):
"""Run a quick no-CFG inference pass on a random training clip.
Returns (waveform [1, L] float32 cpu, sample_rate) or (None, None) on failure.
Uses fewer ODE steps than inference (8 vs 25) for speed.
"""
generator.eval()
try:
_, clip_f_cpu, sync_f_cpu, text_clip_cpu = random.choice(dataset)
clip_f = clip_f_cpu.to(device, dtype)
sync_f = sync_f_cpu.to(device, dtype)
text_clip = text_clip_cpu.to(device, dtype)
x0 = torch.randn(1, seq_cfg.latent_seq_len, generator.latent_dim,
device=device, dtype=dtype)
eval_fm = FlowMatching(min_sigma=0, inference_mode="euler", num_steps=num_steps)
def velocity_fn(t, x):
return generator.forward(x, clip_f, sync_f, text_clip,
t.reshape(1).to(device, dtype))
with torch.no_grad():
x1_pred = eval_fm.to_data(velocity_fn, x0)
x1_unnorm = generator.unnormalize(x1_pred)
# feature_utils_orig may be on CPU (offload strategy) — move temporarily
orig_device = next(feature_utils_orig.parameters()).device
if orig_device != device:
feature_utils_orig.to(device)
try:
spec = feature_utils_orig.decode(x1_unnorm)
audio = feature_utils_orig.vocode(spec)
finally:
if orig_device != device:
feature_utils_orig.to(orig_device)
audio = audio.float().cpu()
if audio.dim() == 2:
audio = audio.unsqueeze(1)
elif audio.dim() == 3 and audio.shape[1] != 1:
audio = audio.mean(dim=1, keepdim=True)
peak = audio.abs().max().clamp(min=1e-8)
audio = (audio / peak).clamp(-1, 1)
return audio.squeeze(0), seq_cfg.sampling_rate # [1, L]
except Exception as e:
print(f"[LoRA Trainer] Eval sample failed: {e}", flush=True)
return None, None
finally:
generator.train()
# ---------------------------------------------------------------------------
# Loss curve rendering
# ---------------------------------------------------------------------------
def _smooth_losses(losses: list[float], beta: float = 0.9) -> list[float]:
"""Exponential moving average smoothing."""
smoothed, ema = [], None
for v in losses:
ema = v if ema is None else beta * ema + (1 - beta) * v
smoothed.append(ema)
return smoothed
def _draw_loss_curve(losses: list[float], log_interval: int,
start_step: int = 0, smoothed: list[float] | None = None) -> Image.Image:
"""Render a loss curve as a PIL Image."""
W, H = 800, 380
pl, pr, pt, pb = 70, 20, 25, 45
img = Image.new("RGB", (W, H), (255, 255, 255))
draw = ImageDraw.Draw(img)
pw = W - pl - pr
ph = H - pt - pb
if len(losses) >= 2:
lo, hi = min(losses), max(losses)
if hi == lo:
hi = lo + 1e-6
rng = hi - lo
# Horizontal grid + y-axis labels
for i in range(5):
y = pt + int(i * ph / 4)
val = hi - i * rng / 4
draw.line([(pl, y), (W - pr, y)], fill=(220, 220, 220), width=1)
draw.text((2, y - 7), f"{val:.4f}", fill=(120, 120, 120))
# Raw loss line
n = len(losses)
pts = []
for i, v in enumerate(losses):
x = pl + int(i * pw / max(n - 1, 1))
y = pt + int((1.0 - (v - lo) / rng) * ph)
pts.append((x, y))
draw.line(pts, fill=(200, 220, 255), width=1)
# Smoothed overlay
if smoothed is not None and len(smoothed) >= 2:
spts = []
for i, v in enumerate(smoothed):
x = pl + int(i * pw / max(n - 1, 1))
y = pt + int((1.0 - (v - lo) / rng) * ph)
spts.append((x, y))
draw.line(spts, fill=(66, 133, 244), width=2)
# x-axis step labels — account for start_step so resumed runs are correct
first_step = start_step + log_interval
last_step = start_step + n * log_interval
for i in range(5):
x = pl + int(i * pw / 4)
step = int(first_step + i * (last_step - first_step) / 4)
draw.text((x - 12, H - pb + 5), str(step), fill=(120, 120, 120))
# Axes
draw.line([(pl, pt), (pl, H - pb)], fill=(40, 40, 40), width=1)
draw.line([(pl, H - pb), (W - pr, H - pb)], fill=(40, 40, 40), width=1)
draw.text((pl + 4, 5), "Training Loss", fill=(40, 40, 40))
return img
def _pil_to_tensor(img: Image.Image) -> torch.Tensor:
"""Convert a PIL Image to a [1, H, W, 3] float32 IMAGE tensor for ComfyUI."""
arr = np.array(img).astype(np.float32) / 255.0
return torch.from_numpy(arr).unsqueeze(0)
# ---------------------------------------------------------------------------
# Node
# ---------------------------------------------------------------------------
class SelvaLoraTrainer:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"model": ("SELVA_MODEL",),
"data_dir": ("STRING", {
"default": "",
"tooltip": "Directory containing .npz feature files and paired audio files.",
}),
"output_dir": ("STRING", {
"default": "lora_output",
"tooltip": "Where to save adapter checkpoints.",
}),
"steps": ("INT", {
"default": 2000, "min": 100, "max": 100000,
"tooltip": "Total training steps.",
}),
"rank": ("INT", {
"default": 16, "min": 1, "max": 128,
"tooltip": "LoRA rank. Higher = more capacity, more VRAM. 16 is a safe default.",
}),
"lr": ("FLOAT", {
"default": 1e-4, "min": 1e-6, "max": 1e-2, "step": 1e-6,
"tooltip": "Learning rate.",
}),
},
"optional": {
"alpha": ("FLOAT", {
"default": 0.0, "min": 0.0, "max": 256.0, "step": 0.5,
"tooltip": "LoRA alpha. 0 = use rank value (scale = 1.0).",
}),
"target": ("STRING", {
"default": "attn.qkv",
"tooltip": "Space-separated layer name suffixes to wrap. Default targets all QKV projections. Add 'linear1' for post-attention projections.",
}),
"batch_size": ("INT", {"default": 4, "min": 1, "max": 32,
"tooltip": "Number of clips per training step. Higher = more stable gradients, more VRAM."}),
"warmup_steps": ("INT", {"default": 100, "min": 0, "max": 5000}),
"grad_accum": ("INT", {"default": 1, "min": 1, "max": 32,
"tooltip": "Gradient accumulation steps. Usually 1 when batch_size > 1."}),
"save_every": ("INT", {"default": 500, "min": 50, "max": 10000}),
"resume_path": ("STRING", {
"default": "",
"tooltip": "Path to a step checkpoint (.pt) to resume training from.",
}),
"seed": ("INT", {"default": 42}),
},
}
RETURN_TYPES = ("SELVA_MODEL", "STRING", "IMAGE")
RETURN_NAMES = ("model", "adapter_path", "loss_curve")
OUTPUT_TOOLTIPS = (
"Model with trained LoRA adapter applied — connect directly to Sampler.",
"Path to adapter_final.pt — use with SelVA LoRA Loader in future sessions.",
"Training loss curve.",
)
FUNCTION = "train"
CATEGORY = SELVA_CATEGORY
DESCRIPTION = (
"Trains a LoRA adapter on a dataset of .npz feature files + paired audio files. "
"Blocks the queue for the duration of training. "
"Prepare the dataset with SelVA Feature Extractor (set a name to get numbered .npz files) "
"and pair each .npz with a clean audio file of the same stem."
)
def train(self, model, data_dir, output_dir, steps, rank, lr,
alpha=0.0, target="attn.qkv", batch_size=4, warmup_steps=100,
grad_accum=1, save_every=500, resume_path="", seed=42):
torch.manual_seed(seed)
random.seed(seed)
device = get_device()
dtype = model["dtype"]
variant = model["variant"]
mode = model["mode"]
seq_cfg = model["seq_cfg"]
feature_utils_orig = model["feature_utils"]
data_dir = Path(data_dir.strip())
_out_str = output_dir.strip()
_out_p = Path(_out_str)
# On Windows a Unix-style path like "/lora_output" is technically absolute
# (drive-relative) but the user almost certainly meant a subfolder of the
# ComfyUI output directory. Treat any non-absolute path AND any path whose
# only "absolute" anchor is a leading slash (no drive letter) as relative to
# the ComfyUI output folder.
import sys as _sys
_unix_style_on_windows = (
_sys.platform == "win32"
and _out_p.is_absolute()
and not _out_p.drive # e.g. Path("/foo").drive == "" on Windows
)
if not _out_p.is_absolute() or _unix_style_on_windows:
_out_p = Path(folder_paths.get_output_directory()) / _out_p.relative_to(_out_p.anchor)
print(f"[LoRA Trainer] output_dir resolved to: {_out_p}", flush=True)
output_dir = _out_p
output_dir.mkdir(parents=True, exist_ok=True)
alpha_val = float(alpha) if alpha > 0.0 else float(rank)
target_suffixes = tuple(target.strip().split())
# --- Load VAE encoder (not present in inference model) ---
vae_name = "v1-16.pth" if mode == "16k" else "v1-44.pth"
vae_path = _SELVA_DIR / "ext" / vae_name
if not vae_path.exists():
raise FileNotFoundError(
f"[LoRA Trainer] VAE weight not found: {vae_path}. "
"Run SelVA Model Loader first to auto-download weights."
)
print("[LoRA Trainer] Loading VAE encoder...", flush=True)
# Keep VAE in float32: mel_converter uses torch.stft which requires float32 input.
vae_utils = FeaturesUtils(
tod_vae_ckpt=str(vae_path),
enable_conditions=False,
mode=mode,
need_vae_encoder=True,
).to(device).eval()
# --- Pre-load dataset ---
npz_files = sorted(data_dir.glob("*.npz"))
if not npz_files:
raise ValueError(f"[LoRA Trainer] No .npz files found in {data_dir}")
prompt_map = _load_prompts(data_dir)
default_prompt = data_dir.name
print(f"[LoRA Trainer] Pre-loading {len(npz_files)} clip(s)...", flush=True)
pbar_load = comfy.utils.ProgressBar(len(npz_files))
dataset = []
for npz_path in npz_files:
audio_path = _find_audio(npz_path)
if audio_path is None:
print(f" [LoRA Trainer] Warning: no audio for {npz_path.name} — skipping", flush=True)
pbar_load.update(1)
continue
bundle = _load_npz(npz_path)
prompt = prompt_map.get(npz_path.name, bundle.get("prompt", default_prompt))
print(f" {npz_path.name} + {audio_path.name}: '{prompt}'", flush=True)
try:
audio = _load_audio(audio_path, seq_cfg.sampling_rate, seq_cfg.duration)
# Audio → latent via VAE (float32: mel_converter/stft require float32)
# encode_audio is @inference_mode — .clone() exits inference mode
audio_b = audio.unsqueeze(0).to(device)
dist = vae_utils.encode_audio(audio_b)
# VAE outputs [B, latent_dim, T]; generator expects [B, T, latent_dim]
x1 = dist.mode().clone().transpose(1, 2).cpu()
# STFT rounding can produce ±1 frame — pad or trim to exact seq length
tgt = seq_cfg.latent_seq_len
if x1.shape[1] < tgt:
x1 = F.pad(x1, (0, 0, 0, tgt - x1.shape[1]))
elif x1.shape[1] > tgt:
x1 = x1[:, :tgt, :]
# Text → CLIP features (reuse already-loaded CLIP from inference model)
text_clip = feature_utils_orig.encode_text_clip([prompt]).cpu()
# Pad/trim clip and sync features to fixed seq lengths — clips from
# shorter videos have fewer frames and would cause stack() to fail
clip_f = bundle["clip_features"] # [1, N_clip, 1024]
c_tgt = seq_cfg.clip_seq_len
if clip_f.shape[1] < c_tgt:
clip_f = F.pad(clip_f, (0, 0, 0, c_tgt - clip_f.shape[1]))
elif clip_f.shape[1] > c_tgt:
clip_f = clip_f[:, :c_tgt, :]
sync_f = bundle["sync_features"] # [1, N_sync, 768]
s_tgt = seq_cfg.sync_seq_len
if sync_f.shape[1] < s_tgt:
sync_f = F.pad(sync_f, (0, 0, 0, s_tgt - sync_f.shape[1]))
elif sync_f.shape[1] > s_tgt:
sync_f = sync_f[:, :s_tgt, :]
dataset.append((x1, clip_f, sync_f, text_clip))
except Exception as e:
print(f" [LoRA Trainer] Warning: failed {npz_path.name}: {e}", flush=True)
traceback.print_exc()
pbar_load.update(1)
# VAE no longer needed — free memory
del vae_utils
soft_empty_cache()
if not dataset:
raise ValueError("[LoRA Trainer] No clips could be loaded.")
print(f"[LoRA Trainer] {len(dataset)} clip(s) ready.", flush=True)
# ComfyUI executes nodes inside torch.inference_mode(). Inference tensors
# can't participate in autograd even with enable_grad — disable inference
# mode entirely so deepcopy, apply_lora, and the training loop all run
# with a clean autograd context.
with torch.inference_mode(False), torch.enable_grad():
return self._train_inner(
model, dataset, feature_utils_orig, seq_cfg,
device, dtype, variant, mode,
data_dir, output_dir, steps, rank, lr,
alpha_val, target_suffixes, batch_size, warmup_steps,
grad_accum, save_every, resume_path, seed,
)
def _train_inner(
self, model, dataset, feature_utils_orig, seq_cfg,
device, dtype, variant, mode,
data_dir, output_dir, steps, rank, lr,
alpha_val, target_suffixes, batch_size, warmup_steps,
grad_accum, save_every, resume_path, seed,
):
# --- Prepare generator copy with LoRA ---
generator = copy.deepcopy(model["generator"]).to(device, dtype)
n_lora = apply_lora(generator, rank=rank, alpha=alpha_val,
target_suffixes=target_suffixes)
if n_lora == 0:
raise RuntimeError(
f"[LoRA Trainer] No layers matched target={target_suffixes}. "
"Check the 'target' field."
)
print(f"[LoRA Trainer] Wrapped {n_lora} layers (rank={rank}, alpha={alpha_val})", flush=True)
for name, p in generator.named_parameters():
p.requires_grad_("lora_" in name)
generator.update_seq_lengths(
latent_seq_len=seq_cfg.latent_seq_len,
clip_seq_len=seq_cfg.clip_seq_len,
sync_seq_len=seq_cfg.sync_seq_len,
)
# --- Optimizer + scheduler ---
lora_params = [p for p in generator.parameters() if p.requires_grad]
optimizer = torch.optim.AdamW(lora_params, lr=lr, weight_decay=1e-2)
def lr_lambda(s):
return s / max(1, warmup_steps) if s < warmup_steps else 1.0
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
fm = FlowMatching(min_sigma=0, inference_mode="euler", num_steps=25)
# --- Resume ---
start_step = 0
if resume_path.strip():
ckpt = torch.load(resume_path.strip(), map_location="cpu", weights_only=False)
if "step" not in ckpt:
raise ValueError("[LoRA Trainer] Checkpoint has no step info.")
start_step = ckpt["step"]
if start_step >= steps:
raise ValueError(
f"[LoRA Trainer] Checkpoint already at step {start_step} >= steps {steps}."
)
load_lora(generator, ckpt["state_dict"])
optimizer.load_state_dict(ckpt["optimizer"])
scheduler.load_state_dict(ckpt["scheduler"])
print(f"[LoRA Trainer] Resumed from step {start_step}.", flush=True)
# --- Training loop ---
generator.train()
optimizer.zero_grad()
log_interval = 50
remaining = steps - start_step
pbar_train = comfy.utils.ProgressBar(remaining)
loss_history = []
running_loss = 0.0
meta = {
"variant": variant,
"rank": rank,
"alpha": alpha_val,
"target": list(target_suffixes),
"steps": steps,
}
print(f"\n[LoRA Trainer] Training {remaining} steps "
f"(step {start_step + 1}{steps}, batch_size={batch_size})\n", flush=True)
last_step = start_step
completed = False
try:
for step in range(start_step + 1, steps + 1):
batch = random.choices(dataset, k=batch_size)
x1_list, clip_list, sync_list, text_list = zip(*batch)
x1 = torch.stack([x.squeeze(0) for x in x1_list]).to(device, dtype)
clip_f = torch.stack([x.squeeze(0) for x in clip_list]).to(device, dtype)
sync_f = torch.stack([x.squeeze(0) for x in sync_list]).to(device, dtype)
text_clip = torch.stack([x.squeeze(0) for x in text_list]).to(device, dtype)
generator.normalize(x1)
t = torch.rand(batch_size, device=device, dtype=dtype)
x0 = torch.randn_like(x1)
xt = fm.get_conditional_flow(x0, x1, t)
v_pred = generator.forward(xt, clip_f, sync_f, text_clip, t)
loss = fm.loss(v_pred, x0, x1).mean() / grad_accum
loss.backward()
running_loss += loss.item() * grad_accum
if step % grad_accum == 0:
torch.nn.utils.clip_grad_norm_(lora_params, max_norm=1.0)
optimizer.step()
scheduler.step()
optimizer.zero_grad()
if step % log_interval == 0:
avg = running_loss / log_interval
loss_history.append(avg)
lr_now = scheduler.get_last_lr()[0]
print(f"[LoRA Trainer] step {step:5d}/{steps} "
f"loss={avg:.4f} lr={lr_now:.2e} bs={batch_size}", flush=True)
running_loss = 0.0
# Live preview: send updated loss curve to ComfyUI frontend
preview_img = _draw_loss_curve(loss_history, log_interval, start_step,
smoothed=_smooth_losses(loss_history))
pbar_train.update_absolute(
step - start_step, remaining, ("JPEG", preview_img, 800)
)
if step % save_every == 0 or step == steps:
ckpt_path = output_dir / f"adapter_step{step:05d}.pt"
torch.save({
"state_dict": get_lora_state_dict(generator),
"optimizer": optimizer.state_dict(),
"scheduler": scheduler.state_dict(),
"step": step,
"meta": meta,
}, ckpt_path)
print(f"[LoRA Trainer] Saved {ckpt_path}", flush=True)
# Save a quick eval sample next to the checkpoint
wav, sr = _eval_sample(generator, feature_utils_orig,
dataset, seq_cfg, device, dtype)
if wav is not None:
wav_path = output_dir / f"sample_step{step:05d}.wav"
try:
torchaudio.save(str(wav_path), wav, sr)
except RuntimeError:
import soundfile as sf
sf.write(str(wav_path), wav.squeeze(0).numpy(), sr)
print(f"[LoRA Trainer] Sample saved: {wav_path}", flush=True)
last_step = step
pbar_train.update(1)
completed = True
finally:
# Save adapter and loss curves whether training completed or was cancelled.
# Skip if we never completed a single step (nothing useful to save).
if loss_history:
if completed:
# Normal completion — use adapter_final.pt (increment if exists)
final_path = output_dir / "adapter_final.pt"
if final_path.exists():
i = 1
while (output_dir / f"adapter_final_{i:03d}.pt").exists():
i += 1
final_path = output_dir / f"adapter_final_{i:03d}.pt"
label = "Done"
else:
# Cancelled — include the step number so the file is useful for resume
final_path = output_dir / f"adapter_cancelled_step{last_step:05d}.pt"
label = f"Cancelled at step {last_step}"
torch.save({"state_dict": get_lora_state_dict(generator), "meta": meta}, final_path)
(output_dir / "meta.json").write_text(json.dumps(meta, indent=2))
print(f"\n[LoRA Trainer] {label}. Adapter saved to {final_path}", flush=True)
smoothed = _smooth_losses(loss_history)
raw_img = _draw_loss_curve(loss_history, log_interval, start_step)
smoothed_img = _draw_loss_curve(loss_history, log_interval, start_step,
smoothed=smoothed)
raw_img.save(str(output_dir / "loss_raw.png"))
smoothed_img.save(str(output_dir / "loss_smoothed.png"))
print(f"[LoRA Trainer] Loss curves saved to {output_dir}", flush=True)
# Reached only on normal completion (exception re-raises past this point)
generator.eval()
generator.to(next(model["generator"].parameters()).device)
patched = {**model, "generator": generator}
loss_curve = _pil_to_tensor(smoothed_img)
return (patched, str(final_path), loss_curve)
-118
View File
@@ -1,118 +0,0 @@
"""
LoRA (Low-Rank Adaptation) for SelVA / MMAudio generator.
Usage:
from selva_core.model.lora import apply_lora, get_lora_state_dict, load_lora
n = apply_lora(net_generator, rank=16, alpha=16.0)
print(f"Wrapped {n} linear layers with LoRA")
# ... train only LoRA params ...
torch.save(get_lora_state_dict(net_generator), "adapter.pt")
# Later, at inference:
apply_lora(net_generator, rank=16, alpha=16.0)
load_lora(net_generator, torch.load("adapter.pt"))
"""
import math
import torch
import torch.nn as nn
class LoRALinear(nn.Module):
"""nn.Linear with a frozen base weight and trainable low-rank A/B matrices.
Output: base(x) + (x @ A.T @ B.T) * (alpha / rank)
A is initialised with Kaiming uniform; B is initialised to zero so the
adapter contribution starts at zero and does not disturb pretrained behaviour.
"""
def __init__(self, linear: nn.Linear, rank: int, alpha: float):
super().__init__()
in_f = linear.in_features
out_f = linear.out_features
self.linear = linear
linear.weight.requires_grad_(False)
if linear.bias is not None:
linear.bias.requires_grad_(False)
ref_dtype = linear.weight.dtype
ref_device = linear.weight.device
self.lora_A = nn.Parameter(torch.empty(rank, in_f, dtype=ref_dtype, device=ref_device))
self.lora_B = nn.Parameter(torch.zeros(out_f, rank, dtype=ref_dtype, device=ref_device))
self.scale = alpha / rank
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.linear(x) + (x @ self.lora_A.T @ self.lora_B.T) * self.scale
def extra_repr(self) -> str:
rank = self.lora_A.shape[0]
return (f"in={self.linear.in_features}, out={self.linear.out_features}, "
f"rank={rank}, scale={self.scale:.4f}")
def apply_lora(
model: nn.Module,
rank: int = 16,
alpha: float = None,
target_suffixes: tuple = ("attn.qkv",),
) -> int:
"""Replace matching nn.Linear layers with LoRALinear in-place.
Args:
model: The module to modify (typically net_generator).
rank: LoRA rank.
alpha: LoRA alpha (scaling). Defaults to rank (scale = 1.0).
target_suffixes: Tuple of module name suffixes to wrap. Default is
("attn.qkv",) which targets all SelfAttention QKV
projections in the MM-DiT generator.
Add "linear1" to also wrap post-attention output projections.
Returns:
Number of linear layers wrapped.
"""
if alpha is None:
alpha = float(rank)
count = 0
for name, module in list(model.named_modules()):
if not any(name.endswith(s) for s in target_suffixes):
continue
if not isinstance(module, nn.Linear):
continue
parts = name.split(".")
parent = model
for part in parts[:-1]:
parent = getattr(parent, part)
setattr(parent, parts[-1], LoRALinear(module, rank, alpha))
count += 1
return count
def get_lora_state_dict(model: nn.Module) -> dict:
"""Return a state dict containing only LoRA parameters (lora_A and lora_B)."""
return {k: v for k, v in model.state_dict().items() if "lora_" in k}
def load_lora(model: nn.Module, state_dict: dict) -> None:
"""Load LoRA weights into a model that has already had apply_lora() called.
Non-LoRA keys in state_dict are ignored (strict=False). Non-LoRA model
parameters are not modified.
"""
missing, unexpected = model.load_state_dict(state_dict, strict=False)
bad = [k for k in unexpected if "lora_" not in k]
if bad:
print(f"[LoRA] Warning: unexpected non-LoRA keys ignored: {bad}")
lora_missing = [k for k in missing if "lora_" in k]
if lora_missing:
print(f"[LoRA] Warning: missing LoRA keys (wrong rank/target?): {lora_missing}")
-422
View File
@@ -1,422 +0,0 @@
#!/usr/bin/env python3
"""
LoRA fine-tuning for SelVA / MMAudio generator.
Teaches the model new or partially-known sound classes from custom video+audio pairs.
Only the LoRA adapter weights are trained (~10 MB vs ~4.4 GB for the full model).
Data layout:
data/my_sound/
clip01.npz # visual features extracted by SelvaFeatureExtractor in ComfyUI
clip01.wav # paired clean audio (same filename stem, any format)
prompts.txt # optional: "clip01.npz: description" — overrides embedded prompt
If prompts.txt is absent, the prompt embedded in each .npz is used.
If the .npz has no embedded prompt, the directory name is used as fallback.
Usage:
python train_lora.py \\
--data_dir data/my_sound \\
--output_dir lora_output \\
--variant large_44k \\
--selva_dir /path/to/ComfyUI/models/selva \\
--rank 16 --steps 2000 --lr 1e-4
"""
import argparse
import os
import sys
import random
import json
from pathlib import Path
import numpy as np
import torch
import torch.nn.functional as F
import torchaudio
import open_clip
from open_clip import create_model_from_pretrained
sys.path.insert(0, os.path.dirname(__file__))
from selva_core.model.networks_generator import get_my_mmaudio
from selva_core.model.utils.features_utils import FeaturesUtils, patch_clip
from selva_core.model.sequence_config import CONFIG_16K, CONFIG_44K
from selva_core.model.flow_matching import FlowMatching
from selva_core.model.lora import apply_lora, get_lora_state_dict
# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
_VARIANTS = {
"small_16k": ("generator_small_16k_sup_5.pth", "16k"),
"small_44k": ("generator_small_44k_sup_5.pth", "44k"),
"medium_44k": ("generator_medium_44k_sup_5.pth", "44k"),
"large_44k": ("generator_large_44k_sup_5.pth", "44k"),
}
_AUDIO_EXTS = {".wav", ".flac", ".mp3", ".ogg", ".aiff", ".aif"}
# ---------------------------------------------------------------------------
# Data helpers
# ---------------------------------------------------------------------------
def load_prompts(data_dir: Path) -> dict:
"""Load filename → prompt overrides from prompts.txt."""
p = data_dir / "prompts.txt"
if not p.exists():
return {}
mapping = {}
for line in p.read_text(encoding="utf-8").splitlines():
line = line.strip()
if not line or line.startswith("#"):
continue
if ":" in line:
fname, prompt = line.split(":", 1)
mapping[fname.strip()] = prompt.strip()
return mapping
def find_audio_for_npz(npz_path: Path) -> Path | None:
"""Find a paired audio file with the same stem as the .npz."""
for ext in _AUDIO_EXTS:
candidate = npz_path.with_suffix(ext)
if candidate.exists():
return candidate
return None
def load_audio(path: Path, target_sr: int, duration: float) -> torch.Tensor:
"""Load an audio file → [L] float32 [-1, 1], resampled and trimmed/padded to duration."""
waveform, sr = torchaudio.load(str(path))
# Stereo → mono
if waveform.shape[0] > 1:
waveform = waveform.mean(0, keepdim=True)
waveform = waveform.squeeze(0).float()
# Resample
if sr != target_sr:
waveform = torchaudio.functional.resample(
waveform.unsqueeze(0), sr, target_sr
).squeeze(0)
target_len = int(duration * target_sr)
if waveform.shape[0] >= target_len:
return waveform[:target_len]
return F.pad(waveform, (0, target_len - waveform.shape[0]))
def load_npz(path: Path) -> dict:
"""Load a feature bundle produced by SelvaFeatureExtractor."""
data = np.load(str(path), allow_pickle=False)
bundle = {
"clip_features": torch.from_numpy(data["clip_features"]), # [1, N, 1024]
"sync_features": torch.from_numpy(data["sync_features"]), # [1, T, 768]
}
if "prompt" in data:
bundle["prompt"] = str(data["prompt"])
if "variant" in data:
bundle["variant"] = str(data["variant"])
return bundle
# ---------------------------------------------------------------------------
# Feature extraction (audio + text only — visual features come from .npz)
# ---------------------------------------------------------------------------
def encode_text_clip(clip_model, tokenizer, text: list[str], device) -> torch.Tensor:
tokens = tokenizer(text).to(device)
with torch.inference_mode():
return clip_model.encode_text(tokens, normalize=True)
def extract_audio_latent(audio: torch.Tensor, feature_utils, device, dtype) -> torch.Tensor:
"""Encode a waveform to the generator's latent space via the VAE.
encode_audio is @inference_mode — .clone() is required before the autograd path.
"""
audio_b = audio.unsqueeze(0).to(device, dtype) # [1, L]
dist = feature_utils.encode_audio(audio_b)
# VAE outputs [B, latent_dim, T]; generator expects [B, T, latent_dim]
return dist.mode().clone().transpose(1, 2).cpu() # [1, seq_len, latent_dim]
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main():
parser = argparse.ArgumentParser(description="LoRA fine-tuning for SelVA generator")
parser.add_argument("--data_dir", required=True, help="Directory with .npz + audio pairs and optional prompts.txt")
parser.add_argument("--output_dir", default="lora_output")
parser.add_argument("--variant", default="large_44k", choices=list(_VARIANTS.keys()))
parser.add_argument("--selva_dir", required=True, help="Path to selva model weights (ComfyUI/models/selva)")
parser.add_argument("--rank", type=int, default=16, help="LoRA rank")
parser.add_argument("--alpha", type=float, default=None, help="LoRA alpha (default: rank)")
parser.add_argument("--target", nargs="+", default=["attn.qkv"],
help="Module name suffixes to wrap with LoRA. Also try 'linear1'.")
parser.add_argument("--lr", type=float, default=1e-4)
parser.add_argument("--steps", type=int, default=2000)
parser.add_argument("--warmup_steps",type=int, default=100)
parser.add_argument("--batch_size", type=int, default=4, help="Clips per training step")
parser.add_argument("--grad_accum", type=int, default=1, help="Gradient accumulation steps")
parser.add_argument("--save_every", type=int, default=500)
parser.add_argument("--resume", default=None,
help="Path to a step checkpoint (.pt) to resume training from.")
parser.add_argument("--precision", default="bf16", choices=["bf16", "fp16", "fp32"])
parser.add_argument("--seed", type=int, default=42)
args = parser.parse_args()
torch.manual_seed(args.seed)
random.seed(args.seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if args.precision == "bf16" and device.type == "cuda" and not torch.cuda.is_bf16_supported():
print("[LoRA] bf16 not supported on this GPU — falling back to fp16")
args.precision = "fp16"
dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[args.precision]
data_dir = Path(args.data_dir)
output_dir = Path(args.output_dir)
selva_dir = Path(args.selva_dir)
output_dir.mkdir(parents=True, exist_ok=True)
gen_filename, mode = _VARIANTS[args.variant]
seq_cfg = CONFIG_16K if mode == "16k" else CONFIG_44K
duration = seq_cfg.duration
sample_rate = seq_cfg.sampling_rate
# --- Weight paths ---
def w(name): return str(selva_dir / name)
def wext(name): return str(selva_dir / "ext" / name)
vae_weight = wext("v1-16.pth" if mode == "16k" else "v1-44.pth")
gen_weight = w(gen_filename)
for path, label in [(vae_weight, "VAE"), (gen_weight, "generator")]:
if not Path(path).exists():
print(f"[LoRA] Missing weight: {path} ({label})")
print("[LoRA] Run ComfyUI with SelvaModelLoader first to auto-download weights.")
sys.exit(1)
# --- Load CLIP text encoder (separate from FeaturesUtils to avoid loading Synchformer/T5) ---
print("[LoRA] Loading CLIP text encoder...")
clip_model = create_model_from_pretrained(
'hf-hub:apple/DFN5B-CLIP-ViT-H-14-384', return_transform=False
).to(device, dtype).eval()
clip_model = patch_clip(clip_model)
tokenizer_clip = open_clip.get_tokenizer('ViT-H-14-378-quickgelu')
# --- Load VAE (FeaturesUtils with enable_conditions=False — no Synchformer/T5) ---
print("[LoRA] Loading VAE encoder...")
feature_utils = FeaturesUtils(
tod_vae_ckpt=vae_weight,
enable_conditions=False,
mode=mode,
need_vae_encoder=True,
).to(device, dtype).eval()
# --- Load generator ---
print(f"[LoRA] Loading generator ({args.variant})...")
net_generator = get_my_mmaudio(args.variant).to(device, dtype).eval()
net_generator.load_weights(
torch.load(gen_weight, map_location="cpu", weights_only=False)
)
# --- Apply LoRA ---
n_lora = apply_lora(
net_generator,
rank=args.rank,
alpha=args.alpha,
target_suffixes=tuple(args.target),
)
print(f"[LoRA] Wrapped {n_lora} linear layers (rank={args.rank}, target={args.target})")
if n_lora == 0:
print("[LoRA] ERROR: no layers were wrapped — check --target names.")
sys.exit(1)
# Freeze everything except LoRA params
for name, p in net_generator.named_parameters():
p.requires_grad_("lora_" in name)
trainable = sum(p.numel() for p in net_generator.parameters() if p.requires_grad)
total = sum(p.numel() for p in net_generator.parameters())
print(f"[LoRA] Trainable: {trainable:,} / {total:,} params "
f"({100 * trainable / total:.2f}%)")
net_generator.update_seq_lengths(
latent_seq_len=seq_cfg.latent_seq_len,
clip_seq_len=seq_cfg.clip_seq_len,
sync_seq_len=seq_cfg.sync_seq_len,
)
# --- Dataset ---
npz_files = sorted(data_dir.glob("*.npz"))
if not npz_files:
print(f"[LoRA] No .npz files found in {data_dir}")
sys.exit(1)
prompt_map = load_prompts(data_dir)
default_prompt = data_dir.name
print(f"[LoRA] Pre-loading {len(npz_files)} clip(s)...")
dataset = []
for npz_path in npz_files:
audio_path = find_audio_for_npz(npz_path)
if audio_path is None:
print(f" [LoRA] Warning: no audio file found for {npz_path.name} — skipping")
continue
bundle = load_npz(npz_path)
# Prompt priority: prompts.txt override > embedded in .npz > directory name
prompt = prompt_map.get(npz_path.name, bundle.get("prompt", default_prompt))
print(f" {npz_path.name} + {audio_path.name}: '{prompt}'")
try:
audio = load_audio(audio_path, sample_rate, duration)
x1 = extract_audio_latent(audio, feature_utils, device, dtype)
# STFT rounding can produce ±1 frame — pad or trim to exact seq length
tgt = seq_cfg.latent_seq_len
if x1.shape[1] < tgt:
x1 = F.pad(x1, (0, 0, 0, tgt - x1.shape[1]))
elif x1.shape[1] > tgt:
x1 = x1[:, :tgt, :]
text_clip = encode_text_clip(clip_model, tokenizer_clip, [prompt], device).cpu()
# Pad/trim clip and sync features to fixed seq lengths — shorter clips
# have fewer frames and would cause stack() to fail during batching
clip_f = bundle["clip_features"] # [1, N_clip, 1024]
c_tgt = seq_cfg.clip_seq_len
if clip_f.shape[1] < c_tgt:
clip_f = F.pad(clip_f, (0, 0, 0, c_tgt - clip_f.shape[1]))
elif clip_f.shape[1] > c_tgt:
clip_f = clip_f[:, :c_tgt, :]
sync_f = bundle["sync_features"] # [1, N_sync, 768]
s_tgt = seq_cfg.sync_seq_len
if sync_f.shape[1] < s_tgt:
sync_f = F.pad(sync_f, (0, 0, 0, s_tgt - sync_f.shape[1]))
elif sync_f.shape[1] > s_tgt:
sync_f = sync_f[:, :s_tgt, :]
dataset.append((x1, clip_f, sync_f, text_clip))
except Exception as e:
print(f" [LoRA] Warning: failed to process {npz_path.name}: {e}")
if not dataset:
print("[LoRA] No clips could be loaded.")
sys.exit(1)
print(f"[LoRA] {len(dataset)} clip(s) ready.")
# --- Optimizer + LR scheduler ---
lora_params = [p for p in net_generator.parameters() if p.requires_grad]
optimizer = torch.optim.AdamW(lora_params, lr=args.lr, weight_decay=1e-2)
def lr_lambda(step):
if step < args.warmup_steps:
return step / max(1, args.warmup_steps)
return 1.0
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
fm = FlowMatching(min_sigma=0, inference_mode="euler", num_steps=25)
# --- Resume ---
start_step = 0
if args.resume:
ckpt = torch.load(args.resume, map_location="cpu", weights_only=False)
if "step" not in ckpt:
print("[LoRA] ERROR: checkpoint has no step info — was it saved by this script?")
sys.exit(1)
start_step = ckpt["step"]
if start_step >= args.steps:
print(f"[LoRA] Checkpoint is already at step {start_step} >= --steps {args.steps}. Nothing to do.")
sys.exit(0)
net_generator.load_state_dict(ckpt["state_dict"], strict=False)
optimizer.load_state_dict(ckpt["optimizer"])
scheduler.load_state_dict(ckpt["scheduler"])
print(f"[LoRA] Resumed from {Path(args.resume).name} (step {start_step}{args.steps})")
# --- Training loop ---
net_generator.train()
optimizer.zero_grad()
remaining = args.steps - start_step
print(f"\n[LoRA] Training: {remaining} steps (step {start_step + 1}{args.steps}), "
f"batch_size={args.batch_size}, lr={args.lr}, grad_accum={args.grad_accum}")
print(f"[LoRA] Checkpoints every {args.save_every} steps → {output_dir}\n")
total_loss = 0.0
for step in range(start_step + 1, args.steps + 1):
batch = random.choices(dataset, k=args.batch_size)
x1_list, clip_list, sync_list, text_list = zip(*batch)
x1 = torch.stack([x.squeeze(0) for x in x1_list]).to(device, dtype)
clip_f = torch.stack([x.squeeze(0) for x in clip_list]).to(device, dtype)
sync_f = torch.stack([x.squeeze(0) for x in sync_list]).to(device, dtype)
text_clip = torch.stack([x.squeeze(0) for x in text_list]).to(device, dtype)
net_generator.normalize(x1)
t = torch.rand(args.batch_size, device=device, dtype=dtype)
x0 = torch.randn_like(x1)
xt = fm.get_conditional_flow(x0, x1, t)
v_pred = net_generator.forward(xt, clip_f, sync_f, text_clip, t)
loss = fm.loss(v_pred, x0, x1).mean() / args.grad_accum
loss.backward()
total_loss += loss.item() * args.grad_accum
if step % args.grad_accum == 0:
torch.nn.utils.clip_grad_norm_(lora_params, max_norm=1.0)
optimizer.step()
scheduler.step()
optimizer.zero_grad()
if step % 50 == 0:
avg = total_loss / 50
lr_now = scheduler.get_last_lr()[0]
print(f"[LoRA] step {step:5d}/{args.steps} loss={avg:.4f} lr={lr_now:.2e}")
total_loss = 0.0
if step % args.save_every == 0 or step == args.steps:
ckpt_path = output_dir / f"adapter_step{step:05d}.pt"
torch.save({
"state_dict": get_lora_state_dict(net_generator),
"optimizer": optimizer.state_dict(),
"scheduler": scheduler.state_dict(),
"step": step,
"meta": {
"variant": args.variant,
"rank": args.rank,
"alpha": args.alpha if args.alpha is not None else float(args.rank),
"target": args.target,
"steps": args.steps,
},
}, ckpt_path)
print(f"[LoRA] Saved {ckpt_path}")
# Save final adapter with embedded metadata
# Increment filename if a previous final already exists (resume case)
final = output_dir / "adapter_final.pt"
if final.exists():
i = 1
while (output_dir / f"adapter_final_{i:03d}.pt").exists():
i += 1
final = output_dir / f"adapter_final_{i:03d}.pt"
meta = {
"variant": args.variant,
"rank": args.rank,
"alpha": args.alpha if args.alpha is not None else float(args.rank),
"target": args.target,
"steps": args.steps,
}
torch.save({"state_dict": get_lora_state_dict(net_generator), "meta": meta}, final)
(output_dir / "meta.json").write_text(json.dumps(meta, indent=2))
print(f"\n[LoRA] Training complete. Adapter saved to {final}")
if __name__ == "__main__":
main()