fix: clean up dead code paths and debug artifacts in prismaudio_core/models

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-03-27 17:49:57 +01:00
parent 84c81e0e55
commit 6e1186d5bd
7 changed files with 10 additions and 129 deletions
+5 -87
View File
@@ -21,25 +21,6 @@ def _get_create_multi_conditioner_from_conditioning_config():
from prismaudio_core.factory import create_multi_conditioner_from_conditioning_config
return create_multi_conditioner_from_conditioning_config
from time import time
class Profiler:
def __init__(self):
self.ticks = [[time(), None]]
def tick(self, msg):
self.ticks.append([time(), msg])
def __repr__(self):
rep = 80 * "=" + "\n"
for i in range(1, len(self.ticks)):
msg = self.ticks[i][1]
ellapsed = self.ticks[i][0] - self.ticks[i - 1][0]
rep += msg + f": {ellapsed*1000:.2f}ms\n"
rep += 80 * "=" + "\n\n\n"
return rep
class DiffusionModel(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -176,8 +157,7 @@ class ConditionedDiffusionModelWrapper(nn.Module):
cross_attention_input.append(cross_attn_in)
cross_attention_masks.append(cross_attn_mask)
# import ipdb
# ipdb.set_trace()
cross_attention_input = torch.cat(cross_attention_input, dim=1)
cross_attention_masks = torch.cat(cross_attention_masks, dim=1)
@@ -314,10 +294,6 @@ class UNetCFG1DWrapper(ConditionedDiffusionModel):
prepend_cond=None,
prepend_cond_mask=None,
**kwargs):
p = Profiler()
p.tick("start")
channels_list = None
if input_concat_cond is not None:
channels_list = [input_concat_cond]
@@ -337,9 +313,6 @@ class UNetCFG1DWrapper(ConditionedDiffusionModel):
negative_embedding_mask=negative_cross_attn_mask,
**kwargs)
p.tick("UNetCFG1D forward")
#print(f"Profiler: {p}")
return outputs
class UNet1DCondWrapper(ConditionedDiffusionModel):
@@ -613,54 +586,6 @@ class DiTWrapper(ConditionedDiffusionModel):
global_embed=global_cond,
**kwargs)
class MMDiTWrapper(ConditionedDiffusionModel):
def __init__(
self,
*args,
**kwargs
):
super().__init__(supports_cross_attention=True, supports_global_cond=False, supports_input_concat=False)
self.model = MMAudio(*args, **kwargs)
# with torch.no_grad():
# for param in self.model.parameters():
# param *= 0.5
def forward(self,
x,
t,
clip_f,
sync_f,
text_f,
inpaint_masked_input=None,
t5_features=None,
metaclip_global_text_features=None,
cfg_scale=1.0,
cfg_dropout_prob: float = 0.0,
batch_cfg: bool = True,
rescale_cfg: bool = False,
scale_phi: float = 0.0,
**kwargs):
# breakpoint()
assert batch_cfg, "batch_cfg must be True for DiTWrapper"
#assert negative_input_concat_cond is None, "negative_input_concat_cond is not supported for DiTWrapper"
return self.model(
latent=x,
t=t,
clip_f=clip_f,
sync_f=sync_f,
text_f=text_f,
inpaint_masked_input=inpaint_masked_input,
t5_features=t5_features,
metaclip_global_text_features=metaclip_global_text_features,
cfg_scale=cfg_scale,
cfg_dropout_prob=cfg_dropout_prob,
scale_phi=scale_phi,
**kwargs)
class MMConditionedDiffusionModelWrapper(ConditionedDiffusionModel):
"""
A diffusion model that takes in conditioning
@@ -739,8 +664,6 @@ class MMConditionedDiffusionModelWrapper(ConditionedDiffusionModel):
}
def forward(self, x: torch.Tensor, t: torch.Tensor, cond: tp.Dict[str, tp.Any], **kwargs):
# breakpoint()
# print(kwargs)
return self.model(x=x, t=t, **self.get_conditioning_inputs(cond), **kwargs)
def generate(self, *args, **kwargs):
@@ -888,8 +811,8 @@ def create_diffusion_cond_from_config(config: tp.Dict[str, tp.Any]):
diffusion_model = UNet1DCondWrapper(**diffusion_model_config)
elif diffusion_model_type == 'dit':
diffusion_model = DiTWrapper(**diffusion_model_config)
elif diffusion_model_type == 'mmdit':
diffusion_model = MMDiTWrapper(**diffusion_model_config)
else:
raise NotImplementedError(f'Unknown diffusion model type: {diffusion_model_type}')
io_channels = model_config.get('io_channels', None)
assert io_channels is not None, "Must specify io_channels in model config"
@@ -939,14 +862,9 @@ def create_diffusion_cond_from_config(config: tp.Dict[str, tp.Any]):
wrapper_fn = ConditionedDiffusionModelWrapper
extra_kwargs["diffusion_objective"] = diffusion_objective
elif model_type == "diffusion_prior":
prior_type = model_config.get("prior_type", None)
assert prior_type is not None, "Must specify prior_type in diffusion prior model config"
else:
raise NotImplementedError(f'Unknown model type: {model_type}')
if prior_type == "mono_stereo":
from .diffusion_prior import MonoToStereoDiffusionPrior
wrapper_fn = MonoToStereoDiffusionPrior
return wrapper_fn(
diffusion_model,
conditioner,