fix: clean up dead code paths and debug artifacts in prismaudio_core/models
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -21,25 +21,6 @@ def _get_create_multi_conditioner_from_conditioning_config():
|
||||
from prismaudio_core.factory import create_multi_conditioner_from_conditioning_config
|
||||
return create_multi_conditioner_from_conditioning_config
|
||||
|
||||
from time import time
|
||||
|
||||
class Profiler:
|
||||
|
||||
def __init__(self):
|
||||
self.ticks = [[time(), None]]
|
||||
|
||||
def tick(self, msg):
|
||||
self.ticks.append([time(), msg])
|
||||
|
||||
def __repr__(self):
|
||||
rep = 80 * "=" + "\n"
|
||||
for i in range(1, len(self.ticks)):
|
||||
msg = self.ticks[i][1]
|
||||
ellapsed = self.ticks[i][0] - self.ticks[i - 1][0]
|
||||
rep += msg + f": {ellapsed*1000:.2f}ms\n"
|
||||
rep += 80 * "=" + "\n\n\n"
|
||||
return rep
|
||||
|
||||
class DiffusionModel(nn.Module):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
@@ -176,8 +157,7 @@ class ConditionedDiffusionModelWrapper(nn.Module):
|
||||
|
||||
cross_attention_input.append(cross_attn_in)
|
||||
cross_attention_masks.append(cross_attn_mask)
|
||||
# import ipdb
|
||||
# ipdb.set_trace()
|
||||
|
||||
cross_attention_input = torch.cat(cross_attention_input, dim=1)
|
||||
cross_attention_masks = torch.cat(cross_attention_masks, dim=1)
|
||||
|
||||
@@ -314,10 +294,6 @@ class UNetCFG1DWrapper(ConditionedDiffusionModel):
|
||||
prepend_cond=None,
|
||||
prepend_cond_mask=None,
|
||||
**kwargs):
|
||||
p = Profiler()
|
||||
|
||||
p.tick("start")
|
||||
|
||||
channels_list = None
|
||||
if input_concat_cond is not None:
|
||||
channels_list = [input_concat_cond]
|
||||
@@ -337,9 +313,6 @@ class UNetCFG1DWrapper(ConditionedDiffusionModel):
|
||||
negative_embedding_mask=negative_cross_attn_mask,
|
||||
**kwargs)
|
||||
|
||||
p.tick("UNetCFG1D forward")
|
||||
|
||||
#print(f"Profiler: {p}")
|
||||
return outputs
|
||||
|
||||
class UNet1DCondWrapper(ConditionedDiffusionModel):
|
||||
@@ -613,54 +586,6 @@ class DiTWrapper(ConditionedDiffusionModel):
|
||||
global_embed=global_cond,
|
||||
**kwargs)
|
||||
|
||||
class MMDiTWrapper(ConditionedDiffusionModel):
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(supports_cross_attention=True, supports_global_cond=False, supports_input_concat=False)
|
||||
|
||||
self.model = MMAudio(*args, **kwargs)
|
||||
|
||||
# with torch.no_grad():
|
||||
# for param in self.model.parameters():
|
||||
# param *= 0.5
|
||||
|
||||
def forward(self,
|
||||
x,
|
||||
t,
|
||||
clip_f,
|
||||
sync_f,
|
||||
text_f,
|
||||
inpaint_masked_input=None,
|
||||
t5_features=None,
|
||||
metaclip_global_text_features=None,
|
||||
cfg_scale=1.0,
|
||||
cfg_dropout_prob: float = 0.0,
|
||||
batch_cfg: bool = True,
|
||||
rescale_cfg: bool = False,
|
||||
scale_phi: float = 0.0,
|
||||
**kwargs):
|
||||
|
||||
# breakpoint()
|
||||
assert batch_cfg, "batch_cfg must be True for DiTWrapper"
|
||||
#assert negative_input_concat_cond is None, "negative_input_concat_cond is not supported for DiTWrapper"
|
||||
|
||||
return self.model(
|
||||
latent=x,
|
||||
t=t,
|
||||
clip_f=clip_f,
|
||||
sync_f=sync_f,
|
||||
text_f=text_f,
|
||||
inpaint_masked_input=inpaint_masked_input,
|
||||
t5_features=t5_features,
|
||||
metaclip_global_text_features=metaclip_global_text_features,
|
||||
cfg_scale=cfg_scale,
|
||||
cfg_dropout_prob=cfg_dropout_prob,
|
||||
scale_phi=scale_phi,
|
||||
**kwargs)
|
||||
|
||||
class MMConditionedDiffusionModelWrapper(ConditionedDiffusionModel):
|
||||
"""
|
||||
A diffusion model that takes in conditioning
|
||||
@@ -739,8 +664,6 @@ class MMConditionedDiffusionModelWrapper(ConditionedDiffusionModel):
|
||||
}
|
||||
|
||||
def forward(self, x: torch.Tensor, t: torch.Tensor, cond: tp.Dict[str, tp.Any], **kwargs):
|
||||
# breakpoint()
|
||||
# print(kwargs)
|
||||
return self.model(x=x, t=t, **self.get_conditioning_inputs(cond), **kwargs)
|
||||
|
||||
def generate(self, *args, **kwargs):
|
||||
@@ -888,8 +811,8 @@ def create_diffusion_cond_from_config(config: tp.Dict[str, tp.Any]):
|
||||
diffusion_model = UNet1DCondWrapper(**diffusion_model_config)
|
||||
elif diffusion_model_type == 'dit':
|
||||
diffusion_model = DiTWrapper(**diffusion_model_config)
|
||||
elif diffusion_model_type == 'mmdit':
|
||||
diffusion_model = MMDiTWrapper(**diffusion_model_config)
|
||||
else:
|
||||
raise NotImplementedError(f'Unknown diffusion model type: {diffusion_model_type}')
|
||||
|
||||
io_channels = model_config.get('io_channels', None)
|
||||
assert io_channels is not None, "Must specify io_channels in model config"
|
||||
@@ -939,14 +862,9 @@ def create_diffusion_cond_from_config(config: tp.Dict[str, tp.Any]):
|
||||
wrapper_fn = ConditionedDiffusionModelWrapper
|
||||
extra_kwargs["diffusion_objective"] = diffusion_objective
|
||||
|
||||
elif model_type == "diffusion_prior":
|
||||
prior_type = model_config.get("prior_type", None)
|
||||
assert prior_type is not None, "Must specify prior_type in diffusion prior model config"
|
||||
else:
|
||||
raise NotImplementedError(f'Unknown model type: {model_type}')
|
||||
|
||||
if prior_type == "mono_stereo":
|
||||
from .diffusion_prior import MonoToStereoDiffusionPrior
|
||||
wrapper_fn = MonoToStereoDiffusionPrior
|
||||
|
||||
return wrapper_fn(
|
||||
diffusion_model,
|
||||
conditioner,
|
||||
|
||||
Reference in New Issue
Block a user