diff --git a/tests/test_generate_popular_node_signatures.py b/tests/test_generate_popular_node_signatures.py index 932243b..e9d1781 100644 --- a/tests/test_generate_popular_node_signatures.py +++ b/tests/test_generate_popular_node_signatures.py @@ -1307,6 +1307,31 @@ NODE_CLASS_MAPPINGS = { self.assertEqual({}, result["nodes"]) self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_return_types_unpacked_alias_mutation_skips_node(self): + source = ''' +class UnpackedAliasMutatedReturnTypesNode: + RETURN_TYPES = ["IMAGE"] + ALIAS, = (RETURN_TYPES,) + ALIAS.clear() + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image": ("IMAGE",), + }, + } + + +NODE_CLASS_MAPPINGS = { + "UnpackedAliasMutatedReturnTypesNode": UnpackedAliasMutatedReturnTypesNode, +} +''' + result = self._extract_source(source, "unpacked-alias-mutated-return-types-pack") + + self.assertEqual({}, result["nodes"]) + self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_return_types_alias_subscript_assignment_skips_node(self): source = ''' class AliasSubscriptMutatedReturnTypesNode: @@ -1916,6 +1941,57 @@ Alias.RETURN_TYPES = ("MASK",) self.assertEqual({}, result["nodes"]) self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_module_class_attribute_alias_mutation_before_mapping_skips_node(self): + source = ''' +class PreMappingAttributeAliasNode: + RETURN_TYPES = ["IMAGE"] + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image": ("IMAGE",), + }, + } + + +RET = PreMappingAttributeAliasNode.RETURN_TYPES +RET.clear() + +NODE_CLASS_MAPPINGS = { + "PreMappingAttributeAliasNode": PreMappingAttributeAliasNode, +} +''' + result = self._extract_source(source, "pre-mapping-attribute-alias-pack") + + self.assertEqual({}, result["nodes"]) + self.assertEqual("no_static_nodes", result["pack"]["status"]) + + def test_module_class_attribute_alias_mutation_after_mapping_skips_node(self): + source = ''' +class PostMappingAttributeAliasNode: + RETURN_TYPES = ["IMAGE"] + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image": ("IMAGE",), + }, + } + + +NODE_CLASS_MAPPINGS = { + "PostMappingAttributeAliasNode": PostMappingAttributeAliasNode, +} +RET = PostMappingAttributeAliasNode.RETURN_TYPES +RET.clear() +''' + result = self._extract_source(source, "post-mapping-attribute-alias-pack") + + self.assertEqual({}, result["nodes"]) + self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_module_class_tuple_alias_patch_after_mapping_skips_node(self): source = ''' class TupleAliasPatchedNode: diff --git a/tools/generate_popular_node_signatures.py b/tools/generate_popular_node_signatures.py index af5fb7f..60f66d4 100644 --- a/tools/generate_popular_node_signatures.py +++ b/tools/generate_popular_node_signatures.py @@ -36,6 +36,7 @@ _MUTATING_METHODS = { _CONTROL_FLOW_TYPES = (ast.If, ast.For, ast.AsyncFor, ast.While, ast.Try, ast.With, ast.AsyncWith, ast.Match) if hasattr(ast, "TryStar"): _CONTROL_FLOW_TYPES += (ast.TryStar,) +_CLASS_SIGNATURE_ATTRS = {"INPUT_TYPES", "RETURN_NAMES", "RETURN_TYPES"} def _literal(node, env, allow_mutable_env=True): @@ -589,6 +590,25 @@ def _input_types_decorators_are_supported(decorators): return all(isinstance(decorator, ast.Name) and decorator.id == "classmethod" for decorator in decorators) +def _class_attr_alias_sources(value, name, aliases): + return isinstance(value, ast.Name) and (value.id == name or value.id in aliases) + + +def _update_class_attr_aliases_from_unpack(target, value, name, aliases): + if not isinstance(target, (ast.Tuple, ast.List)) or not isinstance(value, (ast.Tuple, ast.List)): + return False + if len(target.elts) != len(value.elts): + return False + found = False + for target_item, value_item in zip(target.elts, value.elts): + if not isinstance(target_item, ast.Name): + continue + if _class_attr_alias_sources(value_item, name, aliases): + aliases.add(target_item.id) + found = True + return found + + def _class_attr(cls, name, env): value = _MISSING aliases = set() @@ -604,11 +624,16 @@ def _class_attr(cls, name, env): len(stmt.targets) == 1 and isinstance(stmt.targets[0], ast.Name) and stmt.targets[0].id != name - and isinstance(stmt.value, ast.Name) - and (stmt.value.id == name or stmt.value.id in aliases) + and _class_attr_alias_sources(stmt.value, name, aliases) ): aliases.add(stmt.targets[0].id) continue + if ( + len(stmt.targets) == 1 + and name not in target_names + and _update_class_attr_aliases_from_unpack(stmt.targets[0], stmt.value, name, aliases) + ): + continue if aliases.intersection(target_names): value = _INVALID aliases.difference_update(target_names) @@ -830,22 +855,68 @@ def _expanded_class_attribute_names(names, class_aliases): return expanded +def _class_attribute_alias_sources(value, class_aliases, class_bindings): + if not isinstance(value, ast.Attribute) or value.attr not in _CLASS_SIGNATURE_ATTRS: + return set() + name = _root_name(value.value) + if name in class_aliases: + return set(class_aliases[name]) + if name in class_bindings: + return {name} + return set() + + +def _class_attribute_alias_invalidated_names(stmt, class_attribute_aliases): + names = ( + _mutating_call_target_names(stmt) + | _assignment_target_names(stmt) + | _delete_target_names(stmt) + | _bound_names(stmt) + ) + invalidated = set() + for name in names: + invalidated.update(class_attribute_aliases.get(name, ())) + return invalidated + + +def _update_class_attribute_aliases(stmt, class_attribute_aliases, class_aliases, class_bindings): + rebound_names = _assignment_target_names(stmt) | _delete_target_names(stmt) | _bound_names(stmt) + for name in rebound_names: + class_attribute_aliases.pop(name, None) + + if isinstance(stmt, ast.Assign) and len(stmt.targets) == 1 and isinstance(stmt.targets[0], ast.Name): + sources = _class_attribute_alias_sources(stmt.value, class_aliases, class_bindings) + if sources: + class_attribute_aliases[stmt.targets[0].id] = sources + elif isinstance(stmt, ast.AnnAssign) and isinstance(stmt.target, ast.Name) and stmt.value is not None: + sources = _class_attribute_alias_sources(stmt.value, class_aliases, class_bindings) + if sources: + class_attribute_aliases[stmt.target.id] = sources + + def _final_module_dict(tree, name, value_converter, value_invalidated_by_names=None): value_invalidated_by_names = value_invalidated_by_names or (lambda _value, _names: False) value = _MISSING env = {} class_bindings = {} class_aliases = {} + class_attribute_aliases = {} def advance_module_state(stmt): + _invalidate_class_bindings( + class_bindings, + _class_attribute_alias_invalidated_names(stmt, class_attribute_aliases), + ) _apply_module_stmt_to_env(stmt, env, class_bindings) _update_class_aliases(stmt, class_aliases, class_bindings) + _update_class_attribute_aliases(stmt, class_attribute_aliases, class_aliases, class_bindings) for stmt in tree.body: class_attr_names = _expanded_class_attribute_names( _class_attribute_mutation_target_names(stmt), class_aliases, ) + class_attr_names.update(_class_attribute_alias_invalidated_names(stmt, class_attribute_aliases)) if ( value not in (_MISSING, _INVALID) and class_attr_names