From d1f49e7c95f9a10818dce29921797b45843fbda9 Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Thu, 2 Jul 2026 16:47:39 +0200 Subject: [PATCH] Track mapping and class attribute aliases --- .../test_generate_popular_node_signatures.py | 78 ++++++++++++++ tools/generate_popular_node_signatures.py | 102 +++++++++++++++++- 2 files changed, 177 insertions(+), 3 deletions(-) diff --git a/tests/test_generate_popular_node_signatures.py b/tests/test_generate_popular_node_signatures.py index 92abbc4..d098d96 100644 --- a/tests/test_generate_popular_node_signatures.py +++ b/tests/test_generate_popular_node_signatures.py @@ -2047,6 +2047,59 @@ NODE_CLASS_MAPPINGS = { self.assertEqual({}, result["nodes"]) self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_module_class_attribute_tuple_alias_mutation_skips_node(self): + source = ''' +class TupleAttributeAliasNode: + RETURN_TYPES = ["IMAGE"] + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image": ("IMAGE",), + }, + } + + +RET, = (TupleAttributeAliasNode.RETURN_TYPES,) +RET.clear() + +NODE_CLASS_MAPPINGS = { + "TupleAttributeAliasNode": TupleAttributeAliasNode, +} +''' + result = self._extract_source(source, "tuple-attribute-alias-pack") + + self.assertEqual({}, result["nodes"]) + self.assertEqual("no_static_nodes", result["pack"]["status"]) + + def test_module_class_attribute_transitive_alias_mutation_skips_node(self): + source = ''' +class TransitiveAttributeAliasNode: + RETURN_TYPES = ["IMAGE"] + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image": ("IMAGE",), + }, + } + + +RET = TransitiveAttributeAliasNode.RETURN_TYPES +ALIAS = RET +ALIAS.clear() + +NODE_CLASS_MAPPINGS = { + "TransitiveAttributeAliasNode": TransitiveAttributeAliasNode, +} +''' + result = self._extract_source(source, "transitive-attribute-alias-pack") + + self.assertEqual({}, result["nodes"]) + self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_module_class_attribute_alias_mutation_after_mapping_skips_node(self): source = ''' class PostMappingAttributeAliasNode: @@ -2172,6 +2225,31 @@ NODE_CLASS_MAPPINGS.clear() self.assertEqual({}, result["nodes"]) self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_unpacked_alias_mutation_invalidates_static_node_mapping(self): + source = ''' +class UnpackedAliasMutatedMappingNode: + RETURN_TYPES = ("IMAGE",) + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image": ("IMAGE",), + }, + } + + +NODE_CLASS_MAPPINGS = { + "UnpackedAliasMutatedMappingNode": UnpackedAliasMutatedMappingNode, +} +ALIAS, = (NODE_CLASS_MAPPINGS,) +ALIAS.clear() +''' + result = self._extract_source(source, "unpacked-alias-mutated-mapping-pack") + + self.assertEqual({}, result["nodes"]) + self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_rhs_mutating_call_to_node_mapping_skips_node(self): source = ''' class RhsMutatedMappingNode: diff --git a/tools/generate_popular_node_signatures.py b/tools/generate_popular_node_signatures.py index 7a2ffb1..80fb47e 100644 --- a/tools/generate_popular_node_signatures.py +++ b/tools/generate_popular_node_signatures.py @@ -894,7 +894,9 @@ def _expanded_class_attribute_names(names, class_aliases): return expanded -def _class_attribute_alias_sources(value, class_aliases, class_bindings): +def _class_attribute_alias_sources(value, class_attribute_aliases, class_aliases, class_bindings): + if isinstance(value, ast.Name): + return set(class_attribute_aliases.get(value.id, ())) if not isinstance(value, ast.Attribute) or value.attr not in _CLASS_SIGNATURE_ATTRS: return set() name = _root_name(value.value) @@ -905,6 +907,30 @@ def _class_attribute_alias_sources(value, class_aliases, class_bindings): return set() +def _update_class_attribute_alias_from_unpack( + target, + value, + class_attribute_aliases, + class_aliases, + class_bindings, +): + if not isinstance(target, (ast.Tuple, ast.List)) or not isinstance(value, (ast.Tuple, ast.List)): + return + if len(target.elts) != len(value.elts): + return + for target_item, value_item in zip(target.elts, value.elts): + if not isinstance(target_item, ast.Name): + continue + sources = _class_attribute_alias_sources( + value_item, + class_attribute_aliases, + class_aliases, + class_bindings, + ) + if sources: + class_attribute_aliases[target_item.id] = sources + + def _class_attribute_alias_invalidated_names(stmt, class_attribute_aliases): names = ( _mutating_call_target_names(stmt) @@ -924,11 +950,29 @@ def _update_class_attribute_aliases(stmt, class_attribute_aliases, class_aliases class_attribute_aliases.pop(name, None) if isinstance(stmt, ast.Assign) and len(stmt.targets) == 1 and isinstance(stmt.targets[0], ast.Name): - sources = _class_attribute_alias_sources(stmt.value, class_aliases, class_bindings) + sources = _class_attribute_alias_sources( + stmt.value, + class_attribute_aliases, + class_aliases, + class_bindings, + ) if sources: class_attribute_aliases[stmt.targets[0].id] = sources + elif isinstance(stmt, ast.Assign) and len(stmt.targets) == 1: + _update_class_attribute_alias_from_unpack( + stmt.targets[0], + stmt.value, + class_attribute_aliases, + class_aliases, + class_bindings, + ) elif isinstance(stmt, ast.AnnAssign) and isinstance(stmt.target, ast.Name) and stmt.value is not None: - sources = _class_attribute_alias_sources(stmt.value, class_aliases, class_bindings) + sources = _class_attribute_alias_sources( + stmt.value, + class_attribute_aliases, + class_aliases, + class_bindings, + ) if sources: class_attribute_aliases[stmt.target.id] = sources @@ -942,6 +986,54 @@ def _module_class_attribute_invalidated_names(stmt, class_aliases, class_attribu return names +def _module_dict_alias_sources(value, name, aliases): + if not isinstance(value, ast.Name): + return set() + if value.id == name: + return {name} + return set(aliases.get(value.id, ())) + + +def _update_module_dict_alias_from_unpack(target, value, name, aliases): + if not isinstance(target, (ast.Tuple, ast.List)) or not isinstance(value, (ast.Tuple, ast.List)): + return + if len(target.elts) != len(value.elts): + return + for target_item, value_item in zip(target.elts, value.elts): + if not isinstance(target_item, ast.Name): + continue + sources = _module_dict_alias_sources(value_item, name, aliases) + if sources: + aliases[target_item.id] = sources + + +def _module_dict_alias_invalidated(stmt, aliases): + names = ( + _mutating_call_target_names(stmt) + | _assignment_target_names(stmt) + | _delete_target_names(stmt) + | _bound_names(stmt) + ) + return any(name in aliases for name in names) + + +def _update_module_dict_aliases(stmt, name, aliases): + rebound_names = _assignment_target_names(stmt) | _delete_target_names(stmt) | _bound_names(stmt) + for rebound_name in rebound_names: + aliases.pop(rebound_name, None) + + if isinstance(stmt, ast.Assign) and len(stmt.targets) == 1 and isinstance(stmt.targets[0], ast.Name): + sources = _module_dict_alias_sources(stmt.value, name, aliases) + if sources: + aliases[stmt.targets[0].id] = sources + elif isinstance(stmt, ast.Assign) and len(stmt.targets) == 1: + _update_module_dict_alias_from_unpack(stmt.targets[0], stmt.value, name, aliases) + elif isinstance(stmt, ast.AnnAssign) and isinstance(stmt.target, ast.Name) and stmt.value is not None: + sources = _module_dict_alias_sources(stmt.value, name, aliases) + if sources: + aliases[stmt.target.id] = sources + + def _final_module_dict(tree, name, value_converter, value_invalidated_by_names=None, return_state=False): value_invalidated_by_names = value_invalidated_by_names or (lambda _value, _names: False) value = _MISSING @@ -949,6 +1041,7 @@ def _final_module_dict(tree, name, value_converter, value_invalidated_by_names=N class_bindings = {} class_aliases = {} class_attribute_aliases = {} + module_dict_aliases = {} def advance_module_state(stmt): _invalidate_class_bindings( @@ -958,6 +1051,7 @@ def _final_module_dict(tree, name, value_converter, value_invalidated_by_names=N _apply_module_stmt_to_env(stmt, env, class_bindings) _update_class_aliases(stmt, class_aliases, class_bindings) _update_class_attribute_aliases(stmt, class_attribute_aliases, class_aliases, class_bindings) + _update_module_dict_aliases(stmt, name, module_dict_aliases) for stmt in tree.body: class_attr_names = _module_class_attribute_invalidated_names(stmt, class_aliases, class_attribute_aliases) @@ -969,6 +1063,8 @@ def _final_module_dict(tree, name, value_converter, value_invalidated_by_names=N value = _INVALID if name in _mutating_call_target_names(stmt): value = _INVALID + if _module_dict_alias_invalidated(stmt, module_dict_aliases): + value = _INVALID if isinstance(stmt, ast.Assign): if not _name_is_assigned(stmt, name): if isinstance(stmt.value, ast.Name) and stmt.value.id == name: