chore: default timestep_mode back to uniform
logit_normal reaches lower loss but perceptual improvement over uniform is dataset-dependent. Keeping uniform as default to match original MMAudio training behavior; logit_normal remains available as an option. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -271,11 +271,11 @@ class SelvaLoraTrainer:
|
|||||||
"tooltip": "Path to a step checkpoint (.pt) to resume training from.",
|
"tooltip": "Path to a step checkpoint (.pt) to resume training from.",
|
||||||
}),
|
}),
|
||||||
"seed": ("INT", {"default": 42}),
|
"seed": ("INT", {"default": 42}),
|
||||||
"timestep_mode": (["logit_normal", "uniform"], {
|
"timestep_mode": (["uniform", "logit_normal"], {
|
||||||
"default": "logit_normal",
|
"default": "uniform",
|
||||||
"tooltip": "How to sample training timesteps. "
|
"tooltip": "How to sample training timesteps. "
|
||||||
"logit_normal concentrates steps near t=0.5 (recommended — reduces white noise artifacts). "
|
"uniform samples all timesteps equally (default, matches original MMAudio training). "
|
||||||
"uniform samples all timesteps equally (original behavior).",
|
"logit_normal concentrates steps near t=0.5 — reaches lower loss but perceptual improvement is dataset-dependent.",
|
||||||
}),
|
}),
|
||||||
"logit_normal_sigma": ("FLOAT", {
|
"logit_normal_sigma": ("FLOAT", {
|
||||||
"default": 1.0, "min": 0.1, "max": 3.0, "step": 0.1,
|
"default": 1.0, "min": 0.1, "max": 3.0, "step": 0.1,
|
||||||
@@ -305,7 +305,7 @@ class SelvaLoraTrainer:
|
|||||||
def train(self, model, data_dir, output_dir, steps, rank, lr,
|
def train(self, model, data_dir, output_dir, steps, rank, lr,
|
||||||
alpha=0.0, target="attn.qkv", batch_size=4, warmup_steps=100,
|
alpha=0.0, target="attn.qkv", batch_size=4, warmup_steps=100,
|
||||||
grad_accum=1, save_every=500, resume_path="", seed=42,
|
grad_accum=1, save_every=500, resume_path="", seed=42,
|
||||||
timestep_mode="logit_normal", logit_normal_sigma=1.0):
|
timestep_mode="uniform", logit_normal_sigma=1.0):
|
||||||
|
|
||||||
torch.manual_seed(seed)
|
torch.manual_seed(seed)
|
||||||
random.seed(seed)
|
random.seed(seed)
|
||||||
@@ -451,7 +451,7 @@ class SelvaLoraTrainer:
|
|||||||
data_dir, output_dir, steps, rank, lr,
|
data_dir, output_dir, steps, rank, lr,
|
||||||
alpha_val, target_suffixes, batch_size, warmup_steps,
|
alpha_val, target_suffixes, batch_size, warmup_steps,
|
||||||
grad_accum, save_every, resume_path, seed,
|
grad_accum, save_every, resume_path, seed,
|
||||||
timestep_mode="logit_normal", logit_normal_sigma=1.0,
|
timestep_mode="uniform", logit_normal_sigma=1.0,
|
||||||
):
|
):
|
||||||
# --- Prepare generator copy with LoRA ---
|
# --- Prepare generator copy with LoRA ---
|
||||||
generator = copy.deepcopy(model["generator"]).to(device, dtype)
|
generator = copy.deepcopy(model["generator"]).to(device, dtype)
|
||||||
|
|||||||
+2
-2
@@ -167,8 +167,8 @@ def main():
|
|||||||
help="Path to a step checkpoint (.pt) to resume training from.")
|
help="Path to a step checkpoint (.pt) to resume training from.")
|
||||||
parser.add_argument("--precision", default="bf16", choices=["bf16", "fp16", "fp32"])
|
parser.add_argument("--precision", default="bf16", choices=["bf16", "fp16", "fp32"])
|
||||||
parser.add_argument("--seed", type=int, default=42)
|
parser.add_argument("--seed", type=int, default=42)
|
||||||
parser.add_argument("--timestep_mode", default="logit_normal", choices=["logit_normal", "uniform"],
|
parser.add_argument("--timestep_mode", default="uniform", choices=["uniform", "logit_normal"],
|
||||||
help="Timestep sampling distribution. logit_normal reduces white noise artifacts.")
|
help="Timestep sampling distribution. uniform matches original MMAudio training. logit_normal reaches lower loss but perceptual improvement is dataset-dependent.")
|
||||||
parser.add_argument("--logit_normal_sigma", type=float, default=1.0,
|
parser.add_argument("--logit_normal_sigma", type=float, default=1.0,
|
||||||
help="Spread of logit-normal distribution (only used with --timestep_mode logit_normal).")
|
help="Spread of logit-normal distribution (only used with --timestep_mode logit_normal).")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|||||||
Reference in New Issue
Block a user