Track unpacked and class attribute aliases

This commit is contained in:
2026-07-02 16:12:32 +02:00
parent 05fa411d47
commit 9752248ee9
2 changed files with 149 additions and 2 deletions
@@ -1307,6 +1307,31 @@ NODE_CLASS_MAPPINGS = {
self.assertEqual({}, result["nodes"])
self.assertEqual("no_static_nodes", result["pack"]["status"])
def test_return_types_unpacked_alias_mutation_skips_node(self):
source = '''
class UnpackedAliasMutatedReturnTypesNode:
RETURN_TYPES = ["IMAGE"]
ALIAS, = (RETURN_TYPES,)
ALIAS.clear()
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"image": ("IMAGE",),
},
}
NODE_CLASS_MAPPINGS = {
"UnpackedAliasMutatedReturnTypesNode": UnpackedAliasMutatedReturnTypesNode,
}
'''
result = self._extract_source(source, "unpacked-alias-mutated-return-types-pack")
self.assertEqual({}, result["nodes"])
self.assertEqual("no_static_nodes", result["pack"]["status"])
def test_return_types_alias_subscript_assignment_skips_node(self):
source = '''
class AliasSubscriptMutatedReturnTypesNode:
@@ -1916,6 +1941,57 @@ Alias.RETURN_TYPES = ("MASK",)
self.assertEqual({}, result["nodes"])
self.assertEqual("no_static_nodes", result["pack"]["status"])
def test_module_class_attribute_alias_mutation_before_mapping_skips_node(self):
source = '''
class PreMappingAttributeAliasNode:
RETURN_TYPES = ["IMAGE"]
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"image": ("IMAGE",),
},
}
RET = PreMappingAttributeAliasNode.RETURN_TYPES
RET.clear()
NODE_CLASS_MAPPINGS = {
"PreMappingAttributeAliasNode": PreMappingAttributeAliasNode,
}
'''
result = self._extract_source(source, "pre-mapping-attribute-alias-pack")
self.assertEqual({}, result["nodes"])
self.assertEqual("no_static_nodes", result["pack"]["status"])
def test_module_class_attribute_alias_mutation_after_mapping_skips_node(self):
source = '''
class PostMappingAttributeAliasNode:
RETURN_TYPES = ["IMAGE"]
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"image": ("IMAGE",),
},
}
NODE_CLASS_MAPPINGS = {
"PostMappingAttributeAliasNode": PostMappingAttributeAliasNode,
}
RET = PostMappingAttributeAliasNode.RETURN_TYPES
RET.clear()
'''
result = self._extract_source(source, "post-mapping-attribute-alias-pack")
self.assertEqual({}, result["nodes"])
self.assertEqual("no_static_nodes", result["pack"]["status"])
def test_module_class_tuple_alias_patch_after_mapping_skips_node(self):
source = '''
class TupleAliasPatchedNode:
+73 -2
View File
@@ -36,6 +36,7 @@ _MUTATING_METHODS = {
_CONTROL_FLOW_TYPES = (ast.If, ast.For, ast.AsyncFor, ast.While, ast.Try, ast.With, ast.AsyncWith, ast.Match)
if hasattr(ast, "TryStar"):
_CONTROL_FLOW_TYPES += (ast.TryStar,)
_CLASS_SIGNATURE_ATTRS = {"INPUT_TYPES", "RETURN_NAMES", "RETURN_TYPES"}
def _literal(node, env, allow_mutable_env=True):
@@ -589,6 +590,25 @@ def _input_types_decorators_are_supported(decorators):
return all(isinstance(decorator, ast.Name) and decorator.id == "classmethod" for decorator in decorators)
def _class_attr_alias_sources(value, name, aliases):
return isinstance(value, ast.Name) and (value.id == name or value.id in aliases)
def _update_class_attr_aliases_from_unpack(target, value, name, aliases):
if not isinstance(target, (ast.Tuple, ast.List)) or not isinstance(value, (ast.Tuple, ast.List)):
return False
if len(target.elts) != len(value.elts):
return False
found = False
for target_item, value_item in zip(target.elts, value.elts):
if not isinstance(target_item, ast.Name):
continue
if _class_attr_alias_sources(value_item, name, aliases):
aliases.add(target_item.id)
found = True
return found
def _class_attr(cls, name, env):
value = _MISSING
aliases = set()
@@ -604,11 +624,16 @@ def _class_attr(cls, name, env):
len(stmt.targets) == 1
and isinstance(stmt.targets[0], ast.Name)
and stmt.targets[0].id != name
and isinstance(stmt.value, ast.Name)
and (stmt.value.id == name or stmt.value.id in aliases)
and _class_attr_alias_sources(stmt.value, name, aliases)
):
aliases.add(stmt.targets[0].id)
continue
if (
len(stmt.targets) == 1
and name not in target_names
and _update_class_attr_aliases_from_unpack(stmt.targets[0], stmt.value, name, aliases)
):
continue
if aliases.intersection(target_names):
value = _INVALID
aliases.difference_update(target_names)
@@ -830,22 +855,68 @@ def _expanded_class_attribute_names(names, class_aliases):
return expanded
def _class_attribute_alias_sources(value, class_aliases, class_bindings):
if not isinstance(value, ast.Attribute) or value.attr not in _CLASS_SIGNATURE_ATTRS:
return set()
name = _root_name(value.value)
if name in class_aliases:
return set(class_aliases[name])
if name in class_bindings:
return {name}
return set()
def _class_attribute_alias_invalidated_names(stmt, class_attribute_aliases):
names = (
_mutating_call_target_names(stmt)
| _assignment_target_names(stmt)
| _delete_target_names(stmt)
| _bound_names(stmt)
)
invalidated = set()
for name in names:
invalidated.update(class_attribute_aliases.get(name, ()))
return invalidated
def _update_class_attribute_aliases(stmt, class_attribute_aliases, class_aliases, class_bindings):
rebound_names = _assignment_target_names(stmt) | _delete_target_names(stmt) | _bound_names(stmt)
for name in rebound_names:
class_attribute_aliases.pop(name, None)
if isinstance(stmt, ast.Assign) and len(stmt.targets) == 1 and isinstance(stmt.targets[0], ast.Name):
sources = _class_attribute_alias_sources(stmt.value, class_aliases, class_bindings)
if sources:
class_attribute_aliases[stmt.targets[0].id] = sources
elif isinstance(stmt, ast.AnnAssign) and isinstance(stmt.target, ast.Name) and stmt.value is not None:
sources = _class_attribute_alias_sources(stmt.value, class_aliases, class_bindings)
if sources:
class_attribute_aliases[stmt.target.id] = sources
def _final_module_dict(tree, name, value_converter, value_invalidated_by_names=None):
value_invalidated_by_names = value_invalidated_by_names or (lambda _value, _names: False)
value = _MISSING
env = {}
class_bindings = {}
class_aliases = {}
class_attribute_aliases = {}
def advance_module_state(stmt):
_invalidate_class_bindings(
class_bindings,
_class_attribute_alias_invalidated_names(stmt, class_attribute_aliases),
)
_apply_module_stmt_to_env(stmt, env, class_bindings)
_update_class_aliases(stmt, class_aliases, class_bindings)
_update_class_attribute_aliases(stmt, class_attribute_aliases, class_aliases, class_bindings)
for stmt in tree.body:
class_attr_names = _expanded_class_attribute_names(
_class_attribute_mutation_target_names(stmt),
class_aliases,
)
class_attr_names.update(_class_attribute_alias_invalidated_names(stmt, class_attribute_aliases))
if (
value not in (_MISSING, _INVALID)
and class_attr_names