fix: clean up dead code paths and debug artifacts in prismaudio_core/models
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -304,8 +304,7 @@ class AudioAutoencoder(nn.Module):
|
|||||||
def encode(self, audio, return_info=False, skip_pretransform=False, iterate_batch=False, **kwargs):
|
def encode(self, audio, return_info=False, skip_pretransform=False, iterate_batch=False, **kwargs):
|
||||||
|
|
||||||
info = {}
|
info = {}
|
||||||
# import ipdb
|
|
||||||
# ipdb.set_trace()
|
|
||||||
if self.pretransform is not None and not skip_pretransform:
|
if self.pretransform is not None and not skip_pretransform:
|
||||||
if self.pretransform.enable_grad:
|
if self.pretransform.enable_grad:
|
||||||
if iterate_batch:
|
if iterate_batch:
|
||||||
@@ -476,11 +475,8 @@ class AudioAutoencoder(nn.Module):
|
|||||||
else:
|
else:
|
||||||
# CHUNKED ENCODING
|
# CHUNKED ENCODING
|
||||||
# samples_per_latent is just the downsampling ratio (which is also the upsampling ratio)
|
# samples_per_latent is just the downsampling ratio (which is also the upsampling ratio)
|
||||||
# import ipdb
|
|
||||||
# ipdb.set_trace()
|
|
||||||
samples_per_latent = self.downsampling_ratio
|
samples_per_latent = self.downsampling_ratio
|
||||||
total_size = audio.shape[2] # in samples
|
total_size = audio.shape[2] # in samples
|
||||||
print(f'audio shape: {audio.shape}')
|
|
||||||
batch_size = audio.shape[0]
|
batch_size = audio.shape[0]
|
||||||
chunk_size *= samples_per_latent # converting metric in latents to samples
|
chunk_size *= samples_per_latent # converting metric in latents to samples
|
||||||
overlap *= samples_per_latent # converting metric in latents to samples
|
overlap *= samples_per_latent # converting metric in latents to samples
|
||||||
@@ -501,12 +497,10 @@ class AudioAutoencoder(nn.Module):
|
|||||||
y_size = total_size // samples_per_latent
|
y_size = total_size // samples_per_latent
|
||||||
# Create an empty latent, we will populate it with chunks as we encode them
|
# Create an empty latent, we will populate it with chunks as we encode them
|
||||||
y_final = torch.zeros((batch_size,self.latent_dim,y_size)).to(audio.device)
|
y_final = torch.zeros((batch_size,self.latent_dim,y_size)).to(audio.device)
|
||||||
print(f'y_final shape: {y_final.shape}')
|
|
||||||
for i in range(num_chunks):
|
for i in range(num_chunks):
|
||||||
x_chunk = chunks[i,:]
|
x_chunk = chunks[i,:]
|
||||||
# encode the chunk
|
# encode the chunk
|
||||||
y_chunk = self.encode(x_chunk)
|
y_chunk = self.encode(x_chunk)
|
||||||
print(f'y_chunk shape: {y_chunk.shape}')
|
|
||||||
# figure out where to put the audio along the time domain
|
# figure out where to put the audio along the time domain
|
||||||
if i == num_chunks-1:
|
if i == num_chunks-1:
|
||||||
# final chunk always goes at the end
|
# final chunk always goes at the end
|
||||||
|
|||||||
@@ -213,8 +213,6 @@ class Video_Global(Conditioner):
|
|||||||
self.global_proj = nn.Sequential(nn.Linear(output_dim, global_dim))
|
self.global_proj = nn.Sequential(nn.Linear(output_dim, global_dim))
|
||||||
|
|
||||||
def forward(self, x, device: tp.Any = "cuda"):
|
def forward(self, x, device: tp.Any = "cuda"):
|
||||||
# import ipdb
|
|
||||||
# ipdb.set_trace()
|
|
||||||
if not isinstance(x[0], torch.Tensor):
|
if not isinstance(x[0], torch.Tensor):
|
||||||
video_feats = []
|
video_feats = []
|
||||||
for path in x:
|
for path in x:
|
||||||
@@ -242,8 +240,6 @@ class Video_Sync(Conditioner):
|
|||||||
self.embedder = nn.Sequential(nn.Linear(dim, output_dim))
|
self.embedder = nn.Sequential(nn.Linear(dim, output_dim))
|
||||||
|
|
||||||
def forward(self, x, device: tp.Any = "cuda"):
|
def forward(self, x, device: tp.Any = "cuda"):
|
||||||
# import ipdb
|
|
||||||
# ipdb.set_trace()
|
|
||||||
if not isinstance(x[0], torch.Tensor):
|
if not isinstance(x[0], torch.Tensor):
|
||||||
video_feats = []
|
video_feats = []
|
||||||
for path in x:
|
for path in x:
|
||||||
@@ -269,8 +265,6 @@ class Text_Linear(Conditioner):
|
|||||||
self.embedder = nn.Sequential(nn.Linear(dim, output_dim))
|
self.embedder = nn.Sequential(nn.Linear(dim, output_dim))
|
||||||
|
|
||||||
def forward(self, x, device: tp.Any = "cuda"):
|
def forward(self, x, device: tp.Any = "cuda"):
|
||||||
# import ipdb
|
|
||||||
# ipdb.set_trace()
|
|
||||||
if not isinstance(x[0], torch.Tensor):
|
if not isinstance(x[0], torch.Tensor):
|
||||||
video_feats = []
|
video_feats = []
|
||||||
for path in x:
|
for path in x:
|
||||||
@@ -295,8 +289,6 @@ class mm_unchang(Conditioner):
|
|||||||
super().__init__(dim, output_dim)
|
super().__init__(dim, output_dim)
|
||||||
|
|
||||||
def forward(self, x, device: tp.Any = "cuda"):
|
def forward(self, x, device: tp.Any = "cuda"):
|
||||||
# import ipdb
|
|
||||||
# ipdb.set_trace()
|
|
||||||
if not isinstance(x[0], torch.Tensor):
|
if not isinstance(x[0], torch.Tensor):
|
||||||
video_feats = []
|
video_feats = []
|
||||||
for path in x:
|
for path in x:
|
||||||
@@ -349,8 +341,6 @@ class CLIPConditioner(Conditioner):
|
|||||||
|
|
||||||
self.model.to(device)
|
self.model.to(device)
|
||||||
self.proj_out.to(device)
|
self.proj_out.to(device)
|
||||||
# import ipdb
|
|
||||||
# ipdb.set_trace()
|
|
||||||
|
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
if not isinstance(images[0], torch.Tensor):
|
if not isinstance(images[0], torch.Tensor):
|
||||||
|
|||||||
@@ -21,25 +21,6 @@ def _get_create_multi_conditioner_from_conditioning_config():
|
|||||||
from prismaudio_core.factory import create_multi_conditioner_from_conditioning_config
|
from prismaudio_core.factory import create_multi_conditioner_from_conditioning_config
|
||||||
return create_multi_conditioner_from_conditioning_config
|
return create_multi_conditioner_from_conditioning_config
|
||||||
|
|
||||||
from time import time
|
|
||||||
|
|
||||||
class Profiler:
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.ticks = [[time(), None]]
|
|
||||||
|
|
||||||
def tick(self, msg):
|
|
||||||
self.ticks.append([time(), msg])
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
rep = 80 * "=" + "\n"
|
|
||||||
for i in range(1, len(self.ticks)):
|
|
||||||
msg = self.ticks[i][1]
|
|
||||||
ellapsed = self.ticks[i][0] - self.ticks[i - 1][0]
|
|
||||||
rep += msg + f": {ellapsed*1000:.2f}ms\n"
|
|
||||||
rep += 80 * "=" + "\n\n\n"
|
|
||||||
return rep
|
|
||||||
|
|
||||||
class DiffusionModel(nn.Module):
|
class DiffusionModel(nn.Module):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
@@ -176,8 +157,7 @@ class ConditionedDiffusionModelWrapper(nn.Module):
|
|||||||
|
|
||||||
cross_attention_input.append(cross_attn_in)
|
cross_attention_input.append(cross_attn_in)
|
||||||
cross_attention_masks.append(cross_attn_mask)
|
cross_attention_masks.append(cross_attn_mask)
|
||||||
# import ipdb
|
|
||||||
# ipdb.set_trace()
|
|
||||||
cross_attention_input = torch.cat(cross_attention_input, dim=1)
|
cross_attention_input = torch.cat(cross_attention_input, dim=1)
|
||||||
cross_attention_masks = torch.cat(cross_attention_masks, dim=1)
|
cross_attention_masks = torch.cat(cross_attention_masks, dim=1)
|
||||||
|
|
||||||
@@ -314,10 +294,6 @@ class UNetCFG1DWrapper(ConditionedDiffusionModel):
|
|||||||
prepend_cond=None,
|
prepend_cond=None,
|
||||||
prepend_cond_mask=None,
|
prepend_cond_mask=None,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
p = Profiler()
|
|
||||||
|
|
||||||
p.tick("start")
|
|
||||||
|
|
||||||
channels_list = None
|
channels_list = None
|
||||||
if input_concat_cond is not None:
|
if input_concat_cond is not None:
|
||||||
channels_list = [input_concat_cond]
|
channels_list = [input_concat_cond]
|
||||||
@@ -337,9 +313,6 @@ class UNetCFG1DWrapper(ConditionedDiffusionModel):
|
|||||||
negative_embedding_mask=negative_cross_attn_mask,
|
negative_embedding_mask=negative_cross_attn_mask,
|
||||||
**kwargs)
|
**kwargs)
|
||||||
|
|
||||||
p.tick("UNetCFG1D forward")
|
|
||||||
|
|
||||||
#print(f"Profiler: {p}")
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
class UNet1DCondWrapper(ConditionedDiffusionModel):
|
class UNet1DCondWrapper(ConditionedDiffusionModel):
|
||||||
@@ -613,54 +586,6 @@ class DiTWrapper(ConditionedDiffusionModel):
|
|||||||
global_embed=global_cond,
|
global_embed=global_cond,
|
||||||
**kwargs)
|
**kwargs)
|
||||||
|
|
||||||
class MMDiTWrapper(ConditionedDiffusionModel):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*args,
|
|
||||||
**kwargs
|
|
||||||
):
|
|
||||||
super().__init__(supports_cross_attention=True, supports_global_cond=False, supports_input_concat=False)
|
|
||||||
|
|
||||||
self.model = MMAudio(*args, **kwargs)
|
|
||||||
|
|
||||||
# with torch.no_grad():
|
|
||||||
# for param in self.model.parameters():
|
|
||||||
# param *= 0.5
|
|
||||||
|
|
||||||
def forward(self,
|
|
||||||
x,
|
|
||||||
t,
|
|
||||||
clip_f,
|
|
||||||
sync_f,
|
|
||||||
text_f,
|
|
||||||
inpaint_masked_input=None,
|
|
||||||
t5_features=None,
|
|
||||||
metaclip_global_text_features=None,
|
|
||||||
cfg_scale=1.0,
|
|
||||||
cfg_dropout_prob: float = 0.0,
|
|
||||||
batch_cfg: bool = True,
|
|
||||||
rescale_cfg: bool = False,
|
|
||||||
scale_phi: float = 0.0,
|
|
||||||
**kwargs):
|
|
||||||
|
|
||||||
# breakpoint()
|
|
||||||
assert batch_cfg, "batch_cfg must be True for DiTWrapper"
|
|
||||||
#assert negative_input_concat_cond is None, "negative_input_concat_cond is not supported for DiTWrapper"
|
|
||||||
|
|
||||||
return self.model(
|
|
||||||
latent=x,
|
|
||||||
t=t,
|
|
||||||
clip_f=clip_f,
|
|
||||||
sync_f=sync_f,
|
|
||||||
text_f=text_f,
|
|
||||||
inpaint_masked_input=inpaint_masked_input,
|
|
||||||
t5_features=t5_features,
|
|
||||||
metaclip_global_text_features=metaclip_global_text_features,
|
|
||||||
cfg_scale=cfg_scale,
|
|
||||||
cfg_dropout_prob=cfg_dropout_prob,
|
|
||||||
scale_phi=scale_phi,
|
|
||||||
**kwargs)
|
|
||||||
|
|
||||||
class MMConditionedDiffusionModelWrapper(ConditionedDiffusionModel):
|
class MMConditionedDiffusionModelWrapper(ConditionedDiffusionModel):
|
||||||
"""
|
"""
|
||||||
A diffusion model that takes in conditioning
|
A diffusion model that takes in conditioning
|
||||||
@@ -739,8 +664,6 @@ class MMConditionedDiffusionModelWrapper(ConditionedDiffusionModel):
|
|||||||
}
|
}
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, t: torch.Tensor, cond: tp.Dict[str, tp.Any], **kwargs):
|
def forward(self, x: torch.Tensor, t: torch.Tensor, cond: tp.Dict[str, tp.Any], **kwargs):
|
||||||
# breakpoint()
|
|
||||||
# print(kwargs)
|
|
||||||
return self.model(x=x, t=t, **self.get_conditioning_inputs(cond), **kwargs)
|
return self.model(x=x, t=t, **self.get_conditioning_inputs(cond), **kwargs)
|
||||||
|
|
||||||
def generate(self, *args, **kwargs):
|
def generate(self, *args, **kwargs):
|
||||||
@@ -888,8 +811,8 @@ def create_diffusion_cond_from_config(config: tp.Dict[str, tp.Any]):
|
|||||||
diffusion_model = UNet1DCondWrapper(**diffusion_model_config)
|
diffusion_model = UNet1DCondWrapper(**diffusion_model_config)
|
||||||
elif diffusion_model_type == 'dit':
|
elif diffusion_model_type == 'dit':
|
||||||
diffusion_model = DiTWrapper(**diffusion_model_config)
|
diffusion_model = DiTWrapper(**diffusion_model_config)
|
||||||
elif diffusion_model_type == 'mmdit':
|
else:
|
||||||
diffusion_model = MMDiTWrapper(**diffusion_model_config)
|
raise NotImplementedError(f'Unknown diffusion model type: {diffusion_model_type}')
|
||||||
|
|
||||||
io_channels = model_config.get('io_channels', None)
|
io_channels = model_config.get('io_channels', None)
|
||||||
assert io_channels is not None, "Must specify io_channels in model config"
|
assert io_channels is not None, "Must specify io_channels in model config"
|
||||||
@@ -939,13 +862,8 @@ def create_diffusion_cond_from_config(config: tp.Dict[str, tp.Any]):
|
|||||||
wrapper_fn = ConditionedDiffusionModelWrapper
|
wrapper_fn = ConditionedDiffusionModelWrapper
|
||||||
extra_kwargs["diffusion_objective"] = diffusion_objective
|
extra_kwargs["diffusion_objective"] = diffusion_objective
|
||||||
|
|
||||||
elif model_type == "diffusion_prior":
|
else:
|
||||||
prior_type = model_config.get("prior_type", None)
|
raise NotImplementedError(f'Unknown model type: {model_type}')
|
||||||
assert prior_type is not None, "Must specify prior_type in diffusion prior model config"
|
|
||||||
|
|
||||||
if prior_type == "mono_stereo":
|
|
||||||
from .diffusion_prior import MonoToStereoDiffusionPrior
|
|
||||||
wrapper_fn = MonoToStereoDiffusionPrior
|
|
||||||
|
|
||||||
return wrapper_fn(
|
return wrapper_fn(
|
||||||
diffusion_model,
|
diffusion_model,
|
||||||
|
|||||||
@@ -223,8 +223,6 @@ class DiffusionTransformer(nn.Module):
|
|||||||
|
|
||||||
# Get the batch of timestep embeddings
|
# Get the batch of timestep embeddings
|
||||||
timestep_embed = self.to_timestep_embed(self.timestep_features(t[:, None])) # (b, embed_dim)
|
timestep_embed = self.to_timestep_embed(self.timestep_features(t[:, None])) # (b, embed_dim)
|
||||||
# import ipdb
|
|
||||||
# ipdb.set_trace()
|
|
||||||
# Timestep embedding is considered a global embedding. Add to the global conditioning if it exists
|
# Timestep embedding is considered a global embedding. Add to the global conditioning if it exists
|
||||||
if self.timestep_cond_type == "global":
|
if self.timestep_cond_type == "global":
|
||||||
if global_embed is not None:
|
if global_embed is not None:
|
||||||
|
|||||||
@@ -89,25 +89,6 @@ class AutoencoderPretransform(Pretransform):
|
|||||||
def load_state_dict(self, state_dict, strict=True):
|
def load_state_dict(self, state_dict, strict=True):
|
||||||
self.model.load_state_dict(state_dict, strict=strict)
|
self.model.load_state_dict(state_dict, strict=strict)
|
||||||
|
|
||||||
class WaveletPretransform(Pretransform):
|
|
||||||
def __init__(self, channels, levels, wavelet):
|
|
||||||
super().__init__(enable_grad=False, io_channels=channels, is_discrete=False)
|
|
||||||
|
|
||||||
from .wavelets import WaveletEncode1d, WaveletDecode1d
|
|
||||||
|
|
||||||
self.encoder = WaveletEncode1d(channels, levels, wavelet)
|
|
||||||
self.decoder = WaveletDecode1d(channels, levels, wavelet)
|
|
||||||
|
|
||||||
self.downsampling_ratio = 2 ** levels
|
|
||||||
self.io_channels = channels
|
|
||||||
self.encoded_channels = channels * self.downsampling_ratio
|
|
||||||
|
|
||||||
def encode(self, x):
|
|
||||||
return self.encoder(x)
|
|
||||||
|
|
||||||
def decode(self, z):
|
|
||||||
return self.decoder(z)
|
|
||||||
|
|
||||||
class PQMFPretransform(Pretransform):
|
class PQMFPretransform(Pretransform):
|
||||||
def __init__(self, attenuation=100, num_bands=16):
|
def __init__(self, attenuation=100, num_bands=16):
|
||||||
# TODO: Fix PQMF to take in in-channels
|
# TODO: Fix PQMF to take in in-channels
|
||||||
|
|||||||
@@ -421,9 +421,10 @@ class Attention(nn.Module):
|
|||||||
flex_attention_score_mod = None
|
flex_attention_score_mod = None
|
||||||
|
|
||||||
if flex_attention_block_mask is not None or flex_attention_score_mod is not None:
|
if flex_attention_block_mask is not None or flex_attention_score_mod is not None:
|
||||||
out = flex_attention_compiled(q,k,v,
|
raise NotImplementedError(
|
||||||
block_mask = flex_attention_block_mask,
|
"FlexAttention is not available in this build. "
|
||||||
score_mod = flex_attention_score_mod)
|
"flex_attention_compiled is not defined. Remove flex_attention_block_mask/flex_attention_score_mod arguments."
|
||||||
|
)
|
||||||
elif flash_attn_available:
|
elif flash_attn_available:
|
||||||
fa_dtype_in = q.dtype
|
fa_dtype_in = q.dtype
|
||||||
q, k, v = map(lambda t: rearrange(t, 'b h n d -> b n h d'), (q, k, v))
|
q, k, v = map(lambda t: rearrange(t, 'b h n d -> b n h d'), (q, k, v))
|
||||||
|
|||||||
@@ -18,7 +18,6 @@ def load_ckpt_state_dict(ckpt_path, prefix=None):
|
|||||||
def remove_weight_norm_from_model(model):
|
def remove_weight_norm_from_model(model):
|
||||||
for module in model.modules():
|
for module in model.modules():
|
||||||
if hasattr(module, "weight"):
|
if hasattr(module, "weight"):
|
||||||
print(f"Removing weight norm from {module}")
|
|
||||||
remove_weight_norm(module)
|
remove_weight_norm(module)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|||||||
Reference in New Issue
Block a user