diff --git a/nodes/selva_lora_trainer.py b/nodes/selva_lora_trainer.py index b1a8253..bce709f 100644 --- a/nodes/selva_lora_trainer.py +++ b/nodes/selva_lora_trainer.py @@ -271,11 +271,11 @@ class SelvaLoraTrainer: "tooltip": "Path to a step checkpoint (.pt) to resume training from.", }), "seed": ("INT", {"default": 42}), - "timestep_mode": (["logit_normal", "uniform"], { - "default": "logit_normal", + "timestep_mode": (["uniform", "logit_normal"], { + "default": "uniform", "tooltip": "How to sample training timesteps. " - "logit_normal concentrates steps near t=0.5 (recommended — reduces white noise artifacts). " - "uniform samples all timesteps equally (original behavior).", + "uniform samples all timesteps equally (default, matches original MMAudio training). " + "logit_normal concentrates steps near t=0.5 — reaches lower loss but perceptual improvement is dataset-dependent.", }), "logit_normal_sigma": ("FLOAT", { "default": 1.0, "min": 0.1, "max": 3.0, "step": 0.1, @@ -305,7 +305,7 @@ class SelvaLoraTrainer: def train(self, model, data_dir, output_dir, steps, rank, lr, alpha=0.0, target="attn.qkv", batch_size=4, warmup_steps=100, grad_accum=1, save_every=500, resume_path="", seed=42, - timestep_mode="logit_normal", logit_normal_sigma=1.0): + timestep_mode="uniform", logit_normal_sigma=1.0): torch.manual_seed(seed) random.seed(seed) @@ -451,7 +451,7 @@ class SelvaLoraTrainer: data_dir, output_dir, steps, rank, lr, alpha_val, target_suffixes, batch_size, warmup_steps, grad_accum, save_every, resume_path, seed, - timestep_mode="logit_normal", logit_normal_sigma=1.0, + timestep_mode="uniform", logit_normal_sigma=1.0, ): # --- Prepare generator copy with LoRA --- generator = copy.deepcopy(model["generator"]).to(device, dtype) diff --git a/train_lora.py b/train_lora.py index a603470..8203d67 100644 --- a/train_lora.py +++ b/train_lora.py @@ -167,8 +167,8 @@ def main(): help="Path to a step checkpoint (.pt) to resume training from.") parser.add_argument("--precision", default="bf16", choices=["bf16", "fp16", "fp32"]) parser.add_argument("--seed", type=int, default=42) - parser.add_argument("--timestep_mode", default="logit_normal", choices=["logit_normal", "uniform"], - help="Timestep sampling distribution. logit_normal reduces white noise artifacts.") + parser.add_argument("--timestep_mode", default="uniform", choices=["uniform", "logit_normal"], + help="Timestep sampling distribution. uniform matches original MMAudio training. logit_normal reaches lower loss but perceptual improvement is dataset-dependent.") parser.add_argument("--logit_normal_sigma", type=float, default=1.0, help="Spread of logit-normal distribution (only used with --timestep_mode logit_normal).") args = parser.parse_args()