From 618e7de64b5b4652383027cb0f515c127d971e65 Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Fri, 27 Mar 2026 18:09:11 +0100 Subject: [PATCH] feat: PrismAudioTextOnly node with correct T5-Gemma encoding Co-Authored-By: Claude Sonnet 4.6 --- nodes/text_only.py | 153 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 153 insertions(+) create mode 100644 nodes/text_only.py diff --git a/nodes/text_only.py b/nodes/text_only.py new file mode 100644 index 0000000..419f351 --- /dev/null +++ b/nodes/text_only.py @@ -0,0 +1,153 @@ +import torch +import comfy.model_management as mm +import comfy.utils + +from .utils import ( + PRISMAUDIO_CATEGORY, SAMPLE_RATE, DOWNSAMPLING_RATIO, IO_CHANNELS, + get_device, get_offload_device, soft_empty_cache, resolve_hf_token, +) +from .sampler import _substitute_empty_features + + +class PrismAudioTextOnly: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "model": ("PRISMAUDIO_MODEL",), + "text_prompt": ("STRING", {"default": "", "multiline": True, "tooltip": "Text description for audio generation"}), + "duration": ("FLOAT", {"default": 10.0, "min": 1.0, "max": 30.0, "step": 0.1}), + "steps": ("INT", {"default": 24, "min": 1, "max": 100}), + "cfg_scale": ("FLOAT", {"default": 5.0, "min": 1.0, "max": 20.0, "step": 0.1}), + "seed": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFFFF}), + }, + } + + RETURN_TYPES = ("AUDIO",) + RETURN_NAMES = ("audio",) + FUNCTION = "generate" + CATEGORY = PRISMAUDIO_CATEGORY + + def generate(self, model, text_prompt, duration, steps, cfg_scale, seed): + device = get_device() + dtype = model["dtype"] + strategy = model["strategy"] + diffusion = model["model"] + + latent_length = round(SAMPLE_RATE * duration / DOWNSAMPLING_RATIO) + + # Encode text with T5-Gemma + text_features = _encode_text_t5(text_prompt, device, dtype) + + # Build metadata: tuple of one dict per sample + # Use zero tensors for video/sync (not None — Cond_MLP crashes on None via pad_sequence) + # Sync_MLP requires length divisible by 8 (segments of 8 frames) — minimum [8, 768] + # These will be substituted with learned empty embeddings after conditioning + sample_meta = { + "video_features": torch.zeros(1, 1024, device=device, dtype=dtype), + "text_features": text_features.to(device, dtype=dtype), + "sync_features": torch.zeros(8, 768, device=device, dtype=dtype), + "video_exist": torch.tensor(False), + } + metadata = (sample_meta,) + + if strategy == "offload_to_cpu": + diffusion.model.to(device) + diffusion.conditioner.to(device) + soft_empty_cache() + + with torch.no_grad(), torch.amp.autocast(device_type=device.type, dtype=dtype): + conditioning = diffusion.conditioner(metadata, device) + + # Substitute empty features for video/sync + _substitute_empty_features(diffusion, conditioning, device, dtype) + + cond_inputs = diffusion.get_conditioning_inputs(conditioning) + + # Generate noise from seed (MPS doesn't support torch.Generator) + gen_device = "cpu" if device.type == "mps" else device + generator = torch.Generator(device=gen_device).manual_seed(seed) + noise = torch.randn( + [1, IO_CHANNELS, latent_length], + generator=generator, + device=gen_device, + ).to(device=device, dtype=dtype) + + pbar = comfy.utils.ProgressBar(steps) + + from prismaudio_core.inference.sampling import sample_discrete_euler + + def on_step(info): + pbar.update(1) + + fakes = sample_discrete_euler( + diffusion.model, + noise, + steps, + callback=on_step, + **cond_inputs, + cfg_scale=cfg_scale, + batch_cfg=True, + ) + + if strategy == "offload_to_cpu": + diffusion.model.to(get_offload_device()) + diffusion.conditioner.to(get_offload_device()) + soft_empty_cache() + diffusion.pretransform.to(device) + + # VAE decode in fp32 (snake activations overflow in fp16) + with torch.amp.autocast(device_type=device.type, enabled=False): + audio = diffusion.pretransform.decode(fakes.float()) + + if strategy == "offload_to_cpu": + diffusion.pretransform.to(get_offload_device()) + soft_empty_cache() + + # Peak normalize then clamp + audio = audio.float() + peak = audio.abs().max().clamp(min=1e-8) + audio = (audio / peak).clamp(-1, 1) + + return ({"waveform": audio.cpu(), "sample_rate": SAMPLE_RATE},) + + +# T5-Gemma encoder singleton +_t5_model = None +_t5_tokenizer = None + + +def _encode_text_t5(text, device, dtype): + """Encode text using T5-Gemma. + + Uses AutoModelForSeq2SeqLM.get_encoder() to match the reference + FeaturesUtils.encode_t5_text() implementation. + No truncation applied (matching reference behavior). + """ + global _t5_model, _t5_tokenizer + + if _t5_model is None: + from transformers import AutoTokenizer, AutoModelForSeq2SeqLM + model_id = "google/t5gemma-l-l-ul2-it" + token = resolve_hf_token() + print(f"[PrismAudio] Loading T5-Gemma text encoder: {model_id}") + _t5_tokenizer = AutoTokenizer.from_pretrained(model_id, token=token) + _t5_model = AutoModelForSeq2SeqLM.from_pretrained(model_id, token=token).get_encoder() + _t5_model.eval() + + _t5_model.to(device, dtype=dtype) + + tokens = _t5_tokenizer( + text, + return_tensors="pt", + padding=True, + ).to(device) + + with torch.no_grad(): + outputs = _t5_model(**tokens) + + # Move T5 off GPU after encoding to save VRAM + _t5_model.to("cpu") + soft_empty_cache() + + return outputs.last_hidden_state.squeeze(0) # [seq_len, dim]