fix: pad/trim clip and sync features to fixed seq_len at dataset load time

Clips from shorter videos produce fewer CLIP frames (e.g. 2s → 16 frames,
8s → 64 frames). Mixed-length datasets would cause torch.stack() to fail
during batching. Normalize to seq_cfg.clip_seq_len / sync_seq_len at load,
same as latents are already normalized to latent_seq_len.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-06 00:51:45 +02:00
parent a5014e49eb
commit d83632e754
2 changed files with 35 additions and 2 deletions
+18 -1
View File
@@ -288,7 +288,24 @@ def main():
elif x1.shape[1] > tgt:
x1 = x1[:, :tgt, :]
text_clip = encode_text_clip(clip_model, tokenizer_clip, [prompt], device).cpu()
dataset.append((x1, bundle["clip_features"], bundle["sync_features"], text_clip))
# Pad/trim clip and sync features to fixed seq lengths — shorter clips
# have fewer frames and would cause stack() to fail during batching
clip_f = bundle["clip_features"] # [1, N_clip, 1024]
c_tgt = seq_cfg.clip_seq_len
if clip_f.shape[1] < c_tgt:
clip_f = F.pad(clip_f, (0, 0, 0, c_tgt - clip_f.shape[1]))
elif clip_f.shape[1] > c_tgt:
clip_f = clip_f[:, :c_tgt, :]
sync_f = bundle["sync_features"] # [1, N_sync, 768]
s_tgt = seq_cfg.sync_seq_len
if sync_f.shape[1] < s_tgt:
sync_f = F.pad(sync_f, (0, 0, 0, s_tgt - sync_f.shape[1]))
elif sync_f.shape[1] > s_tgt:
sync_f = sync_f[:, :s_tgt, :]
dataset.append((x1, clip_f, sync_f, text_clip))
except Exception as e:
print(f" [LoRA] Warning: failed to process {npz_path.name}: {e}")