feat: add LoRA dropout, LoRA+ asymmetric LR, and curriculum timestep sampling
- LoRA dropout: applied to the LoRA path only (not frozen base weights), 0.05–0.1 helps regularize on small datasets (arXiv:2404.09610) - LoRA+: separate optimizer param groups for lora_A and lora_B with configurable LR ratio; ratio=16 enables LoRA+ (arXiv:2402.12354) - Curriculum mode: logit_normal for first N% of steps then uniform, directly addresses early convergence + fine-detail degradation at boundaries (arXiv:2603.12517) Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
+54
-13
@@ -271,17 +271,33 @@ class SelvaLoraTrainer:
|
|||||||
"tooltip": "Path to a step checkpoint (.pt) to resume training from.",
|
"tooltip": "Path to a step checkpoint (.pt) to resume training from.",
|
||||||
}),
|
}),
|
||||||
"seed": ("INT", {"default": 42}),
|
"seed": ("INT", {"default": 42}),
|
||||||
"timestep_mode": (["uniform", "logit_normal"], {
|
"timestep_mode": (["uniform", "logit_normal", "curriculum"], {
|
||||||
"default": "uniform",
|
"default": "uniform",
|
||||||
"tooltip": "How to sample training timesteps. "
|
"tooltip": "How to sample training timesteps. "
|
||||||
"uniform samples all timesteps equally (default, matches original MMAudio training). "
|
"uniform: all timesteps equally (matches original MMAudio). "
|
||||||
"logit_normal concentrates steps near t=0.5 — reaches lower loss but perceptual improvement is dataset-dependent.",
|
"logit_normal: concentrates near t=0.5. "
|
||||||
|
"curriculum: logit_normal for first curriculum_switch% of steps then uniform (recommended for small datasets).",
|
||||||
}),
|
}),
|
||||||
"logit_normal_sigma": ("FLOAT", {
|
"logit_normal_sigma": ("FLOAT", {
|
||||||
"default": 1.0, "min": 0.1, "max": 3.0, "step": 0.1,
|
"default": 1.0, "min": 0.1, "max": 3.0, "step": 0.1,
|
||||||
"tooltip": "Spread of the logit-normal distribution. "
|
"tooltip": "Spread of the logit-normal distribution. "
|
||||||
"1.0 = moderate peak at t=0.5. Higher approaches uniform. "
|
"1.0 = moderate peak at t=0.5. Higher approaches uniform. "
|
||||||
"Only used when timestep_mode=logit_normal.",
|
"Used with logit_normal and curriculum modes.",
|
||||||
|
}),
|
||||||
|
"curriculum_switch": ("FLOAT", {
|
||||||
|
"default": 0.6, "min": 0.1, "max": 0.9, "step": 0.05,
|
||||||
|
"tooltip": "Fraction of steps to run logit_normal before switching to uniform. "
|
||||||
|
"0.6 = switch at 60% of total steps. Only used with timestep_mode=curriculum.",
|
||||||
|
}),
|
||||||
|
"lora_dropout": ("FLOAT", {
|
||||||
|
"default": 0.0, "min": 0.0, "max": 0.3, "step": 0.01,
|
||||||
|
"tooltip": "Dropout applied to the LoRA path only (not the frozen base weights). "
|
||||||
|
"0=disabled. 0.05–0.1 helps regularize on small datasets (arXiv:2404.09610).",
|
||||||
|
}),
|
||||||
|
"lora_plus_ratio": ("FLOAT", {
|
||||||
|
"default": 1.0, "min": 1.0, "max": 32.0, "step": 1.0,
|
||||||
|
"tooltip": "LoRA+ LR ratio: lr_B = lr × ratio. "
|
||||||
|
"1.0 = standard LoRA. 16.0 = LoRA+ (arXiv:2402.12354).",
|
||||||
}),
|
}),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -305,7 +321,8 @@ class SelvaLoraTrainer:
|
|||||||
def train(self, model, data_dir, output_dir, steps, rank, lr,
|
def train(self, model, data_dir, output_dir, steps, rank, lr,
|
||||||
alpha=0.0, target="attn.qkv", batch_size=4, warmup_steps=100,
|
alpha=0.0, target="attn.qkv", batch_size=4, warmup_steps=100,
|
||||||
grad_accum=1, save_every=500, resume_path="", seed=42,
|
grad_accum=1, save_every=500, resume_path="", seed=42,
|
||||||
timestep_mode="uniform", logit_normal_sigma=1.0):
|
timestep_mode="uniform", logit_normal_sigma=1.0, curriculum_switch=0.6,
|
||||||
|
lora_dropout=0.0, lora_plus_ratio=1.0):
|
||||||
|
|
||||||
torch.manual_seed(seed)
|
torch.manual_seed(seed)
|
||||||
random.seed(seed)
|
random.seed(seed)
|
||||||
@@ -442,7 +459,8 @@ class SelvaLoraTrainer:
|
|||||||
data_dir, output_dir, steps, rank, lr,
|
data_dir, output_dir, steps, rank, lr,
|
||||||
alpha_val, target_suffixes, batch_size, warmup_steps,
|
alpha_val, target_suffixes, batch_size, warmup_steps,
|
||||||
grad_accum, save_every, resume_path, seed,
|
grad_accum, save_every, resume_path, seed,
|
||||||
timestep_mode, logit_normal_sigma,
|
timestep_mode, logit_normal_sigma, curriculum_switch,
|
||||||
|
lora_dropout, lora_plus_ratio,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _train_inner(
|
def _train_inner(
|
||||||
@@ -451,19 +469,21 @@ class SelvaLoraTrainer:
|
|||||||
data_dir, output_dir, steps, rank, lr,
|
data_dir, output_dir, steps, rank, lr,
|
||||||
alpha_val, target_suffixes, batch_size, warmup_steps,
|
alpha_val, target_suffixes, batch_size, warmup_steps,
|
||||||
grad_accum, save_every, resume_path, seed,
|
grad_accum, save_every, resume_path, seed,
|
||||||
timestep_mode="uniform", logit_normal_sigma=1.0,
|
timestep_mode="uniform", logit_normal_sigma=1.0, curriculum_switch=0.6,
|
||||||
|
lora_dropout=0.0, lora_plus_ratio=1.0,
|
||||||
):
|
):
|
||||||
# --- Prepare generator copy with LoRA ---
|
# --- Prepare generator copy with LoRA ---
|
||||||
generator = copy.deepcopy(model["generator"]).to(device, dtype)
|
generator = copy.deepcopy(model["generator"]).to(device, dtype)
|
||||||
|
|
||||||
n_lora = apply_lora(generator, rank=rank, alpha=alpha_val,
|
n_lora = apply_lora(generator, rank=rank, alpha=alpha_val,
|
||||||
target_suffixes=target_suffixes)
|
target_suffixes=target_suffixes, dropout=lora_dropout)
|
||||||
if n_lora == 0:
|
if n_lora == 0:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"[LoRA Trainer] No layers matched target={target_suffixes}. "
|
f"[LoRA Trainer] No layers matched target={target_suffixes}. "
|
||||||
"Check the 'target' field."
|
"Check the 'target' field."
|
||||||
)
|
)
|
||||||
print(f"[LoRA Trainer] Wrapped {n_lora} layers (rank={rank}, alpha={alpha_val})", flush=True)
|
print(f"[LoRA Trainer] Wrapped {n_lora} layers "
|
||||||
|
f"(rank={rank}, alpha={alpha_val}, dropout={lora_dropout})", flush=True)
|
||||||
|
|
||||||
for name, p in generator.named_parameters():
|
for name, p in generator.named_parameters():
|
||||||
p.requires_grad_("lora_" in name)
|
p.requires_grad_("lora_" in name)
|
||||||
@@ -475,8 +495,16 @@ class SelvaLoraTrainer:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# --- Optimizer + scheduler ---
|
# --- Optimizer + scheduler ---
|
||||||
lora_params = [p for p in generator.parameters() if p.requires_grad]
|
# LoRA+: split A and B into separate param groups with different LRs.
|
||||||
optimizer = torch.optim.AdamW(lora_params, lr=lr, weight_decay=1e-2)
|
# ratio=1.0 = standard LoRA (same LR for both). ratio=16 = LoRA+.
|
||||||
|
lora_A_params = [p for n, p in generator.named_parameters() if "lora_A" in n and p.requires_grad]
|
||||||
|
lora_B_params = [p for n, p in generator.named_parameters() if "lora_B" in n and p.requires_grad]
|
||||||
|
optimizer = torch.optim.AdamW([
|
||||||
|
{"params": lora_A_params, "lr": lr},
|
||||||
|
{"params": lora_B_params, "lr": lr * lora_plus_ratio},
|
||||||
|
], weight_decay=1e-2)
|
||||||
|
if lora_plus_ratio != 1.0:
|
||||||
|
print(f"[LoRA Trainer] LoRA+: lr_A={lr:.2e} lr_B={lr * lora_plus_ratio:.2e}", flush=True)
|
||||||
|
|
||||||
def lr_lambda(s):
|
def lr_lambda(s):
|
||||||
return s / max(1, warmup_steps) if s < warmup_steps else 1.0
|
return s / max(1, warmup_steps) if s < warmup_steps else 1.0
|
||||||
@@ -518,8 +546,15 @@ class SelvaLoraTrainer:
|
|||||||
"steps": steps,
|
"steps": steps,
|
||||||
"timestep_mode": timestep_mode,
|
"timestep_mode": timestep_mode,
|
||||||
"logit_normal_sigma": logit_normal_sigma,
|
"logit_normal_sigma": logit_normal_sigma,
|
||||||
|
"curriculum_switch": curriculum_switch,
|
||||||
|
"lora_dropout": lora_dropout,
|
||||||
|
"lora_plus_ratio": lora_plus_ratio,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# For curriculum mode: compute the step at which we switch from logit_normal to uniform
|
||||||
|
curriculum_switch_step = start_step + int((steps - start_step) * curriculum_switch)
|
||||||
|
_curriculum_switched = False
|
||||||
|
|
||||||
print(f"\n[LoRA Trainer] Training {remaining} steps "
|
print(f"\n[LoRA Trainer] Training {remaining} steps "
|
||||||
f"(step {start_step + 1} → {steps}, batch_size={batch_size}, "
|
f"(step {start_step + 1} → {steps}, batch_size={batch_size}, "
|
||||||
f"timestep_mode={timestep_mode})\n", flush=True)
|
f"timestep_mode={timestep_mode})\n", flush=True)
|
||||||
@@ -538,11 +573,17 @@ class SelvaLoraTrainer:
|
|||||||
|
|
||||||
generator.normalize(x1)
|
generator.normalize(x1)
|
||||||
|
|
||||||
if timestep_mode == "logit_normal":
|
if timestep_mode == "logit_normal" or (
|
||||||
|
timestep_mode == "curriculum" and step <= curriculum_switch_step
|
||||||
|
):
|
||||||
u = torch.randn(batch_size, device=device, dtype=dtype) * logit_normal_sigma
|
u = torch.randn(batch_size, device=device, dtype=dtype) * logit_normal_sigma
|
||||||
t = torch.sigmoid(u)
|
t = torch.sigmoid(u)
|
||||||
else:
|
else:
|
||||||
t = torch.rand(batch_size, device=device, dtype=dtype)
|
t = torch.rand(batch_size, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
if timestep_mode == "curriculum" and step == curriculum_switch_step + 1 and not _curriculum_switched:
|
||||||
|
print(f"[LoRA Trainer] Curriculum switch: logit_normal → uniform at step {step}", flush=True)
|
||||||
|
_curriculum_switched = True
|
||||||
x0 = torch.randn_like(x1)
|
x0 = torch.randn_like(x1)
|
||||||
xt = fm.get_conditional_flow(x0, x1, t)
|
xt = fm.get_conditional_flow(x0, x1, t)
|
||||||
|
|
||||||
@@ -552,7 +593,7 @@ class SelvaLoraTrainer:
|
|||||||
running_loss += loss.item() * grad_accum
|
running_loss += loss.item() * grad_accum
|
||||||
|
|
||||||
if step % grad_accum == 0:
|
if step % grad_accum == 0:
|
||||||
torch.nn.utils.clip_grad_norm_(lora_params, max_norm=1.0)
|
torch.nn.utils.clip_grad_norm_(lora_A_params + lora_B_params, max_norm=1.0)
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
scheduler.step()
|
scheduler.step()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|||||||
@@ -25,13 +25,14 @@ import torch.nn as nn
|
|||||||
class LoRALinear(nn.Module):
|
class LoRALinear(nn.Module):
|
||||||
"""nn.Linear with a frozen base weight and trainable low-rank A/B matrices.
|
"""nn.Linear with a frozen base weight and trainable low-rank A/B matrices.
|
||||||
|
|
||||||
Output: base(x) + (x @ A.T @ B.T) * (alpha / rank)
|
Output: base(x) + (dropout(x) @ A.T @ B.T) * (alpha / rank)
|
||||||
|
|
||||||
A is initialised with Kaiming uniform; B is initialised to zero so the
|
A is initialised with Kaiming uniform; B is initialised to zero so the
|
||||||
adapter contribution starts at zero and does not disturb pretrained behaviour.
|
adapter contribution starts at zero and does not disturb pretrained behaviour.
|
||||||
|
Dropout is applied only to the LoRA path, not the base linear.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, linear: nn.Linear, rank: int, alpha: float):
|
def __init__(self, linear: nn.Linear, rank: int, alpha: float, dropout: float = 0.0):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
in_f = linear.in_features
|
in_f = linear.in_features
|
||||||
out_f = linear.out_features
|
out_f = linear.out_features
|
||||||
@@ -46,16 +47,18 @@ class LoRALinear(nn.Module):
|
|||||||
self.lora_A = nn.Parameter(torch.empty(rank, in_f, dtype=ref_dtype, device=ref_device))
|
self.lora_A = nn.Parameter(torch.empty(rank, in_f, dtype=ref_dtype, device=ref_device))
|
||||||
self.lora_B = nn.Parameter(torch.zeros(out_f, rank, dtype=ref_dtype, device=ref_device))
|
self.lora_B = nn.Parameter(torch.zeros(out_f, rank, dtype=ref_dtype, device=ref_device))
|
||||||
self.scale = alpha / rank
|
self.scale = alpha / rank
|
||||||
|
self.dropout = nn.Dropout(p=dropout) if dropout > 0.0 else nn.Identity()
|
||||||
|
|
||||||
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
|
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
return self.linear(x) + (x @ self.lora_A.T @ self.lora_B.T) * self.scale
|
return self.linear(x) + (self.dropout(x) @ self.lora_A.T @ self.lora_B.T) * self.scale
|
||||||
|
|
||||||
def extra_repr(self) -> str:
|
def extra_repr(self) -> str:
|
||||||
rank = self.lora_A.shape[0]
|
rank = self.lora_A.shape[0]
|
||||||
|
p = self.dropout.p if isinstance(self.dropout, nn.Dropout) else 0.0
|
||||||
return (f"in={self.linear.in_features}, out={self.linear.out_features}, "
|
return (f"in={self.linear.in_features}, out={self.linear.out_features}, "
|
||||||
f"rank={rank}, scale={self.scale:.4f}")
|
f"rank={rank}, scale={self.scale:.4f}, dropout={p}")
|
||||||
|
|
||||||
|
|
||||||
def apply_lora(
|
def apply_lora(
|
||||||
@@ -63,6 +66,7 @@ def apply_lora(
|
|||||||
rank: int = 16,
|
rank: int = 16,
|
||||||
alpha: float = None,
|
alpha: float = None,
|
||||||
target_suffixes: tuple = ("attn.qkv",),
|
target_suffixes: tuple = ("attn.qkv",),
|
||||||
|
dropout: float = 0.0,
|
||||||
) -> int:
|
) -> int:
|
||||||
"""Replace matching nn.Linear layers with LoRALinear in-place.
|
"""Replace matching nn.Linear layers with LoRALinear in-place.
|
||||||
|
|
||||||
@@ -74,6 +78,8 @@ def apply_lora(
|
|||||||
("attn.qkv",) which targets all SelfAttention QKV
|
("attn.qkv",) which targets all SelfAttention QKV
|
||||||
projections in the MM-DiT generator.
|
projections in the MM-DiT generator.
|
||||||
Add "linear1" to also wrap post-attention output projections.
|
Add "linear1" to also wrap post-attention output projections.
|
||||||
|
dropout: Dropout probability on the LoRA path (not the base linear).
|
||||||
|
0.05–0.1 helps regularize on small datasets.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Number of linear layers wrapped.
|
Number of linear layers wrapped.
|
||||||
@@ -92,7 +98,7 @@ def apply_lora(
|
|||||||
parent = model
|
parent = model
|
||||||
for part in parts[:-1]:
|
for part in parts[:-1]:
|
||||||
parent = getattr(parent, part)
|
parent = getattr(parent, part)
|
||||||
setattr(parent, parts[-1], LoRALinear(module, rank, alpha))
|
setattr(parent, parts[-1], LoRALinear(module, rank, alpha, dropout=dropout))
|
||||||
count += 1
|
count += 1
|
||||||
|
|
||||||
return count
|
return count
|
||||||
|
|||||||
+39
-8
@@ -167,10 +167,16 @@ def main():
|
|||||||
help="Path to a step checkpoint (.pt) to resume training from.")
|
help="Path to a step checkpoint (.pt) to resume training from.")
|
||||||
parser.add_argument("--precision", default="bf16", choices=["bf16", "fp16", "fp32"])
|
parser.add_argument("--precision", default="bf16", choices=["bf16", "fp16", "fp32"])
|
||||||
parser.add_argument("--seed", type=int, default=42)
|
parser.add_argument("--seed", type=int, default=42)
|
||||||
parser.add_argument("--timestep_mode", default="uniform", choices=["uniform", "logit_normal"],
|
parser.add_argument("--timestep_mode", default="uniform", choices=["uniform", "logit_normal", "curriculum"],
|
||||||
help="Timestep sampling distribution. uniform matches original MMAudio training. logit_normal reaches lower loss but perceptual improvement is dataset-dependent.")
|
help="Timestep sampling. uniform=original MMAudio, logit_normal=concentrated near t=0.5, curriculum=logit_normal then uniform.")
|
||||||
parser.add_argument("--logit_normal_sigma", type=float, default=1.0,
|
parser.add_argument("--logit_normal_sigma", type=float, default=1.0,
|
||||||
help="Spread of logit-normal distribution (only used with --timestep_mode logit_normal).")
|
help="Spread of logit-normal distribution.")
|
||||||
|
parser.add_argument("--curriculum_switch", type=float, default=0.6,
|
||||||
|
help="Fraction of steps to use logit_normal before switching to uniform (curriculum mode only).")
|
||||||
|
parser.add_argument("--lora_dropout", type=float, default=0.0,
|
||||||
|
help="Dropout on the LoRA path only. 0.05–0.1 helps on small datasets.")
|
||||||
|
parser.add_argument("--lora_plus_ratio", type=float, default=1.0,
|
||||||
|
help="LoRA+ LR ratio: lr_B = lr * ratio. 1.0=standard, 16.0=LoRA+.")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
torch.manual_seed(args.seed)
|
torch.manual_seed(args.seed)
|
||||||
@@ -234,8 +240,9 @@ def main():
|
|||||||
rank=args.rank,
|
rank=args.rank,
|
||||||
alpha=args.alpha,
|
alpha=args.alpha,
|
||||||
target_suffixes=tuple(args.target),
|
target_suffixes=tuple(args.target),
|
||||||
|
dropout=args.lora_dropout,
|
||||||
)
|
)
|
||||||
print(f"[LoRA] Wrapped {n_lora} linear layers (rank={args.rank}, target={args.target})")
|
print(f"[LoRA] Wrapped {n_lora} linear layers (rank={args.rank}, target={args.target}, dropout={args.lora_dropout})")
|
||||||
if n_lora == 0:
|
if n_lora == 0:
|
||||||
print("[LoRA] ERROR: no layers were wrapped — check --target names.")
|
print("[LoRA] ERROR: no layers were wrapped — check --target names.")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
@@ -315,8 +322,16 @@ def main():
|
|||||||
print(f"[LoRA] {len(dataset)} clip(s) ready.")
|
print(f"[LoRA] {len(dataset)} clip(s) ready.")
|
||||||
|
|
||||||
# --- Optimizer + LR scheduler ---
|
# --- Optimizer + LR scheduler ---
|
||||||
lora_params = [p for p in net_generator.parameters() if p.requires_grad]
|
# LoRA+: separate param groups for A and B with different LRs.
|
||||||
optimizer = torch.optim.AdamW(lora_params, lr=args.lr, weight_decay=1e-2)
|
# ratio=1.0 = standard LoRA. ratio=16 = LoRA+ (arXiv:2402.12354).
|
||||||
|
lora_A_params = [p for n, p in net_generator.named_parameters() if "lora_A" in n and p.requires_grad]
|
||||||
|
lora_B_params = [p for n, p in net_generator.named_parameters() if "lora_B" in n and p.requires_grad]
|
||||||
|
optimizer = torch.optim.AdamW([
|
||||||
|
{"params": lora_A_params, "lr": args.lr},
|
||||||
|
{"params": lora_B_params, "lr": args.lr * args.lora_plus_ratio},
|
||||||
|
], weight_decay=1e-2)
|
||||||
|
if args.lora_plus_ratio != 1.0:
|
||||||
|
print(f"[LoRA] LoRA+: lr_A={args.lr:.2e} lr_B={args.lr * args.lora_plus_ratio:.2e}")
|
||||||
|
|
||||||
def lr_lambda(step):
|
def lr_lambda(step):
|
||||||
if step < args.warmup_steps:
|
if step < args.warmup_steps:
|
||||||
@@ -351,6 +366,9 @@ def main():
|
|||||||
f"batch_size={args.batch_size}, lr={args.lr}, grad_accum={args.grad_accum}")
|
f"batch_size={args.batch_size}, lr={args.lr}, grad_accum={args.grad_accum}")
|
||||||
print(f"[LoRA] Checkpoints every {args.save_every} steps → {output_dir}\n")
|
print(f"[LoRA] Checkpoints every {args.save_every} steps → {output_dir}\n")
|
||||||
|
|
||||||
|
curriculum_switch_step = start_step + int((args.steps - start_step) * args.curriculum_switch)
|
||||||
|
_curriculum_switched = False
|
||||||
|
|
||||||
total_loss = 0.0
|
total_loss = 0.0
|
||||||
for step in range(start_step + 1, args.steps + 1):
|
for step in range(start_step + 1, args.steps + 1):
|
||||||
batch = random.choices(dataset, k=args.batch_size)
|
batch = random.choices(dataset, k=args.batch_size)
|
||||||
@@ -363,11 +381,18 @@ def main():
|
|||||||
|
|
||||||
net_generator.normalize(x1)
|
net_generator.normalize(x1)
|
||||||
|
|
||||||
if args.timestep_mode == "logit_normal":
|
if args.timestep_mode == "logit_normal" or (
|
||||||
|
args.timestep_mode == "curriculum" and step <= curriculum_switch_step
|
||||||
|
):
|
||||||
u = torch.randn(args.batch_size, device=device, dtype=dtype) * args.logit_normal_sigma
|
u = torch.randn(args.batch_size, device=device, dtype=dtype) * args.logit_normal_sigma
|
||||||
t = torch.sigmoid(u)
|
t = torch.sigmoid(u)
|
||||||
else:
|
else:
|
||||||
t = torch.rand(args.batch_size, device=device, dtype=dtype)
|
t = torch.rand(args.batch_size, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
if args.timestep_mode == "curriculum" and step == curriculum_switch_step + 1 and not _curriculum_switched:
|
||||||
|
print(f"[LoRA] Curriculum switch: logit_normal → uniform at step {step}")
|
||||||
|
_curriculum_switched = True
|
||||||
|
|
||||||
x0 = torch.randn_like(x1)
|
x0 = torch.randn_like(x1)
|
||||||
xt = fm.get_conditional_flow(x0, x1, t)
|
xt = fm.get_conditional_flow(x0, x1, t)
|
||||||
|
|
||||||
@@ -378,7 +403,7 @@ def main():
|
|||||||
total_loss += loss.item() * args.grad_accum
|
total_loss += loss.item() * args.grad_accum
|
||||||
|
|
||||||
if step % args.grad_accum == 0:
|
if step % args.grad_accum == 0:
|
||||||
torch.nn.utils.clip_grad_norm_(lora_params, max_norm=1.0)
|
torch.nn.utils.clip_grad_norm_(lora_A_params + lora_B_params, max_norm=1.0)
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
scheduler.step()
|
scheduler.step()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
@@ -404,6 +429,9 @@ def main():
|
|||||||
"steps": args.steps,
|
"steps": args.steps,
|
||||||
"timestep_mode": args.timestep_mode,
|
"timestep_mode": args.timestep_mode,
|
||||||
"logit_normal_sigma": args.logit_normal_sigma,
|
"logit_normal_sigma": args.logit_normal_sigma,
|
||||||
|
"curriculum_switch": args.curriculum_switch,
|
||||||
|
"lora_dropout": args.lora_dropout,
|
||||||
|
"lora_plus_ratio": args.lora_plus_ratio,
|
||||||
},
|
},
|
||||||
}, ckpt_path)
|
}, ckpt_path)
|
||||||
print(f"[LoRA] Saved {ckpt_path}")
|
print(f"[LoRA] Saved {ckpt_path}")
|
||||||
@@ -424,6 +452,9 @@ def main():
|
|||||||
"steps": args.steps,
|
"steps": args.steps,
|
||||||
"timestep_mode": args.timestep_mode,
|
"timestep_mode": args.timestep_mode,
|
||||||
"logit_normal_sigma": args.logit_normal_sigma,
|
"logit_normal_sigma": args.logit_normal_sigma,
|
||||||
|
"curriculum_switch": args.curriculum_switch,
|
||||||
|
"lora_dropout": args.lora_dropout,
|
||||||
|
"lora_plus_ratio": args.lora_plus_ratio,
|
||||||
}
|
}
|
||||||
torch.save({"state_dict": get_lora_state_dict(net_generator), "meta": meta}, final)
|
torch.save({"state_dict": get_lora_state_dict(net_generator), "meta": meta}, final)
|
||||||
(output_dir / "meta.json").write_text(json.dumps(meta, indent=2))
|
(output_dir / "meta.json").write_text(json.dumps(meta, indent=2))
|
||||||
|
|||||||
Reference in New Issue
Block a user