fix: pad/trim clip and sync features to fixed seq_len at dataset load time
Clips from shorter videos produce fewer CLIP frames (e.g. 2s → 16 frames, 8s → 64 frames). Mixed-length datasets would cause torch.stack() to fail during batching. Normalize to seq_cfg.clip_seq_len / sync_seq_len at load, same as latents are already normalized to latent_seq_len. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -383,7 +383,23 @@ class SelvaLoraTrainer:
|
|||||||
# Text → CLIP features (reuse already-loaded CLIP from inference model)
|
# Text → CLIP features (reuse already-loaded CLIP from inference model)
|
||||||
text_clip = feature_utils_orig.encode_text_clip([prompt]).cpu()
|
text_clip = feature_utils_orig.encode_text_clip([prompt]).cpu()
|
||||||
|
|
||||||
dataset.append((x1, bundle["clip_features"], bundle["sync_features"], text_clip))
|
# Pad/trim clip and sync features to fixed seq lengths — clips from
|
||||||
|
# shorter videos have fewer frames and would cause stack() to fail
|
||||||
|
clip_f = bundle["clip_features"] # [1, N_clip, 1024]
|
||||||
|
c_tgt = seq_cfg.clip_seq_len
|
||||||
|
if clip_f.shape[1] < c_tgt:
|
||||||
|
clip_f = F.pad(clip_f, (0, 0, 0, c_tgt - clip_f.shape[1]))
|
||||||
|
elif clip_f.shape[1] > c_tgt:
|
||||||
|
clip_f = clip_f[:, :c_tgt, :]
|
||||||
|
|
||||||
|
sync_f = bundle["sync_features"] # [1, N_sync, 768]
|
||||||
|
s_tgt = seq_cfg.sync_seq_len
|
||||||
|
if sync_f.shape[1] < s_tgt:
|
||||||
|
sync_f = F.pad(sync_f, (0, 0, 0, s_tgt - sync_f.shape[1]))
|
||||||
|
elif sync_f.shape[1] > s_tgt:
|
||||||
|
sync_f = sync_f[:, :s_tgt, :]
|
||||||
|
|
||||||
|
dataset.append((x1, clip_f, sync_f, text_clip))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f" [LoRA Trainer] Warning: failed {npz_path.name}: {e}", flush=True)
|
print(f" [LoRA Trainer] Warning: failed {npz_path.name}: {e}", flush=True)
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|||||||
+18
-1
@@ -288,7 +288,24 @@ def main():
|
|||||||
elif x1.shape[1] > tgt:
|
elif x1.shape[1] > tgt:
|
||||||
x1 = x1[:, :tgt, :]
|
x1 = x1[:, :tgt, :]
|
||||||
text_clip = encode_text_clip(clip_model, tokenizer_clip, [prompt], device).cpu()
|
text_clip = encode_text_clip(clip_model, tokenizer_clip, [prompt], device).cpu()
|
||||||
dataset.append((x1, bundle["clip_features"], bundle["sync_features"], text_clip))
|
|
||||||
|
# Pad/trim clip and sync features to fixed seq lengths — shorter clips
|
||||||
|
# have fewer frames and would cause stack() to fail during batching
|
||||||
|
clip_f = bundle["clip_features"] # [1, N_clip, 1024]
|
||||||
|
c_tgt = seq_cfg.clip_seq_len
|
||||||
|
if clip_f.shape[1] < c_tgt:
|
||||||
|
clip_f = F.pad(clip_f, (0, 0, 0, c_tgt - clip_f.shape[1]))
|
||||||
|
elif clip_f.shape[1] > c_tgt:
|
||||||
|
clip_f = clip_f[:, :c_tgt, :]
|
||||||
|
|
||||||
|
sync_f = bundle["sync_features"] # [1, N_sync, 768]
|
||||||
|
s_tgt = seq_cfg.sync_seq_len
|
||||||
|
if sync_f.shape[1] < s_tgt:
|
||||||
|
sync_f = F.pad(sync_f, (0, 0, 0, s_tgt - sync_f.shape[1]))
|
||||||
|
elif sync_f.shape[1] > s_tgt:
|
||||||
|
sync_f = sync_f[:, :s_tgt, :]
|
||||||
|
|
||||||
|
dataset.append((x1, clip_f, sync_f, text_clip))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f" [LoRA] Warning: failed to process {npz_path.name}: {e}")
|
print(f" [LoRA] Warning: failed to process {npz_path.name}: {e}")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user