fix: clean up dead code paths and debug artifacts in prismaudio_core/models
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -304,8 +304,7 @@ class AudioAutoencoder(nn.Module):
|
||||
def encode(self, audio, return_info=False, skip_pretransform=False, iterate_batch=False, **kwargs):
|
||||
|
||||
info = {}
|
||||
# import ipdb
|
||||
# ipdb.set_trace()
|
||||
|
||||
if self.pretransform is not None and not skip_pretransform:
|
||||
if self.pretransform.enable_grad:
|
||||
if iterate_batch:
|
||||
@@ -476,11 +475,8 @@ class AudioAutoencoder(nn.Module):
|
||||
else:
|
||||
# CHUNKED ENCODING
|
||||
# samples_per_latent is just the downsampling ratio (which is also the upsampling ratio)
|
||||
# import ipdb
|
||||
# ipdb.set_trace()
|
||||
samples_per_latent = self.downsampling_ratio
|
||||
total_size = audio.shape[2] # in samples
|
||||
print(f'audio shape: {audio.shape}')
|
||||
batch_size = audio.shape[0]
|
||||
chunk_size *= samples_per_latent # converting metric in latents to samples
|
||||
overlap *= samples_per_latent # converting metric in latents to samples
|
||||
@@ -501,12 +497,10 @@ class AudioAutoencoder(nn.Module):
|
||||
y_size = total_size // samples_per_latent
|
||||
# Create an empty latent, we will populate it with chunks as we encode them
|
||||
y_final = torch.zeros((batch_size,self.latent_dim,y_size)).to(audio.device)
|
||||
print(f'y_final shape: {y_final.shape}')
|
||||
for i in range(num_chunks):
|
||||
x_chunk = chunks[i,:]
|
||||
# encode the chunk
|
||||
y_chunk = self.encode(x_chunk)
|
||||
print(f'y_chunk shape: {y_chunk.shape}')
|
||||
# figure out where to put the audio along the time domain
|
||||
if i == num_chunks-1:
|
||||
# final chunk always goes at the end
|
||||
|
||||
@@ -213,8 +213,6 @@ class Video_Global(Conditioner):
|
||||
self.global_proj = nn.Sequential(nn.Linear(output_dim, global_dim))
|
||||
|
||||
def forward(self, x, device: tp.Any = "cuda"):
|
||||
# import ipdb
|
||||
# ipdb.set_trace()
|
||||
if not isinstance(x[0], torch.Tensor):
|
||||
video_feats = []
|
||||
for path in x:
|
||||
@@ -242,8 +240,6 @@ class Video_Sync(Conditioner):
|
||||
self.embedder = nn.Sequential(nn.Linear(dim, output_dim))
|
||||
|
||||
def forward(self, x, device: tp.Any = "cuda"):
|
||||
# import ipdb
|
||||
# ipdb.set_trace()
|
||||
if not isinstance(x[0], torch.Tensor):
|
||||
video_feats = []
|
||||
for path in x:
|
||||
@@ -269,8 +265,6 @@ class Text_Linear(Conditioner):
|
||||
self.embedder = nn.Sequential(nn.Linear(dim, output_dim))
|
||||
|
||||
def forward(self, x, device: tp.Any = "cuda"):
|
||||
# import ipdb
|
||||
# ipdb.set_trace()
|
||||
if not isinstance(x[0], torch.Tensor):
|
||||
video_feats = []
|
||||
for path in x:
|
||||
@@ -295,8 +289,6 @@ class mm_unchang(Conditioner):
|
||||
super().__init__(dim, output_dim)
|
||||
|
||||
def forward(self, x, device: tp.Any = "cuda"):
|
||||
# import ipdb
|
||||
# ipdb.set_trace()
|
||||
if not isinstance(x[0], torch.Tensor):
|
||||
video_feats = []
|
||||
for path in x:
|
||||
@@ -349,8 +341,6 @@ class CLIPConditioner(Conditioner):
|
||||
|
||||
self.model.to(device)
|
||||
self.proj_out.to(device)
|
||||
# import ipdb
|
||||
# ipdb.set_trace()
|
||||
|
||||
self.model.eval()
|
||||
if not isinstance(images[0], torch.Tensor):
|
||||
|
||||
@@ -21,25 +21,6 @@ def _get_create_multi_conditioner_from_conditioning_config():
|
||||
from prismaudio_core.factory import create_multi_conditioner_from_conditioning_config
|
||||
return create_multi_conditioner_from_conditioning_config
|
||||
|
||||
from time import time
|
||||
|
||||
class Profiler:
|
||||
|
||||
def __init__(self):
|
||||
self.ticks = [[time(), None]]
|
||||
|
||||
def tick(self, msg):
|
||||
self.ticks.append([time(), msg])
|
||||
|
||||
def __repr__(self):
|
||||
rep = 80 * "=" + "\n"
|
||||
for i in range(1, len(self.ticks)):
|
||||
msg = self.ticks[i][1]
|
||||
ellapsed = self.ticks[i][0] - self.ticks[i - 1][0]
|
||||
rep += msg + f": {ellapsed*1000:.2f}ms\n"
|
||||
rep += 80 * "=" + "\n\n\n"
|
||||
return rep
|
||||
|
||||
class DiffusionModel(nn.Module):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
@@ -176,8 +157,7 @@ class ConditionedDiffusionModelWrapper(nn.Module):
|
||||
|
||||
cross_attention_input.append(cross_attn_in)
|
||||
cross_attention_masks.append(cross_attn_mask)
|
||||
# import ipdb
|
||||
# ipdb.set_trace()
|
||||
|
||||
cross_attention_input = torch.cat(cross_attention_input, dim=1)
|
||||
cross_attention_masks = torch.cat(cross_attention_masks, dim=1)
|
||||
|
||||
@@ -314,10 +294,6 @@ class UNetCFG1DWrapper(ConditionedDiffusionModel):
|
||||
prepend_cond=None,
|
||||
prepend_cond_mask=None,
|
||||
**kwargs):
|
||||
p = Profiler()
|
||||
|
||||
p.tick("start")
|
||||
|
||||
channels_list = None
|
||||
if input_concat_cond is not None:
|
||||
channels_list = [input_concat_cond]
|
||||
@@ -337,9 +313,6 @@ class UNetCFG1DWrapper(ConditionedDiffusionModel):
|
||||
negative_embedding_mask=negative_cross_attn_mask,
|
||||
**kwargs)
|
||||
|
||||
p.tick("UNetCFG1D forward")
|
||||
|
||||
#print(f"Profiler: {p}")
|
||||
return outputs
|
||||
|
||||
class UNet1DCondWrapper(ConditionedDiffusionModel):
|
||||
@@ -613,54 +586,6 @@ class DiTWrapper(ConditionedDiffusionModel):
|
||||
global_embed=global_cond,
|
||||
**kwargs)
|
||||
|
||||
class MMDiTWrapper(ConditionedDiffusionModel):
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(supports_cross_attention=True, supports_global_cond=False, supports_input_concat=False)
|
||||
|
||||
self.model = MMAudio(*args, **kwargs)
|
||||
|
||||
# with torch.no_grad():
|
||||
# for param in self.model.parameters():
|
||||
# param *= 0.5
|
||||
|
||||
def forward(self,
|
||||
x,
|
||||
t,
|
||||
clip_f,
|
||||
sync_f,
|
||||
text_f,
|
||||
inpaint_masked_input=None,
|
||||
t5_features=None,
|
||||
metaclip_global_text_features=None,
|
||||
cfg_scale=1.0,
|
||||
cfg_dropout_prob: float = 0.0,
|
||||
batch_cfg: bool = True,
|
||||
rescale_cfg: bool = False,
|
||||
scale_phi: float = 0.0,
|
||||
**kwargs):
|
||||
|
||||
# breakpoint()
|
||||
assert batch_cfg, "batch_cfg must be True for DiTWrapper"
|
||||
#assert negative_input_concat_cond is None, "negative_input_concat_cond is not supported for DiTWrapper"
|
||||
|
||||
return self.model(
|
||||
latent=x,
|
||||
t=t,
|
||||
clip_f=clip_f,
|
||||
sync_f=sync_f,
|
||||
text_f=text_f,
|
||||
inpaint_masked_input=inpaint_masked_input,
|
||||
t5_features=t5_features,
|
||||
metaclip_global_text_features=metaclip_global_text_features,
|
||||
cfg_scale=cfg_scale,
|
||||
cfg_dropout_prob=cfg_dropout_prob,
|
||||
scale_phi=scale_phi,
|
||||
**kwargs)
|
||||
|
||||
class MMConditionedDiffusionModelWrapper(ConditionedDiffusionModel):
|
||||
"""
|
||||
A diffusion model that takes in conditioning
|
||||
@@ -739,8 +664,6 @@ class MMConditionedDiffusionModelWrapper(ConditionedDiffusionModel):
|
||||
}
|
||||
|
||||
def forward(self, x: torch.Tensor, t: torch.Tensor, cond: tp.Dict[str, tp.Any], **kwargs):
|
||||
# breakpoint()
|
||||
# print(kwargs)
|
||||
return self.model(x=x, t=t, **self.get_conditioning_inputs(cond), **kwargs)
|
||||
|
||||
def generate(self, *args, **kwargs):
|
||||
@@ -888,8 +811,8 @@ def create_diffusion_cond_from_config(config: tp.Dict[str, tp.Any]):
|
||||
diffusion_model = UNet1DCondWrapper(**diffusion_model_config)
|
||||
elif diffusion_model_type == 'dit':
|
||||
diffusion_model = DiTWrapper(**diffusion_model_config)
|
||||
elif diffusion_model_type == 'mmdit':
|
||||
diffusion_model = MMDiTWrapper(**diffusion_model_config)
|
||||
else:
|
||||
raise NotImplementedError(f'Unknown diffusion model type: {diffusion_model_type}')
|
||||
|
||||
io_channels = model_config.get('io_channels', None)
|
||||
assert io_channels is not None, "Must specify io_channels in model config"
|
||||
@@ -939,13 +862,8 @@ def create_diffusion_cond_from_config(config: tp.Dict[str, tp.Any]):
|
||||
wrapper_fn = ConditionedDiffusionModelWrapper
|
||||
extra_kwargs["diffusion_objective"] = diffusion_objective
|
||||
|
||||
elif model_type == "diffusion_prior":
|
||||
prior_type = model_config.get("prior_type", None)
|
||||
assert prior_type is not None, "Must specify prior_type in diffusion prior model config"
|
||||
|
||||
if prior_type == "mono_stereo":
|
||||
from .diffusion_prior import MonoToStereoDiffusionPrior
|
||||
wrapper_fn = MonoToStereoDiffusionPrior
|
||||
else:
|
||||
raise NotImplementedError(f'Unknown model type: {model_type}')
|
||||
|
||||
return wrapper_fn(
|
||||
diffusion_model,
|
||||
|
||||
@@ -223,8 +223,6 @@ class DiffusionTransformer(nn.Module):
|
||||
|
||||
# Get the batch of timestep embeddings
|
||||
timestep_embed = self.to_timestep_embed(self.timestep_features(t[:, None])) # (b, embed_dim)
|
||||
# import ipdb
|
||||
# ipdb.set_trace()
|
||||
# Timestep embedding is considered a global embedding. Add to the global conditioning if it exists
|
||||
if self.timestep_cond_type == "global":
|
||||
if global_embed is not None:
|
||||
|
||||
@@ -89,25 +89,6 @@ class AutoencoderPretransform(Pretransform):
|
||||
def load_state_dict(self, state_dict, strict=True):
|
||||
self.model.load_state_dict(state_dict, strict=strict)
|
||||
|
||||
class WaveletPretransform(Pretransform):
|
||||
def __init__(self, channels, levels, wavelet):
|
||||
super().__init__(enable_grad=False, io_channels=channels, is_discrete=False)
|
||||
|
||||
from .wavelets import WaveletEncode1d, WaveletDecode1d
|
||||
|
||||
self.encoder = WaveletEncode1d(channels, levels, wavelet)
|
||||
self.decoder = WaveletDecode1d(channels, levels, wavelet)
|
||||
|
||||
self.downsampling_ratio = 2 ** levels
|
||||
self.io_channels = channels
|
||||
self.encoded_channels = channels * self.downsampling_ratio
|
||||
|
||||
def encode(self, x):
|
||||
return self.encoder(x)
|
||||
|
||||
def decode(self, z):
|
||||
return self.decoder(z)
|
||||
|
||||
class PQMFPretransform(Pretransform):
|
||||
def __init__(self, attenuation=100, num_bands=16):
|
||||
# TODO: Fix PQMF to take in in-channels
|
||||
|
||||
@@ -421,9 +421,10 @@ class Attention(nn.Module):
|
||||
flex_attention_score_mod = None
|
||||
|
||||
if flex_attention_block_mask is not None or flex_attention_score_mod is not None:
|
||||
out = flex_attention_compiled(q,k,v,
|
||||
block_mask = flex_attention_block_mask,
|
||||
score_mod = flex_attention_score_mod)
|
||||
raise NotImplementedError(
|
||||
"FlexAttention is not available in this build. "
|
||||
"flex_attention_compiled is not defined. Remove flex_attention_block_mask/flex_attention_score_mod arguments."
|
||||
)
|
||||
elif flash_attn_available:
|
||||
fa_dtype_in = q.dtype
|
||||
q, k, v = map(lambda t: rearrange(t, 'b h n d -> b n h d'), (q, k, v))
|
||||
|
||||
@@ -18,7 +18,6 @@ def load_ckpt_state_dict(ckpt_path, prefix=None):
|
||||
def remove_weight_norm_from_model(model):
|
||||
for module in model.modules():
|
||||
if hasattr(module, "weight"):
|
||||
print(f"Removing weight norm from {module}")
|
||||
remove_weight_norm(module)
|
||||
|
||||
return model
|
||||
|
||||
Reference in New Issue
Block a user