From 3219ec0c3926b5f7d63c265ff2d231b527f7cd3e Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Thu, 2 Jul 2026 17:59:53 +0200 Subject: [PATCH] Track starred collection aliases in static extraction --- .../test_generate_popular_node_signatures.py | 101 ++++++++++++++++++ tools/generate_popular_node_signatures.py | 58 ++++++++-- 2 files changed, 150 insertions(+), 9 deletions(-) diff --git a/tests/test_generate_popular_node_signatures.py b/tests/test_generate_popular_node_signatures.py index 2f60eeb..70b9535 100644 --- a/tests/test_generate_popular_node_signatures.py +++ b/tests/test_generate_popular_node_signatures.py @@ -1439,6 +1439,31 @@ NODE_CLASS_MAPPINGS = { self.assertEqual({}, result["nodes"]) self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_return_types_starred_collection_alias_mutation_skips_node(self): + source = ''' +class StarredCollectionAliasMutatedReturnTypesNode: + RETURN_TYPES = ["IMAGE"] + *ALIASES, = (RETURN_TYPES,) + ALIASES[0].clear() + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image": ("IMAGE",), + }, + } + + +NODE_CLASS_MAPPINGS = { + "StarredCollectionAliasMutatedReturnTypesNode": StarredCollectionAliasMutatedReturnTypesNode, +} +''' + result = self._extract_source(source, "starred-collection-alias-mutated-return-types-pack") + + self.assertEqual({}, result["nodes"]) + self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_return_types_alias_subscript_assignment_skips_node(self): source = ''' class AliasSubscriptMutatedReturnTypesNode: @@ -2302,6 +2327,31 @@ Alias.RETURN_TYPES = ("MASK",) self.assertEqual({}, result["nodes"]) self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_module_class_starred_collection_alias_patch_after_mapping_skips_node(self): + source = ''' +class StarredCollectionAliasPatchedNode: + RETURN_TYPES = ("IMAGE",) + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image": ("IMAGE",), + }, + } + + +NODE_CLASS_MAPPINGS = { + "StarredCollectionAliasPatchedNode": StarredCollectionAliasPatchedNode, +} +*ALIASES, = (StarredCollectionAliasPatchedNode,) +ALIASES[0].RETURN_TYPES = ("MASK",) +''' + result = self._extract_source(source, "starred-collection-alias-patched-node-pack") + + self.assertEqual({}, result["nodes"]) + self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_module_class_globals_subscript_alias_patch_after_mapping_skips_node(self): source = ''' class GlobalsSubscriptAliasPatchedNode: @@ -2480,6 +2530,32 @@ NODE_CLASS_MAPPINGS = { self.assertEqual({}, result["nodes"]) self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_module_class_attribute_starred_collection_alias_mutation_skips_node(self): + source = ''' +class StarredCollectionAttributeAliasNode: + RETURN_TYPES = ["IMAGE"] + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image": ("IMAGE",), + }, + } + + +*ALIASES, = (StarredCollectionAttributeAliasNode.RETURN_TYPES,) +ALIASES[0].clear() + +NODE_CLASS_MAPPINGS = { + "StarredCollectionAttributeAliasNode": StarredCollectionAttributeAliasNode, +} +''' + result = self._extract_source(source, "starred-collection-attribute-alias-pack") + + self.assertEqual({}, result["nodes"]) + self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_module_class_attribute_transitive_alias_mutation_skips_node(self): source = ''' class TransitiveAttributeAliasNode: @@ -3000,6 +3076,31 @@ ALIAS.clear() self.assertEqual({}, result["nodes"]) self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_starred_collection_alias_mutation_invalidates_static_node_mapping(self): + source = ''' +class StarredCollectionAliasMutatedMappingNode: + RETURN_TYPES = ("IMAGE",) + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image": ("IMAGE",), + }, + } + + +NODE_CLASS_MAPPINGS = { + "StarredCollectionAliasMutatedMappingNode": StarredCollectionAliasMutatedMappingNode, +} +*ALIASES, = (NODE_CLASS_MAPPINGS,) +ALIASES[0].clear() +''' + result = self._extract_source(source, "starred-collection-alias-mutated-mapping-pack") + + self.assertEqual({}, result["nodes"]) + self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_rhs_mutating_call_to_node_mapping_skips_node(self): source = ''' class RhsMutatedMappingNode: diff --git a/tools/generate_popular_node_signatures.py b/tools/generate_popular_node_signatures.py index db8986a..d3c3bc9 100644 --- a/tools/generate_popular_node_signatures.py +++ b/tools/generate_popular_node_signatures.py @@ -918,6 +918,8 @@ def _unpack_target_value_pairs(target, value): return () pairs = [(targets[index], values[index]) for index in range(prefix_count)] + star_stop = len(values) - suffix_count if suffix_count else len(values) + pairs.append((targets[starred_index], ast.Tuple(elts=values[prefix_count:star_stop], ctx=ast.Load()))) if suffix_count: target_suffix = targets[-suffix_count:] value_suffix = values[-suffix_count:] @@ -925,17 +927,30 @@ def _unpack_target_value_pairs(target, value): return tuple(pairs) +def _alias_target_name(target): + if isinstance(target, ast.Name): + return target.id + if isinstance(target, ast.Starred) and isinstance(target.value, ast.Name): + return target.value.id + return None + + def _class_attr_alias_sources(value, name, aliases): - return isinstance(value, ast.Name) and (value.id == name or value.id in aliases) + if isinstance(value, ast.Name): + return value.id == name or value.id in aliases + if isinstance(value, (ast.Tuple, ast.List)): + return any(_class_attr_alias_sources(item, name, aliases) for item in value.elts) + return False def _update_class_attr_aliases_from_unpack(target, value, name, aliases): found = False for target_item, value_item in _unpack_target_value_pairs(target, value): - if not isinstance(target_item, ast.Name): + target_name = _alias_target_name(target_item) + if target_name is None: continue if _class_attr_alias_sources(value_item, name, aliases): - aliases.add(target_item.id) + aliases.add(target_name) found = True return found @@ -1154,6 +1169,11 @@ def _class_alias_sources(value, class_aliases, class_bindings): if value.id in class_bindings: return {value.id} return set() + if isinstance(value, (ast.Tuple, ast.List)): + sources = set() + for item in value.elts: + sources.update(_class_alias_sources(item, class_aliases, class_bindings)) + return sources name = _namespace_subscript_name(value) or _namespace_lookup_name(value) if name in class_aliases: @@ -1165,11 +1185,12 @@ def _class_alias_sources(value, class_aliases, class_bindings): def _update_class_alias_from_unpack(target, value, class_aliases, class_bindings): for target_item, value_item in _unpack_target_value_pairs(target, value): - if not isinstance(target_item, ast.Name): + target_name = _alias_target_name(target_item) + if target_name is None: continue sources = _class_alias_sources(value_item, class_aliases, class_bindings) if sources: - class_aliases[target_item.id] = sources + class_aliases[target_name] = sources def _update_class_aliases(stmt, class_aliases, class_bindings): @@ -1204,6 +1225,18 @@ def _expanded_class_attribute_names(names, class_aliases): def _class_attribute_alias_sources(value, class_attribute_aliases, class_aliases, class_bindings): if isinstance(value, ast.Name): return set(class_attribute_aliases.get(value.id, ())) + if isinstance(value, (ast.Tuple, ast.List)): + sources = set() + for item in value.elts: + sources.update( + _class_attribute_alias_sources( + item, + class_attribute_aliases, + class_aliases, + class_bindings, + ) + ) + return sources names = set() if isinstance(value, ast.Attribute) and value.attr in _CLASS_SIGNATURE_ATTRS: @@ -1230,7 +1263,8 @@ def _update_class_attribute_alias_from_unpack( class_bindings, ): for target_item, value_item in _unpack_target_value_pairs(target, value): - if not isinstance(target_item, ast.Name): + target_name = _alias_target_name(target_item) + if target_name is None: continue sources = _class_attribute_alias_sources( value_item, @@ -1239,7 +1273,7 @@ def _update_class_attribute_alias_from_unpack( class_bindings, ) if sources: - class_attribute_aliases[target_item.id] = sources + class_attribute_aliases[target_name] = sources def _class_attribute_alias_invalidated_names(stmt, class_attribute_aliases): @@ -1309,6 +1343,11 @@ def _module_dict_alias_sources(value, name, aliases): if value.id == name: return {name} return set(aliases.get(value.id, ())) + if isinstance(value, (ast.Tuple, ast.List)): + sources = set() + for item in value.elts: + sources.update(_module_dict_alias_sources(item, name, aliases)) + return sources namespace_name = _namespace_subscript_name(value) or _namespace_lookup_name(value) if namespace_name == name: @@ -1318,11 +1357,12 @@ def _module_dict_alias_sources(value, name, aliases): def _update_module_dict_alias_from_unpack(target, value, name, aliases): for target_item, value_item in _unpack_target_value_pairs(target, value): - if not isinstance(target_item, ast.Name): + target_name = _alias_target_name(target_item) + if target_name is None: continue sources = _module_dict_alias_sources(value_item, name, aliases) if sources: - aliases[target_item.id] = sources + aliases[target_name] = sources def _module_dict_alias_invalidated(stmt, aliases):