diff --git a/tests/test_generate_popular_node_signatures.py b/tests/test_generate_popular_node_signatures.py index 8328fa2..c47b183 100644 --- a/tests/test_generate_popular_node_signatures.py +++ b/tests/test_generate_popular_node_signatures.py @@ -1486,6 +1486,31 @@ NODE_CLASS_MAPPINGS = { self.assertEqual({}, result["nodes"]) self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_return_types_chained_alias_mutation_skips_node(self): + source = ''' +class ChainedAliasMutatedReturnTypesNode: + RETURN_TYPES = ["IMAGE"] + A = B = RETURN_TYPES + A.append("MASK") + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image": ("IMAGE",), + }, + } + + +NODE_CLASS_MAPPINGS = { + "ChainedAliasMutatedReturnTypesNode": ChainedAliasMutatedReturnTypesNode, +} +''' + result = self._extract_source(source, "chained-alias-mutated-return-types-pack") + + self.assertEqual({}, result["nodes"]) + self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_return_types_unpacked_alias_mutation_skips_node(self): source = ''' class UnpackedAliasMutatedReturnTypesNode: @@ -2446,6 +2471,31 @@ Alias.RETURN_TYPES = ("MASK",) self.assertEqual({}, result["nodes"]) self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_module_class_chained_alias_patch_after_mapping_skips_node(self): + source = ''' +class ChainedAliasPatchedNode: + RETURN_TYPES = ("IMAGE",) + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image": ("IMAGE",), + }, + } + + +NODE_CLASS_MAPPINGS = { + "ChainedAliasPatchedNode": ChainedAliasPatchedNode, +} +A = B = ChainedAliasPatchedNode +A.RETURN_TYPES = ("MASK",) +''' + result = self._extract_source(source, "chained-alias-patched-node-pack") + + self.assertEqual({}, result["nodes"]) + self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_module_class_starred_alias_patch_after_mapping_skips_node(self): source = ''' class StarredAliasPatchedNode: @@ -2622,6 +2672,32 @@ NODE_CLASS_MAPPINGS = { self.assertEqual({}, result["nodes"]) self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_module_class_attribute_chained_alias_mutation_skips_node(self): + source = ''' +class ChainedAttributeAliasNode: + RETURN_TYPES = ["IMAGE"] + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image": ("IMAGE",), + }, + } + + +A = B = ChainedAttributeAliasNode.RETURN_TYPES +A.append("MASK") + +NODE_CLASS_MAPPINGS = { + "ChainedAttributeAliasNode": ChainedAttributeAliasNode, +} +''' + result = self._extract_source(source, "chained-attribute-alias-pack") + + self.assertEqual({}, result["nodes"]) + self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_module_class_attribute_tuple_alias_mutation_skips_node(self): source = ''' class TupleAttributeAliasNode: diff --git a/tools/generate_popular_node_signatures.py b/tools/generate_popular_node_signatures.py index 1c6fcea..b485400 100644 --- a/tools/generate_popular_node_signatures.py +++ b/tools/generate_popular_node_signatures.py @@ -1009,6 +1009,18 @@ def _class_attr(cls, name, env): value = _INVALID if isinstance(stmt, ast.Assign): target_names = _assignment_target_names(stmt) + if len(stmt.targets) > 1 and _class_attr_alias_sources(stmt.value, name, aliases): + target_aliases = [] + for target in stmt.targets: + target_name = _alias_target_name(target) + if target_name is None: + value = _INVALID + target_aliases = [] + break + target_aliases.append(target_name) + aliases.update(alias for alias in target_aliases if alias != name) + if name not in target_names: + continue if ( len(stmt.targets) == 1 and isinstance(stmt.targets[0], ast.Name) @@ -1290,6 +1302,13 @@ def _update_class_aliases(stmt, class_aliases, class_bindings): sources = _class_alias_sources(stmt.value, class_aliases, class_bindings) if sources: class_aliases[stmt.targets[0].id] = sources + elif isinstance(stmt, ast.Assign) and len(stmt.targets) > 1: + sources = _class_alias_sources(stmt.value, class_aliases, class_bindings) + if sources: + for target in stmt.targets: + target_name = _alias_target_name(target) + if target_name is not None: + class_aliases[target_name] = sources elif isinstance(stmt, ast.Assign) and len(stmt.targets) == 1: _update_class_alias_from_unpack(stmt.targets[0], stmt.value, class_aliases, class_bindings) elif isinstance(stmt, ast.AnnAssign) and isinstance(stmt.target, ast.Name) and stmt.value is not None: @@ -1387,6 +1406,18 @@ def _update_class_attribute_aliases(stmt, class_attribute_aliases, class_aliases ) if sources: class_attribute_aliases[stmt.targets[0].id] = sources + elif isinstance(stmt, ast.Assign) and len(stmt.targets) > 1: + sources = _class_attribute_alias_sources( + stmt.value, + class_attribute_aliases, + class_aliases, + class_bindings, + ) + if sources: + for target in stmt.targets: + target_name = _alias_target_name(target) + if target_name is not None: + class_attribute_aliases[target_name] = sources elif isinstance(stmt, ast.Assign) and len(stmt.targets) == 1: _update_class_attribute_alias_from_unpack( stmt.targets[0],