fix: pad/trim clip and sync features to fixed seq_len at dataset load time

Clips from shorter videos produce fewer CLIP frames (e.g. 2s → 16 frames,
8s → 64 frames). Mixed-length datasets would cause torch.stack() to fail
during batching. Normalize to seq_cfg.clip_seq_len / sync_seq_len at load,
same as latents are already normalized to latent_seq_len.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-06 00:51:45 +02:00
parent 8ae0ba3c7d
commit 8338560600
2 changed files with 35 additions and 2 deletions
+17 -1
View File
@@ -370,7 +370,23 @@ class SelvaLoraTrainer:
# Text → CLIP features (reuse already-loaded CLIP from inference model) # Text → CLIP features (reuse already-loaded CLIP from inference model)
text_clip = feature_utils_orig.encode_text_clip([prompt]).cpu() text_clip = feature_utils_orig.encode_text_clip([prompt]).cpu()
dataset.append((x1, bundle["clip_features"], bundle["sync_features"], text_clip)) # Pad/trim clip and sync features to fixed seq lengths — clips from
# shorter videos have fewer frames and would cause stack() to fail
clip_f = bundle["clip_features"] # [1, N_clip, 1024]
c_tgt = seq_cfg.clip_seq_len
if clip_f.shape[1] < c_tgt:
clip_f = F.pad(clip_f, (0, 0, 0, c_tgt - clip_f.shape[1]))
elif clip_f.shape[1] > c_tgt:
clip_f = clip_f[:, :c_tgt, :]
sync_f = bundle["sync_features"] # [1, N_sync, 768]
s_tgt = seq_cfg.sync_seq_len
if sync_f.shape[1] < s_tgt:
sync_f = F.pad(sync_f, (0, 0, 0, s_tgt - sync_f.shape[1]))
elif sync_f.shape[1] > s_tgt:
sync_f = sync_f[:, :s_tgt, :]
dataset.append((x1, clip_f, sync_f, text_clip))
except Exception as e: except Exception as e:
print(f" [LoRA Trainer] Warning: failed {npz_path.name}: {e}", flush=True) print(f" [LoRA Trainer] Warning: failed {npz_path.name}: {e}", flush=True)
traceback.print_exc() traceback.print_exc()
+18 -1
View File
@@ -284,7 +284,24 @@ def main():
elif x1.shape[1] > tgt: elif x1.shape[1] > tgt:
x1 = x1[:, :tgt, :] x1 = x1[:, :tgt, :]
text_clip = encode_text_clip(clip_model, tokenizer_clip, [prompt], device).cpu() text_clip = encode_text_clip(clip_model, tokenizer_clip, [prompt], device).cpu()
dataset.append((x1, bundle["clip_features"], bundle["sync_features"], text_clip))
# Pad/trim clip and sync features to fixed seq lengths — shorter clips
# have fewer frames and would cause stack() to fail during batching
clip_f = bundle["clip_features"] # [1, N_clip, 1024]
c_tgt = seq_cfg.clip_seq_len
if clip_f.shape[1] < c_tgt:
clip_f = F.pad(clip_f, (0, 0, 0, c_tgt - clip_f.shape[1]))
elif clip_f.shape[1] > c_tgt:
clip_f = clip_f[:, :c_tgt, :]
sync_f = bundle["sync_features"] # [1, N_sync, 768]
s_tgt = seq_cfg.sync_seq_len
if sync_f.shape[1] < s_tgt:
sync_f = F.pad(sync_f, (0, 0, 0, s_tgt - sync_f.shape[1]))
elif sync_f.shape[1] > s_tgt:
sync_f = sync_f[:, :s_tgt, :]
dataset.append((x1, clip_f, sync_f, text_clip))
except Exception as e: except Exception as e:
print(f" [LoRA] Warning: failed to process {npz_path.name}: {e}") print(f" [LoRA] Warning: failed to process {npz_path.name}: {e}")