diff --git a/nodes/selva_lora_trainer.py b/nodes/selva_lora_trainer.py index dddbf0f..64bc525 100644 --- a/nodes/selva_lora_trainer.py +++ b/nodes/selva_lora_trainer.py @@ -370,7 +370,23 @@ class SelvaLoraTrainer: # Text → CLIP features (reuse already-loaded CLIP from inference model) text_clip = feature_utils_orig.encode_text_clip([prompt]).cpu() - dataset.append((x1, bundle["clip_features"], bundle["sync_features"], text_clip)) + # Pad/trim clip and sync features to fixed seq lengths — clips from + # shorter videos have fewer frames and would cause stack() to fail + clip_f = bundle["clip_features"] # [1, N_clip, 1024] + c_tgt = seq_cfg.clip_seq_len + if clip_f.shape[1] < c_tgt: + clip_f = F.pad(clip_f, (0, 0, 0, c_tgt - clip_f.shape[1])) + elif clip_f.shape[1] > c_tgt: + clip_f = clip_f[:, :c_tgt, :] + + sync_f = bundle["sync_features"] # [1, N_sync, 768] + s_tgt = seq_cfg.sync_seq_len + if sync_f.shape[1] < s_tgt: + sync_f = F.pad(sync_f, (0, 0, 0, s_tgt - sync_f.shape[1])) + elif sync_f.shape[1] > s_tgt: + sync_f = sync_f[:, :s_tgt, :] + + dataset.append((x1, clip_f, sync_f, text_clip)) except Exception as e: print(f" [LoRA Trainer] Warning: failed {npz_path.name}: {e}", flush=True) traceback.print_exc() diff --git a/train_lora.py b/train_lora.py index 4a5595e..3f1642a 100644 --- a/train_lora.py +++ b/train_lora.py @@ -284,7 +284,24 @@ def main(): elif x1.shape[1] > tgt: x1 = x1[:, :tgt, :] text_clip = encode_text_clip(clip_model, tokenizer_clip, [prompt], device).cpu() - dataset.append((x1, bundle["clip_features"], bundle["sync_features"], text_clip)) + + # Pad/trim clip and sync features to fixed seq lengths — shorter clips + # have fewer frames and would cause stack() to fail during batching + clip_f = bundle["clip_features"] # [1, N_clip, 1024] + c_tgt = seq_cfg.clip_seq_len + if clip_f.shape[1] < c_tgt: + clip_f = F.pad(clip_f, (0, 0, 0, c_tgt - clip_f.shape[1])) + elif clip_f.shape[1] > c_tgt: + clip_f = clip_f[:, :c_tgt, :] + + sync_f = bundle["sync_features"] # [1, N_sync, 768] + s_tgt = seq_cfg.sync_seq_len + if sync_f.shape[1] < s_tgt: + sync_f = F.pad(sync_f, (0, 0, 0, s_tgt - sync_f.shape[1])) + elif sync_f.shape[1] > s_tgt: + sync_f = sync_f[:, :s_tgt, :] + + dataset.append((x1, clip_f, sync_f, text_clip)) except Exception as e: print(f" [LoRA] Warning: failed to process {npz_path.name}: {e}")