From 79d9921ba696d280eb348acd81fb565471f1c331 Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Thu, 2 Jul 2026 17:37:06 +0200 Subject: [PATCH] Track starred unpack aliases in static extraction --- .../test_generate_popular_node_signatures.py | 127 ++++++++++++++++++ tools/generate_popular_node_signatures.py | 53 +++++--- 2 files changed, 160 insertions(+), 20 deletions(-) diff --git a/tests/test_generate_popular_node_signatures.py b/tests/test_generate_popular_node_signatures.py index df0f433..61331f6 100644 --- a/tests/test_generate_popular_node_signatures.py +++ b/tests/test_generate_popular_node_signatures.py @@ -1414,6 +1414,31 @@ NODE_CLASS_MAPPINGS = { self.assertEqual({}, result["nodes"]) self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_return_types_starred_unpacked_alias_mutation_skips_node(self): + source = ''' +class StarredUnpackedAliasMutatedReturnTypesNode: + RETURN_TYPES = ["IMAGE"] + ALIAS, *REST = (RETURN_TYPES, [], []) + ALIAS.clear() + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image": ("IMAGE",), + }, + } + + +NODE_CLASS_MAPPINGS = { + "StarredUnpackedAliasMutatedReturnTypesNode": StarredUnpackedAliasMutatedReturnTypesNode, +} +''' + result = self._extract_source(source, "starred-unpacked-alias-mutated-return-types-pack") + + self.assertEqual({}, result["nodes"]) + self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_return_types_alias_subscript_assignment_skips_node(self): source = ''' class AliasSubscriptMutatedReturnTypesNode: @@ -1490,6 +1515,32 @@ NODE_CLASS_MAPPINGS = { self.assertEqual({}, result["nodes"]) self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_return_names_starred_unpacked_alias_mutation_skips_node(self): + source = ''' +class StarredUnpackedAliasMutatedReturnNamesNode: + RETURN_TYPES = ("IMAGE",) + RETURN_NAMES = ["image"] + ALIAS, *REST = (RETURN_NAMES, [], []) + ALIAS.clear() + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image": ("IMAGE",), + }, + } + + +NODE_CLASS_MAPPINGS = { + "StarredUnpackedAliasMutatedReturnNamesNode": StarredUnpackedAliasMutatedReturnNamesNode, +} +''' + result = self._extract_source(source, "starred-unpacked-alias-mutated-return-names-pack") + + self.assertEqual({}, result["nodes"]) + self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_return_types_transitive_alias_mutation_skips_node(self): source = ''' class TransitiveAliasMutatedReturnTypesNode: @@ -2151,6 +2202,31 @@ Alias.RETURN_TYPES = ("MASK",) self.assertEqual({}, result["nodes"]) self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_module_class_starred_alias_patch_after_mapping_skips_node(self): + source = ''' +class StarredAliasPatchedNode: + RETURN_TYPES = ("IMAGE",) + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image": ("IMAGE",), + }, + } + + +NODE_CLASS_MAPPINGS = { + "StarredAliasPatchedNode": StarredAliasPatchedNode, +} +Alias, *REST = (StarredAliasPatchedNode, object(), object()) +Alias.RETURN_TYPES = ("MASK",) +''' + result = self._extract_source(source, "starred-alias-patched-node-pack") + + self.assertEqual({}, result["nodes"]) + self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_module_class_globals_subscript_alias_patch_after_mapping_skips_node(self): source = ''' class GlobalsSubscriptAliasPatchedNode: @@ -2303,6 +2379,32 @@ NODE_CLASS_MAPPINGS = { self.assertEqual({}, result["nodes"]) self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_module_class_attribute_starred_alias_mutation_skips_node(self): + source = ''' +class StarredAttributeAliasNode: + RETURN_TYPES = ["IMAGE"] + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image": ("IMAGE",), + }, + } + + +RET, *REST = (StarredAttributeAliasNode.RETURN_TYPES, [], []) +RET.clear() + +NODE_CLASS_MAPPINGS = { + "StarredAttributeAliasNode": StarredAttributeAliasNode, +} +''' + result = self._extract_source(source, "starred-attribute-alias-pack") + + self.assertEqual({}, result["nodes"]) + self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_module_class_attribute_transitive_alias_mutation_skips_node(self): source = ''' class TransitiveAttributeAliasNode: @@ -2700,6 +2802,31 @@ ALIAS.clear() self.assertEqual({}, result["nodes"]) self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_starred_unpacked_alias_mutation_invalidates_static_node_mapping(self): + source = ''' +class StarredUnpackedAliasMutatedMappingNode: + RETURN_TYPES = ("IMAGE",) + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image": ("IMAGE",), + }, + } + + +NODE_CLASS_MAPPINGS = { + "StarredUnpackedAliasMutatedMappingNode": StarredUnpackedAliasMutatedMappingNode, +} +ALIAS, *REST = (NODE_CLASS_MAPPINGS, {}, {}) +ALIAS.clear() +''' + result = self._extract_source(source, "starred-unpacked-alias-mutated-mapping-pack") + + self.assertEqual({}, result["nodes"]) + self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_rhs_mutating_call_to_node_mapping_skips_node(self): source = ''' class RhsMutatedMappingNode: diff --git a/tools/generate_popular_node_signatures.py b/tools/generate_popular_node_signatures.py index 704ed0b..bc36354 100644 --- a/tools/generate_popular_node_signatures.py +++ b/tools/generate_popular_node_signatures.py @@ -759,17 +759,42 @@ def _input_types_decorators_are_supported(decorators, classmethod_shadowed): return True +def _unpack_target_value_pairs(target, value): + if not isinstance(target, (ast.Tuple, ast.List)) or not isinstance(value, (ast.Tuple, ast.List)): + return () + + targets = target.elts + values = value.elts + starred_indices = [index for index, item in enumerate(targets) if isinstance(item, ast.Starred)] + if not starred_indices: + if len(targets) != len(values): + return () + return tuple(zip(targets, values)) + + if len(starred_indices) != 1: + return () + + starred_index = starred_indices[0] + prefix_count = starred_index + suffix_count = len(targets) - starred_index - 1 + if len(values) < prefix_count + suffix_count: + return () + + pairs = [(targets[index], values[index]) for index in range(prefix_count)] + if suffix_count: + target_suffix = targets[-suffix_count:] + value_suffix = values[-suffix_count:] + pairs.extend(zip(target_suffix, value_suffix)) + return tuple(pairs) + + def _class_attr_alias_sources(value, name, aliases): return isinstance(value, ast.Name) and (value.id == name or value.id in aliases) def _update_class_attr_aliases_from_unpack(target, value, name, aliases): - if not isinstance(target, (ast.Tuple, ast.List)) or not isinstance(value, (ast.Tuple, ast.List)): - return False - if len(target.elts) != len(value.elts): - return False found = False - for target_item, value_item in zip(target.elts, value.elts): + for target_item, value_item in _unpack_target_value_pairs(target, value): if not isinstance(target_item, ast.Name): continue if _class_attr_alias_sources(value_item, name, aliases): @@ -997,11 +1022,7 @@ def _class_alias_sources(value, class_aliases, class_bindings): def _update_class_alias_from_unpack(target, value, class_aliases, class_bindings): - if not isinstance(target, (ast.Tuple, ast.List)) or not isinstance(value, (ast.Tuple, ast.List)): - return - if len(target.elts) != len(value.elts): - return - for target_item, value_item in zip(target.elts, value.elts): + for target_item, value_item in _unpack_target_value_pairs(target, value): if not isinstance(target_item, ast.Name): continue sources = _class_alias_sources(value_item, class_aliases, class_bindings) @@ -1066,11 +1087,7 @@ def _update_class_attribute_alias_from_unpack( class_aliases, class_bindings, ): - if not isinstance(target, (ast.Tuple, ast.List)) or not isinstance(value, (ast.Tuple, ast.List)): - return - if len(target.elts) != len(value.elts): - return - for target_item, value_item in zip(target.elts, value.elts): + for target_item, value_item in _unpack_target_value_pairs(target, value): if not isinstance(target_item, ast.Name): continue sources = _class_attribute_alias_sources( @@ -1151,11 +1168,7 @@ def _module_dict_alias_sources(value, name, aliases): def _update_module_dict_alias_from_unpack(target, value, name, aliases): - if not isinstance(target, (ast.Tuple, ast.List)) or not isinstance(value, (ast.Tuple, ast.List)): - return - if len(target.elts) != len(value.elts): - return - for target_item, value_item in zip(target.elts, value.elts): + for target_item, value_item in _unpack_target_value_pairs(target, value): if not isinstance(target_item, ast.Name): continue sources = _module_dict_alias_sources(value_item, name, aliases)