c0b7ccbcee
Zero features through bias-free Cond_MLP produce near-zero activations, not the learned null signal the model was trained with. Use empty_clip_feat (the learned null video embedding) just like empty_sync_feat for sync. Also improve text_prompt tooltip to encourage detailed CoT descriptions. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
161 lines
6.1 KiB
Python
161 lines
6.1 KiB
Python
import torch
|
|
import comfy.model_management as mm
|
|
import comfy.utils
|
|
|
|
from .utils import (
|
|
PRISMAUDIO_CATEGORY, SAMPLE_RATE, DOWNSAMPLING_RATIO, IO_CHANNELS,
|
|
get_device, get_offload_device, soft_empty_cache, resolve_hf_token,
|
|
)
|
|
from .sampler import _substitute_empty_features
|
|
|
|
|
|
class PrismAudioTextOnly:
|
|
@classmethod
|
|
def INPUT_TYPES(cls):
|
|
return {
|
|
"required": {
|
|
"model": ("PRISMAUDIO_MODEL",),
|
|
"text_prompt": ("STRING", {"default": "", "multiline": True, "tooltip": "Detailed chain-of-thought description of the audio scene. Use long, descriptive text — e.g. 'A large dog barks sharply twice, with ambient outdoor background noise. The sound is clear and close.' Short prompts produce lower quality."}),
|
|
"duration": ("FLOAT", {"default": 10.0, "min": 1.0, "max": 30.0, "step": 0.1}),
|
|
"steps": ("INT", {"default": 24, "min": 1, "max": 100}),
|
|
"cfg_scale": ("FLOAT", {"default": 5.0, "min": 1.0, "max": 20.0, "step": 0.1}),
|
|
"seed": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFFFF}),
|
|
},
|
|
}
|
|
|
|
RETURN_TYPES = ("AUDIO",)
|
|
RETURN_NAMES = ("audio",)
|
|
FUNCTION = "generate"
|
|
CATEGORY = PRISMAUDIO_CATEGORY
|
|
|
|
def generate(self, model, text_prompt, duration, steps, cfg_scale, seed):
|
|
device = get_device()
|
|
dtype = model["dtype"]
|
|
strategy = model["strategy"]
|
|
diffusion = model["model"]
|
|
|
|
latent_length = round(SAMPLE_RATE * duration / DOWNSAMPLING_RATIO)
|
|
|
|
# Encode text with T5-Gemma
|
|
text_features = _encode_text_t5(text_prompt, device, dtype)
|
|
|
|
# Build metadata: tuple of one dict per sample
|
|
# Use zero tensors for video/sync (not None — Cond_MLP crashes on None via pad_sequence)
|
|
# Sync_MLP requires length divisible by 8 (segments of 8 frames) — minimum [8, 768]
|
|
# These will be substituted with learned empty embeddings after conditioning
|
|
sample_meta = {
|
|
"video_features": torch.zeros(1, 1024, device=device, dtype=dtype),
|
|
"text_features": text_features.to(device, dtype=dtype),
|
|
"sync_features": torch.zeros(8, 768, device=device, dtype=dtype),
|
|
"video_exist": torch.tensor(False),
|
|
}
|
|
metadata = (sample_meta,)
|
|
|
|
if strategy == "offload_to_cpu":
|
|
diffusion.model.to(device)
|
|
diffusion.conditioner.to(device)
|
|
soft_empty_cache()
|
|
|
|
with torch.no_grad(), torch.amp.autocast(device_type=device.type, dtype=dtype):
|
|
conditioning = diffusion.conditioner(metadata, device)
|
|
|
|
# Substitute empty features for video/sync
|
|
_substitute_empty_features(diffusion, conditioning, device, dtype)
|
|
|
|
cond_inputs = diffusion.get_conditioning_inputs(conditioning)
|
|
|
|
# Generate noise from seed (MPS doesn't support torch.Generator)
|
|
gen_device = "cpu" if device.type == "mps" else device
|
|
generator = torch.Generator(device=gen_device).manual_seed(seed)
|
|
noise = torch.randn(
|
|
[1, IO_CHANNELS, latent_length],
|
|
generator=generator,
|
|
device=gen_device,
|
|
).to(device=device, dtype=dtype)
|
|
|
|
pbar = comfy.utils.ProgressBar(steps)
|
|
|
|
from prismaudio_core.inference.sampling import sample_discrete_euler
|
|
|
|
def on_step(info):
|
|
pbar.update(1)
|
|
|
|
fakes = sample_discrete_euler(
|
|
diffusion.model,
|
|
noise,
|
|
steps,
|
|
callback=on_step,
|
|
**cond_inputs,
|
|
cfg_scale=cfg_scale,
|
|
batch_cfg=True,
|
|
)
|
|
|
|
fakes_f = fakes.float()
|
|
print(f"[PrismAudio] latent stats: shape={tuple(fakes_f.shape)} mean={fakes_f.mean():.4f} std={fakes_f.std():.4f} min={fakes_f.min():.4f} max={fakes_f.max():.4f}", flush=True)
|
|
|
|
if strategy == "offload_to_cpu":
|
|
diffusion.model.to(get_offload_device())
|
|
diffusion.conditioner.to(get_offload_device())
|
|
soft_empty_cache()
|
|
diffusion.pretransform.to(device)
|
|
|
|
# VAE decode in fp32 (snake activations overflow in fp16)
|
|
with torch.amp.autocast(device_type=device.type, enabled=False):
|
|
audio = diffusion.pretransform.decode(fakes_f)
|
|
|
|
if strategy == "offload_to_cpu":
|
|
diffusion.pretransform.to(get_offload_device())
|
|
soft_empty_cache()
|
|
|
|
# Peak normalize then clamp
|
|
audio = audio.float()
|
|
pre_norm_std = audio.std().item()
|
|
pre_norm_peak = audio.abs().max().item()
|
|
peak = audio.abs().max().clamp(min=1e-8)
|
|
audio = (audio / peak).clamp(-1, 1)
|
|
print(f"[PrismAudio] audio stats (pre-norm): std={pre_norm_std:.4f} peak={pre_norm_peak:.4f}", flush=True)
|
|
print(f"[PrismAudio] audio shape: {tuple(audio.shape)}", flush=True)
|
|
|
|
return ({"waveform": audio.cpu(), "sample_rate": SAMPLE_RATE},)
|
|
|
|
|
|
# T5-Gemma encoder singleton
|
|
_t5_model = None
|
|
_t5_tokenizer = None
|
|
|
|
|
|
def _encode_text_t5(text, device, dtype):
|
|
"""Encode text using T5-Gemma.
|
|
|
|
Uses AutoModelForSeq2SeqLM.get_encoder() to match the reference
|
|
FeaturesUtils.encode_t5_text() implementation.
|
|
No truncation applied (matching reference behavior).
|
|
"""
|
|
global _t5_model, _t5_tokenizer
|
|
|
|
if _t5_model is None:
|
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
|
model_id = "google/t5gemma-l-l-ul2-it"
|
|
token = resolve_hf_token()
|
|
print(f"[PrismAudio] Loading T5-Gemma text encoder: {model_id}")
|
|
_t5_tokenizer = AutoTokenizer.from_pretrained(model_id, token=token)
|
|
_t5_model = AutoModelForSeq2SeqLM.from_pretrained(model_id, token=token).get_encoder()
|
|
_t5_model.eval()
|
|
|
|
_t5_model.to(device, dtype=dtype)
|
|
|
|
tokens = _t5_tokenizer(
|
|
text,
|
|
return_tensors="pt",
|
|
padding=True,
|
|
).to(device)
|
|
|
|
with torch.no_grad():
|
|
outputs = _t5_model(**tokens)
|
|
|
|
# Move T5 off GPU after encoding to save VRAM
|
|
_t5_model.to("cpu")
|
|
soft_empty_cache()
|
|
|
|
return outputs.last_hidden_state.squeeze(0) # [seq_len, dim]
|