feat: add inject_mode (suffix/prefix) to TI pipeline
Observation: n4_baseline loss barely moved (1.025→0.965 over 3000 steps), token_norm grew linearly without plateau — generator likely ignores last-K CLIP positions (EOS/padding zone) where suffix injects. Fix: add inject_mode parameter throughout the pipeline: - "suffix": replace last K positions (original behavior, model may ignore) - "prefix": replace positions 1:1+K right after BOS — highest attention weight in CLIP, much stronger gradient signal expected Changes: - selva_textual_inversion_trainer.py: _inject_tokens() helper centralises the torch.cat construction for both modes; used in training loop and eval; inject_mode stored in checkpoint files - selva_textual_inversion_loader.py: reads inject_mode from checkpoint, includes in TEXTUAL_INVERSION bundle - selva_sampler.py: uses _inject_tokens() via bundle's inject_mode field - selva_ti_scheduler.py: inject_mode in _PARAM_DEFAULTS, config, and _train_inner call - ti_sweep_1.json: updated with prefix_inject group (n4, n8, n4+warm); n4_baseline marked completed; suffix experiments retained for comparison Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
+29
-20
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"name": "ti_sweep_1",
|
"name": "ti_sweep_1",
|
||||||
"description": "First TI sweep: token count, learning rate, and warm init. All generator weights frozen throughout. Baseline = n_tokens=4, lr=1e-3, random init. Primary goal: find a working (n_tokens, lr) pair before optimising further.",
|
"description": "First TI sweep: inject position, token count, learning rate, and warm init. n4_baseline already completed (suffix, loss barely moved — model likely ignores last-K positions). Priority: prefix injection group.",
|
||||||
"data_dir": "/media/unraid/davinci/Selva/BJ/features",
|
"data_dir": "/media/unraid/davinci/Selva/BJ/features",
|
||||||
"output_root": "/media/unraid/davinci/Selva/BJ/experiment/ti_sweep_1",
|
"output_root": "/media/unraid/davinci/Selva/BJ/experiment/ti_sweep_1",
|
||||||
"base": {
|
"base": {
|
||||||
@@ -11,52 +11,61 @@
|
|||||||
"seed": 42,
|
"seed": 42,
|
||||||
"init_text": "",
|
"init_text": "",
|
||||||
"lr": 1e-3,
|
"lr": 1e-3,
|
||||||
"n_tokens": 4
|
"n_tokens": 4,
|
||||||
|
"inject_mode": "suffix"
|
||||||
},
|
},
|
||||||
"experiments": [
|
"experiments": [
|
||||||
|
|
||||||
{
|
{
|
||||||
"id": "n4_baseline",
|
"id": "n4_baseline",
|
||||||
"group": "token_count",
|
"group": "suffix_token_count",
|
||||||
"description": "4 tokens, lr=1e-3, random init. Primary reference point — all other experiments are measured against this."
|
"description": "4 tokens, suffix, lr=1e-3, random init. COMPLETED — loss 1.025→0.965, nearly flat. Token norm grew linearly to 3.2 with no plateau. Model appears to ignore last-K positions."
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": "n8",
|
"id": "n8",
|
||||||
"group": "token_count",
|
"group": "suffix_token_count",
|
||||||
"description": "8 tokens, lr=1e-3, random init. Double the capacity — does it capture more style or just overfit faster?",
|
"description": "8 tokens, suffix, lr=1e-3. More capacity — does it do better than n4_baseline?",
|
||||||
"n_tokens": 8
|
"n_tokens": 8
|
||||||
},
|
},
|
||||||
|
|
||||||
{
|
{
|
||||||
"id": "n16",
|
"id": "n4_prefix",
|
||||||
"group": "token_count",
|
"group": "prefix_inject",
|
||||||
"description": "16 tokens, lr=1e-3, random init. Maximum expressiveness — worth the extra convergence difficulty?",
|
"description": "4 tokens at positions 1:5 (after BOS). Prefix positions carry the highest attention weight in CLIP — should produce much stronger loss signal than suffix.",
|
||||||
"n_tokens": 16
|
"inject_mode": "prefix"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "n8_prefix",
|
||||||
|
"group": "prefix_inject",
|
||||||
|
"description": "8 tokens at prefix positions. More capacity + high-attention positions.",
|
||||||
|
"n_tokens": 8,
|
||||||
|
"inject_mode": "prefix"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "n4_prefix_warm",
|
||||||
|
"group": "prefix_inject",
|
||||||
|
"description": "4 tokens, prefix, warm-started from 'mechanical impact sound design'. Best of both: semantically meaningful start + strong gradient signal.",
|
||||||
|
"inject_mode": "prefix",
|
||||||
|
"init_text": "mechanical impact sound design"
|
||||||
},
|
},
|
||||||
|
|
||||||
{
|
{
|
||||||
"id": "lr_5e4",
|
"id": "lr_5e4",
|
||||||
"group": "learning_rate",
|
"group": "learning_rate",
|
||||||
"description": "n_tokens=4, lr=5e-4. Half the default LR — smoother convergence, possibly better generalisation.",
|
"description": "4 tokens, suffix, lr=5e-4. Slower convergence — mainly a baseline comparison for the prefix group.",
|
||||||
"lr": 5e-4
|
"lr": 5e-4
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": "lr_2e3",
|
"id": "lr_2e3",
|
||||||
"group": "learning_rate",
|
"group": "learning_rate",
|
||||||
"description": "n_tokens=4, lr=2e-3. Double the default LR — faster early convergence, risk of oscillation.",
|
"description": "4 tokens, suffix, lr=2e-3. Faster early movement — does token norm plateau earlier?",
|
||||||
"lr": 2e-3
|
"lr": 2e-3
|
||||||
},
|
},
|
||||||
|
|
||||||
{
|
{
|
||||||
"id": "n4_warm",
|
"id": "n4_warm",
|
||||||
"group": "warm_init",
|
"group": "warm_init",
|
||||||
"description": "4 tokens warm-started from 'mechanical impact sound design'. CLIP embedding initialises tokens in a semantically relevant region of the space — may converge faster and to a better style representation.",
|
"description": "4 tokens, suffix, warm-started from 'mechanical impact sound design'.",
|
||||||
"init_text": "mechanical impact sound design"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "n8_warm",
|
|
||||||
"group": "warm_init",
|
|
||||||
"description": "8 tokens warm-started from 'mechanical impact sound design'. Combines the warm-init advantage with more expressive capacity.",
|
|
||||||
"n_tokens": 8,
|
|
||||||
"init_text": "mechanical impact sound design"
|
"init_text": "mechanical impact sound design"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ import comfy.utils
|
|||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
|
|
||||||
from .utils import SELVA_CATEGORY, get_device, get_offload_device, soft_empty_cache
|
from .utils import SELVA_CATEGORY, get_device, get_offload_device, soft_empty_cache
|
||||||
|
from .selva_textual_inversion_trainer import _inject_tokens
|
||||||
|
|
||||||
|
|
||||||
class SelvaSampler:
|
class SelvaSampler:
|
||||||
@@ -118,16 +119,15 @@ class SelvaSampler:
|
|||||||
neg_text_clip = feature_utils.encode_text_clip([negative_prompt]) \
|
neg_text_clip = feature_utils.encode_text_clip([negative_prompt]) \
|
||||||
if negative_prompt.strip() else None
|
if negative_prompt.strip() else None
|
||||||
|
|
||||||
# Inject textual inversion tokens into last K positions of CLIP embedding
|
# Inject textual inversion tokens into CLIP conditioning
|
||||||
if textual_inversion is not None:
|
if textual_inversion is not None:
|
||||||
emb = textual_inversion["embeddings"].to(device, dtype) # [K, 1024]
|
emb = textual_inversion["embeddings"].to(device, dtype) # [K, 1024]
|
||||||
K = emb.shape[0]
|
K = emb.shape[0]
|
||||||
text_clip = text_clip.clone()
|
inject_mode = textual_inversion.get("inject_mode", "suffix")
|
||||||
text_clip[:, -K:, :] = emb.unsqueeze(0)
|
text_clip = _inject_tokens(text_clip, emb, K, inject_mode)
|
||||||
if neg_text_clip is not None:
|
if neg_text_clip is not None:
|
||||||
neg_text_clip = neg_text_clip.clone()
|
neg_text_clip = _inject_tokens(neg_text_clip, emb, K, inject_mode)
|
||||||
neg_text_clip[:, -K:, :] = emb.unsqueeze(0)
|
print(f"[SelVA] Textual inversion: {K} tokens mode={inject_mode}",
|
||||||
print(f"[SelVA] Textual inversion: injected {K} tokens into CLIP conditioning",
|
|
||||||
flush=True)
|
flush=True)
|
||||||
|
|
||||||
conditions = net_generator.preprocess_conditions(clip_f, sync_f, text_clip)
|
conditions = net_generator.preprocess_conditions(clip_f, sync_f, text_clip)
|
||||||
|
|||||||
@@ -57,10 +57,14 @@ class SelvaTextualInversionLoader:
|
|||||||
print(f"[TI Loader] trained {data['step']} / {data.get('steps', '?')} steps "
|
print(f"[TI Loader] trained {data['step']} / {data.get('steps', '?')} steps "
|
||||||
f"lr={data.get('lr', '?')}", flush=True)
|
f"lr={data.get('lr', '?')}", flush=True)
|
||||||
|
|
||||||
|
inject_mode = data.get("inject_mode", "suffix")
|
||||||
|
print(f"[TI Loader] inject_mode='{inject_mode}'", flush=True)
|
||||||
|
|
||||||
bundle = {
|
bundle = {
|
||||||
"embeddings": embeddings, # [K, 1024] float32 on CPU
|
"embeddings": embeddings, # [K, 1024] float32 on CPU
|
||||||
"n_tokens": n_tokens,
|
"n_tokens": n_tokens,
|
||||||
"path": str(p),
|
"inject_mode": inject_mode,
|
||||||
"init_text": data.get("init_text", ""),
|
"path": str(p),
|
||||||
|
"init_text": data.get("init_text", ""),
|
||||||
}
|
}
|
||||||
return (bundle,)
|
return (bundle,)
|
||||||
|
|||||||
@@ -42,7 +42,29 @@ from .selva_lora_trainer import (
|
|||||||
# Eval helper with token injection
|
# Eval helper with token injection
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
def _eval_sample_ti(generator, learned_tokens, n_tokens,
|
def _inject_tokens(text_clip: torch.Tensor, tokens: torch.Tensor,
|
||||||
|
n_tokens: int, inject_mode: str) -> torch.Tensor:
|
||||||
|
"""Build a text_clip tensor with learned tokens injected.
|
||||||
|
|
||||||
|
inject_mode:
|
||||||
|
"suffix" — replace last n_tokens positions (EOS/padding zone)
|
||||||
|
"prefix" — replace positions 1:1+n_tokens (after BOS, before content)
|
||||||
|
|
||||||
|
Always uses torch.cat so gradient flows to `tokens` when tokens.requires_grad.
|
||||||
|
Works for both training (tokens is a Parameter) and eval (tokens is detached).
|
||||||
|
"""
|
||||||
|
if inject_mode == "prefix":
|
||||||
|
bos = text_clip[:, :1, :].detach() # [B, 1, D]
|
||||||
|
toks = tokens.unsqueeze(0).expand(text_clip.shape[0], -1, -1) # [B, K, D]
|
||||||
|
rest = text_clip[:, 1 + n_tokens:, :].detach() # [B, 75-K, D]
|
||||||
|
return torch.cat([bos, toks, rest], dim=1) # [B, 77, D]
|
||||||
|
else: # suffix (default)
|
||||||
|
front = text_clip[:, :-n_tokens, :].detach() # [B, 77-K, D]
|
||||||
|
toks = tokens.unsqueeze(0).expand(text_clip.shape[0], -1, -1) # [B, K, D]
|
||||||
|
return torch.cat([front, toks], dim=1) # [B, 77, D]
|
||||||
|
|
||||||
|
|
||||||
|
def _eval_sample_ti(generator, learned_tokens, n_tokens, inject_mode,
|
||||||
feature_utils_orig, dataset, seq_cfg,
|
feature_utils_orig, dataset, seq_cfg,
|
||||||
device, dtype, num_steps=25, seed=42, clip_idx=0):
|
device, dtype, num_steps=25, seed=42, clip_idx=0):
|
||||||
"""Inference pass with learned tokens injected into text conditioning."""
|
"""Inference pass with learned tokens injected into text conditioning."""
|
||||||
@@ -53,7 +75,8 @@ def _eval_sample_ti(generator, learned_tokens, n_tokens,
|
|||||||
sync_f = sync_f_cpu.to(device, dtype)
|
sync_f = sync_f_cpu.to(device, dtype)
|
||||||
text_clip = text_clip_cpu.to(device, dtype).clone()
|
text_clip = text_clip_cpu.to(device, dtype).clone()
|
||||||
|
|
||||||
text_clip[:, -n_tokens:, :] = learned_tokens.detach().unsqueeze(0).to(device, dtype)
|
emb = learned_tokens.detach().to(device, dtype)
|
||||||
|
text_input = _inject_tokens(text_clip, emb, n_tokens, inject_mode)
|
||||||
|
|
||||||
rng = torch.Generator(device=device).manual_seed(seed)
|
rng = torch.Generator(device=device).manual_seed(seed)
|
||||||
x0 = torch.randn(1, seq_cfg.latent_seq_len, generator.latent_dim,
|
x0 = torch.randn(1, seq_cfg.latent_seq_len, generator.latent_dim,
|
||||||
@@ -62,7 +85,7 @@ def _eval_sample_ti(generator, learned_tokens, n_tokens,
|
|||||||
eval_fm = FlowMatching(min_sigma=0, inference_mode="euler", num_steps=num_steps)
|
eval_fm = FlowMatching(min_sigma=0, inference_mode="euler", num_steps=num_steps)
|
||||||
|
|
||||||
def velocity_fn(t, x):
|
def velocity_fn(t, x):
|
||||||
return generator.forward(x, clip_f, sync_f, text_clip,
|
return generator.forward(x, clip_f, sync_f, text_input,
|
||||||
t.reshape(1).to(device, dtype))
|
t.reshape(1).to(device, dtype))
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
@@ -165,6 +188,14 @@ class SelvaTextualInversionTrainer:
|
|||||||
}),
|
}),
|
||||||
},
|
},
|
||||||
"optional": {
|
"optional": {
|
||||||
|
"inject_mode": (["suffix", "prefix"], {
|
||||||
|
"default": "suffix",
|
||||||
|
"tooltip": (
|
||||||
|
"Where to inject the learned tokens in the 77-token CLIP sequence. "
|
||||||
|
"'suffix' replaces the last K positions (EOS/padding — may be ignored by the model). "
|
||||||
|
"'prefix' replaces positions 1:1+K right after BOS — higher attention weight, stronger style signal."
|
||||||
|
),
|
||||||
|
}),
|
||||||
"init_text": ("STRING", {
|
"init_text": ("STRING", {
|
||||||
"default": "",
|
"default": "",
|
||||||
"tooltip": "Optional text phrase to warm-start token values via CLIP. Leave empty for random init (N(0, 0.02)). Example: 'industrial sound design'.",
|
"tooltip": "Optional text phrase to warm-start token values via CLIP. Leave empty for random init (N(0, 0.02)). Example: 'industrial sound design'.",
|
||||||
@@ -178,7 +209,7 @@ class SelvaTextualInversionTrainer:
|
|||||||
|
|
||||||
def train(self, model, data_dir, output_path, n_tokens, steps, lr,
|
def train(self, model, data_dir, output_path, n_tokens, steps, lr,
|
||||||
batch_size, seed, save_every,
|
batch_size, seed, save_every,
|
||||||
init_text="", warmup_steps=100):
|
inject_mode="suffix", init_text="", warmup_steps=100):
|
||||||
|
|
||||||
device = get_device()
|
device = get_device()
|
||||||
dtype = model["dtype"]
|
dtype = model["dtype"]
|
||||||
@@ -212,7 +243,7 @@ class SelvaTextualInversionTrainer:
|
|||||||
device, dtype, mode,
|
device, dtype, mode,
|
||||||
data_dir, out_path,
|
data_dir, out_path,
|
||||||
n_tokens, steps, lr, batch_size,
|
n_tokens, steps, lr, batch_size,
|
||||||
warmup_steps, seed, save_every, init_text,
|
warmup_steps, seed, save_every, init_text, inject_mode,
|
||||||
)
|
)
|
||||||
smoothed = _smooth_losses(r["loss_history"]) if r["loss_history"] else []
|
smoothed = _smooth_losses(r["loss_history"]) if r["loss_history"] else []
|
||||||
curve_img = _draw_loss_curve(r["loss_history"], log_interval=50, smoothed=smoothed)
|
curve_img = _draw_loss_curve(r["loss_history"], log_interval=50, smoothed=smoothed)
|
||||||
@@ -223,7 +254,7 @@ class SelvaTextualInversionTrainer:
|
|||||||
device, dtype, mode,
|
device, dtype, mode,
|
||||||
data_dir, out_path,
|
data_dir, out_path,
|
||||||
n_tokens, steps, lr, batch_size,
|
n_tokens, steps, lr, batch_size,
|
||||||
warmup_steps, seed, save_every, init_text,
|
warmup_steps, seed, save_every, init_text, inject_mode="suffix",
|
||||||
):
|
):
|
||||||
torch.manual_seed(seed)
|
torch.manual_seed(seed)
|
||||||
|
|
||||||
@@ -289,12 +320,8 @@ class SelvaTextualInversionTrainer:
|
|||||||
sync_f = torch.stack([x.squeeze(0) for x in sync_list]).to(device, dtype)
|
sync_f = torch.stack([x.squeeze(0) for x in sync_list]).to(device, dtype)
|
||||||
text_clip = torch.stack([x.squeeze(0) for x in text_list]).to(device, dtype).clone()
|
text_clip = torch.stack([x.squeeze(0) for x in text_list]).to(device, dtype).clone()
|
||||||
|
|
||||||
# Inject learned tokens into last n_tokens positions.
|
# Inject learned tokens — gradient flows via torch.cat (not in-place assignment).
|
||||||
# Must use torch.cat (not in-place assignment) so the computation graph
|
text_input = _inject_tokens(text_clip, learned_tokens, n_tokens, inject_mode)
|
||||||
# links text_input → learned_tokens and gradients flow correctly.
|
|
||||||
text_front = text_clip[:, :-n_tokens, :].detach() # [B, 77-K, D], no grad
|
|
||||||
tokens_expanded = learned_tokens.unsqueeze(0).expand(batch_size, -1, -1) # [B, K, D]
|
|
||||||
text_input = torch.cat([text_front, tokens_expanded], dim=1) # [B, 77, D] with grad
|
|
||||||
|
|
||||||
x1 = generator.normalize(x1)
|
x1 = generator.normalize(x1)
|
||||||
t = torch.rand(batch_size, device=device, dtype=dtype)
|
t = torch.rand(batch_size, device=device, dtype=dtype)
|
||||||
@@ -326,12 +353,13 @@ class SelvaTextualInversionTrainer:
|
|||||||
if step % save_every == 0 or step == steps:
|
if step % save_every == 0 or step == steps:
|
||||||
# Save checkpoint
|
# Save checkpoint
|
||||||
ckpt = {
|
ckpt = {
|
||||||
"embeddings": learned_tokens.detach().cpu(),
|
"embeddings": learned_tokens.detach().cpu(),
|
||||||
"n_tokens": n_tokens,
|
"n_tokens": n_tokens,
|
||||||
"step": step,
|
"inject_mode": inject_mode,
|
||||||
"init_text": init_text,
|
"step": step,
|
||||||
"lr": lr,
|
"init_text": init_text,
|
||||||
"steps": steps,
|
"lr": lr,
|
||||||
|
"steps": steps,
|
||||||
"loss_history": loss_history,
|
"loss_history": loss_history,
|
||||||
}
|
}
|
||||||
ckpt_path = ckpt_dir / f"step_{step:05d}.pt"
|
ckpt_path = ckpt_dir / f"step_{step:05d}.pt"
|
||||||
@@ -339,7 +367,7 @@ class SelvaTextualInversionTrainer:
|
|||||||
|
|
||||||
# Eval sample
|
# Eval sample
|
||||||
wav, sr = _eval_sample_ti(
|
wav, sr = _eval_sample_ti(
|
||||||
generator, learned_tokens, n_tokens,
|
generator, learned_tokens, n_tokens, inject_mode,
|
||||||
feature_utils_orig, dataset, seq_cfg,
|
feature_utils_orig, dataset, seq_cfg,
|
||||||
device, dtype, seed=seed,
|
device, dtype, seed=seed,
|
||||||
)
|
)
|
||||||
@@ -365,12 +393,13 @@ class SelvaTextualInversionTrainer:
|
|||||||
|
|
||||||
# --- Final save ---
|
# --- Final save ---
|
||||||
final = {
|
final = {
|
||||||
"embeddings": learned_tokens.detach().cpu(),
|
"embeddings": learned_tokens.detach().cpu(),
|
||||||
"n_tokens": n_tokens,
|
"n_tokens": n_tokens,
|
||||||
"step": steps,
|
"inject_mode": inject_mode,
|
||||||
"init_text": init_text,
|
"step": steps,
|
||||||
"lr": lr,
|
"init_text": init_text,
|
||||||
"steps": steps,
|
"lr": lr,
|
||||||
|
"steps": steps,
|
||||||
"loss_history": loss_history,
|
"loss_history": loss_history,
|
||||||
}
|
}
|
||||||
torch.save(final, str(out_path))
|
torch.save(final, str(out_path))
|
||||||
|
|||||||
@@ -82,6 +82,7 @@ _PARAM_DEFAULTS = {
|
|||||||
"seed": 42,
|
"seed": 42,
|
||||||
"save_every": 1000,
|
"save_every": 1000,
|
||||||
"init_text": "",
|
"init_text": "",
|
||||||
|
"inject_mode": "suffix",
|
||||||
}
|
}
|
||||||
|
|
||||||
_PALETTE = [
|
_PALETTE = [
|
||||||
@@ -356,6 +357,7 @@ class SelvaTiScheduler:
|
|||||||
seed = int(cfg["seed"])
|
seed = int(cfg["seed"])
|
||||||
save_every = int(cfg["save_every"])
|
save_every = int(cfg["save_every"])
|
||||||
init_text = str(cfg["init_text"])
|
init_text = str(cfg["init_text"])
|
||||||
|
inject_mode = str(cfg["inject_mode"])
|
||||||
|
|
||||||
output_dir = output_root / exp_id
|
output_dir = output_root / exp_id
|
||||||
output_dir.mkdir(parents=True, exist_ok=True)
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
@@ -365,7 +367,8 @@ class SelvaTiScheduler:
|
|||||||
if exp_desc:
|
if exp_desc:
|
||||||
print(f"[TI Scheduler] {exp_desc}", flush=True)
|
print(f"[TI Scheduler] {exp_desc}", flush=True)
|
||||||
print(f"[TI Scheduler] n_tokens={n_tokens} lr={lr:.2e} steps={steps} "
|
print(f"[TI Scheduler] n_tokens={n_tokens} lr={lr:.2e} steps={steps} "
|
||||||
f"batch_size={batch_size} warmup={warmup} seed={seed}", flush=True)
|
f"batch_size={batch_size} warmup={warmup} seed={seed} "
|
||||||
|
f"inject_mode={inject_mode}", flush=True)
|
||||||
if init_text:
|
if init_text:
|
||||||
print(f"[TI Scheduler] init_text='{init_text}'", flush=True)
|
print(f"[TI Scheduler] init_text='{init_text}'", flush=True)
|
||||||
|
|
||||||
@@ -381,6 +384,7 @@ class SelvaTiScheduler:
|
|||||||
"seed": seed,
|
"seed": seed,
|
||||||
"save_every": save_every,
|
"save_every": save_every,
|
||||||
"init_text": init_text,
|
"init_text": init_text,
|
||||||
|
"inject_mode": inject_mode,
|
||||||
},
|
},
|
||||||
"results": {"status": "running"},
|
"results": {"status": "running"},
|
||||||
"embeddings_path": None,
|
"embeddings_path": None,
|
||||||
@@ -397,7 +401,7 @@ class SelvaTiScheduler:
|
|||||||
device, dtype, mode,
|
device, dtype, mode,
|
||||||
data_dir, out_path,
|
data_dir, out_path,
|
||||||
n_tokens, steps, lr, batch_size,
|
n_tokens, steps, lr, batch_size,
|
||||||
warmup, seed, save_every, init_text,
|
warmup, seed, save_every, init_text, inject_mode,
|
||||||
)
|
)
|
||||||
|
|
||||||
duration = time.monotonic() - t_start
|
duration = time.monotonic() - t_start
|
||||||
|
|||||||
Reference in New Issue
Block a user