6bc3fd6443
Pure PyTorch SelVA source for SelvaModelLoader/FeatureExtractor/Sampler nodes. Imports rewritten from selva.* to selva_core.*. mel_converter.py: replaced librosa.filters.mel with pure-numpy implementation to avoid librosa→numba→NumPy version incompatibility in some ComfyUI environments. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
40 lines
887 B
Python
40 lines
887 B
Python
import logging
|
|
|
|
log = logging.getLogger()
|
|
|
|
|
|
def get_parameter_groups(model, cfg, print_log=False):
|
|
"""
|
|
Assign different weight decays and learning rates to different parameters.
|
|
Returns a parameter group which can be passed to the optimizer.
|
|
"""
|
|
weight_decay = cfg.weight_decay
|
|
base_lr = cfg.learning_rate
|
|
|
|
params = []
|
|
|
|
# inspired by detectron2
|
|
memo = set()
|
|
for name, param in model.named_parameters():
|
|
if not param.requires_grad:
|
|
continue
|
|
# Avoid duplicating parameters
|
|
if param in memo:
|
|
continue
|
|
memo.add(param)
|
|
|
|
if name.startswith('module'):
|
|
name = name[7:]
|
|
|
|
params.append(param)
|
|
|
|
parameter_groups = [
|
|
{
|
|
'params': params,
|
|
'lr': base_lr,
|
|
'weight_decay': weight_decay
|
|
},
|
|
]
|
|
|
|
return parameter_groups
|