fix: clean up dead code paths and debug artifacts in prismaudio_core/models

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-03-27 17:49:57 +01:00
parent 84c81e0e55
commit 6e1186d5bd
7 changed files with 10 additions and 129 deletions
+1 -7
View File
@@ -304,8 +304,7 @@ class AudioAutoencoder(nn.Module):
def encode(self, audio, return_info=False, skip_pretransform=False, iterate_batch=False, **kwargs):
info = {}
# import ipdb
# ipdb.set_trace()
if self.pretransform is not None and not skip_pretransform:
if self.pretransform.enable_grad:
if iterate_batch:
@@ -476,11 +475,8 @@ class AudioAutoencoder(nn.Module):
else:
# CHUNKED ENCODING
# samples_per_latent is just the downsampling ratio (which is also the upsampling ratio)
# import ipdb
# ipdb.set_trace()
samples_per_latent = self.downsampling_ratio
total_size = audio.shape[2] # in samples
print(f'audio shape: {audio.shape}')
batch_size = audio.shape[0]
chunk_size *= samples_per_latent # converting metric in latents to samples
overlap *= samples_per_latent # converting metric in latents to samples
@@ -501,12 +497,10 @@ class AudioAutoencoder(nn.Module):
y_size = total_size // samples_per_latent
# Create an empty latent, we will populate it with chunks as we encode them
y_final = torch.zeros((batch_size,self.latent_dim,y_size)).to(audio.device)
print(f'y_final shape: {y_final.shape}')
for i in range(num_chunks):
x_chunk = chunks[i,:]
# encode the chunk
y_chunk = self.encode(x_chunk)
print(f'y_chunk shape: {y_chunk.shape}')
# figure out where to put the audio along the time domain
if i == num_chunks-1:
# final chunk always goes at the end
-10
View File
@@ -213,8 +213,6 @@ class Video_Global(Conditioner):
self.global_proj = nn.Sequential(nn.Linear(output_dim, global_dim))
def forward(self, x, device: tp.Any = "cuda"):
# import ipdb
# ipdb.set_trace()
if not isinstance(x[0], torch.Tensor):
video_feats = []
for path in x:
@@ -242,8 +240,6 @@ class Video_Sync(Conditioner):
self.embedder = nn.Sequential(nn.Linear(dim, output_dim))
def forward(self, x, device: tp.Any = "cuda"):
# import ipdb
# ipdb.set_trace()
if not isinstance(x[0], torch.Tensor):
video_feats = []
for path in x:
@@ -269,8 +265,6 @@ class Text_Linear(Conditioner):
self.embedder = nn.Sequential(nn.Linear(dim, output_dim))
def forward(self, x, device: tp.Any = "cuda"):
# import ipdb
# ipdb.set_trace()
if not isinstance(x[0], torch.Tensor):
video_feats = []
for path in x:
@@ -295,8 +289,6 @@ class mm_unchang(Conditioner):
super().__init__(dim, output_dim)
def forward(self, x, device: tp.Any = "cuda"):
# import ipdb
# ipdb.set_trace()
if not isinstance(x[0], torch.Tensor):
video_feats = []
for path in x:
@@ -349,8 +341,6 @@ class CLIPConditioner(Conditioner):
self.model.to(device)
self.proj_out.to(device)
# import ipdb
# ipdb.set_trace()
self.model.eval()
if not isinstance(images[0], torch.Tensor):
+5 -87
View File
@@ -21,25 +21,6 @@ def _get_create_multi_conditioner_from_conditioning_config():
from prismaudio_core.factory import create_multi_conditioner_from_conditioning_config
return create_multi_conditioner_from_conditioning_config
from time import time
class Profiler:
def __init__(self):
self.ticks = [[time(), None]]
def tick(self, msg):
self.ticks.append([time(), msg])
def __repr__(self):
rep = 80 * "=" + "\n"
for i in range(1, len(self.ticks)):
msg = self.ticks[i][1]
ellapsed = self.ticks[i][0] - self.ticks[i - 1][0]
rep += msg + f": {ellapsed*1000:.2f}ms\n"
rep += 80 * "=" + "\n\n\n"
return rep
class DiffusionModel(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -176,8 +157,7 @@ class ConditionedDiffusionModelWrapper(nn.Module):
cross_attention_input.append(cross_attn_in)
cross_attention_masks.append(cross_attn_mask)
# import ipdb
# ipdb.set_trace()
cross_attention_input = torch.cat(cross_attention_input, dim=1)
cross_attention_masks = torch.cat(cross_attention_masks, dim=1)
@@ -314,10 +294,6 @@ class UNetCFG1DWrapper(ConditionedDiffusionModel):
prepend_cond=None,
prepend_cond_mask=None,
**kwargs):
p = Profiler()
p.tick("start")
channels_list = None
if input_concat_cond is not None:
channels_list = [input_concat_cond]
@@ -337,9 +313,6 @@ class UNetCFG1DWrapper(ConditionedDiffusionModel):
negative_embedding_mask=negative_cross_attn_mask,
**kwargs)
p.tick("UNetCFG1D forward")
#print(f"Profiler: {p}")
return outputs
class UNet1DCondWrapper(ConditionedDiffusionModel):
@@ -613,54 +586,6 @@ class DiTWrapper(ConditionedDiffusionModel):
global_embed=global_cond,
**kwargs)
class MMDiTWrapper(ConditionedDiffusionModel):
def __init__(
self,
*args,
**kwargs
):
super().__init__(supports_cross_attention=True, supports_global_cond=False, supports_input_concat=False)
self.model = MMAudio(*args, **kwargs)
# with torch.no_grad():
# for param in self.model.parameters():
# param *= 0.5
def forward(self,
x,
t,
clip_f,
sync_f,
text_f,
inpaint_masked_input=None,
t5_features=None,
metaclip_global_text_features=None,
cfg_scale=1.0,
cfg_dropout_prob: float = 0.0,
batch_cfg: bool = True,
rescale_cfg: bool = False,
scale_phi: float = 0.0,
**kwargs):
# breakpoint()
assert batch_cfg, "batch_cfg must be True for DiTWrapper"
#assert negative_input_concat_cond is None, "negative_input_concat_cond is not supported for DiTWrapper"
return self.model(
latent=x,
t=t,
clip_f=clip_f,
sync_f=sync_f,
text_f=text_f,
inpaint_masked_input=inpaint_masked_input,
t5_features=t5_features,
metaclip_global_text_features=metaclip_global_text_features,
cfg_scale=cfg_scale,
cfg_dropout_prob=cfg_dropout_prob,
scale_phi=scale_phi,
**kwargs)
class MMConditionedDiffusionModelWrapper(ConditionedDiffusionModel):
"""
A diffusion model that takes in conditioning
@@ -739,8 +664,6 @@ class MMConditionedDiffusionModelWrapper(ConditionedDiffusionModel):
}
def forward(self, x: torch.Tensor, t: torch.Tensor, cond: tp.Dict[str, tp.Any], **kwargs):
# breakpoint()
# print(kwargs)
return self.model(x=x, t=t, **self.get_conditioning_inputs(cond), **kwargs)
def generate(self, *args, **kwargs):
@@ -888,8 +811,8 @@ def create_diffusion_cond_from_config(config: tp.Dict[str, tp.Any]):
diffusion_model = UNet1DCondWrapper(**diffusion_model_config)
elif diffusion_model_type == 'dit':
diffusion_model = DiTWrapper(**diffusion_model_config)
elif diffusion_model_type == 'mmdit':
diffusion_model = MMDiTWrapper(**diffusion_model_config)
else:
raise NotImplementedError(f'Unknown diffusion model type: {diffusion_model_type}')
io_channels = model_config.get('io_channels', None)
assert io_channels is not None, "Must specify io_channels in model config"
@@ -939,13 +862,8 @@ def create_diffusion_cond_from_config(config: tp.Dict[str, tp.Any]):
wrapper_fn = ConditionedDiffusionModelWrapper
extra_kwargs["diffusion_objective"] = diffusion_objective
elif model_type == "diffusion_prior":
prior_type = model_config.get("prior_type", None)
assert prior_type is not None, "Must specify prior_type in diffusion prior model config"
if prior_type == "mono_stereo":
from .diffusion_prior import MonoToStereoDiffusionPrior
wrapper_fn = MonoToStereoDiffusionPrior
else:
raise NotImplementedError(f'Unknown model type: {model_type}')
return wrapper_fn(
diffusion_model,
-2
View File
@@ -223,8 +223,6 @@ class DiffusionTransformer(nn.Module):
# Get the batch of timestep embeddings
timestep_embed = self.to_timestep_embed(self.timestep_features(t[:, None])) # (b, embed_dim)
# import ipdb
# ipdb.set_trace()
# Timestep embedding is considered a global embedding. Add to the global conditioning if it exists
if self.timestep_cond_type == "global":
if global_embed is not None:
-19
View File
@@ -89,25 +89,6 @@ class AutoencoderPretransform(Pretransform):
def load_state_dict(self, state_dict, strict=True):
self.model.load_state_dict(state_dict, strict=strict)
class WaveletPretransform(Pretransform):
def __init__(self, channels, levels, wavelet):
super().__init__(enable_grad=False, io_channels=channels, is_discrete=False)
from .wavelets import WaveletEncode1d, WaveletDecode1d
self.encoder = WaveletEncode1d(channels, levels, wavelet)
self.decoder = WaveletDecode1d(channels, levels, wavelet)
self.downsampling_ratio = 2 ** levels
self.io_channels = channels
self.encoded_channels = channels * self.downsampling_ratio
def encode(self, x):
return self.encoder(x)
def decode(self, z):
return self.decoder(z)
class PQMFPretransform(Pretransform):
def __init__(self, attenuation=100, num_bands=16):
# TODO: Fix PQMF to take in in-channels
+4 -3
View File
@@ -421,9 +421,10 @@ class Attention(nn.Module):
flex_attention_score_mod = None
if flex_attention_block_mask is not None or flex_attention_score_mod is not None:
out = flex_attention_compiled(q,k,v,
block_mask = flex_attention_block_mask,
score_mod = flex_attention_score_mod)
raise NotImplementedError(
"FlexAttention is not available in this build. "
"flex_attention_compiled is not defined. Remove flex_attention_block_mask/flex_attention_score_mod arguments."
)
elif flash_attn_available:
fa_dtype_in = q.dtype
q, k, v = map(lambda t: rearrange(t, 'b h n d -> b n h d'), (q, k, v))
-1
View File
@@ -18,7 +18,6 @@ def load_ckpt_state_dict(ckpt_path, prefix=None):
def remove_weight_norm_from_model(model):
for module in model.modules():
if hasattr(module, "weight"):
print(f"Removing weight norm from {module}")
remove_weight_norm(module)
return model