diff --git a/tests/test_generate_popular_node_signatures.py b/tests/test_generate_popular_node_signatures.py index 5216ed5..f1d81dd 100644 --- a/tests/test_generate_popular_node_signatures.py +++ b/tests/test_generate_popular_node_signatures.py @@ -1621,6 +1621,34 @@ NODE_CLASS_MAPPINGS = { self.assertIn("TopLevelMappedNode", result["nodes"]) self.assertEqual("ok", result["pack"]["status"]) + def test_decorated_class_mapping_skips_node(self): + source = ''' +def decorator(cls): + return cls + + +@decorator +class DecoratedMappedNode: + RETURN_TYPES = ("IMAGE",) + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image": ("IMAGE",), + }, + } + + +NODE_CLASS_MAPPINGS = { + "DecoratedMappedNode": DecoratedMappedNode, +} +''' + result = self._extract_source(source, "decorated-mapped-class-pack") + + self.assertEqual({}, result["nodes"]) + self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_node_mapping_key_uses_assignment_time_env(self): source = ''' KEY = "Original" @@ -1650,6 +1678,58 @@ KEY = "Wrong" self.assertEqual(["IMAGE"], result["nodes"]["Original"]["outputs"]) self.assertEqual("ok", result["pack"]["status"]) + def test_module_class_return_types_patch_after_mapping_skips_node(self): + source = ''' +class PatchedReturnTypesNode: + RETURN_TYPES = ("IMAGE",) + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image": ("IMAGE",), + }, + } + + +NODE_CLASS_MAPPINGS = { + "PatchedReturnTypesNode": PatchedReturnTypesNode, +} +PatchedReturnTypesNode.RETURN_TYPES = ("MASK",) +''' + result = self._extract_source(source, "patched-return-types-pack") + + self.assertEqual({}, result["nodes"]) + self.assertEqual("no_static_nodes", result["pack"]["status"]) + + def test_module_class_input_types_patch_after_mapping_skips_node(self): + source = ''' +def build_inputs(): + return {"required": {"mask": ("MASK",)}} + + +class PatchedInputTypesNode: + RETURN_TYPES = ("IMAGE",) + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image": ("IMAGE",), + }, + } + + +NODE_CLASS_MAPPINGS = { + "PatchedInputTypesNode": PatchedInputTypesNode, +} +PatchedInputTypesNode.INPUT_TYPES = build_inputs +''' + result = self._extract_source(source, "patched-input-types-pack") + + self.assertEqual({}, result["nodes"]) + self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_mutated_node_class_mapping_skips_node(self): source = ''' class MutatedMappingNode: diff --git a/tools/generate_popular_node_signatures.py b/tools/generate_popular_node_signatures.py index 51bbd5e..ebc7f85 100644 --- a/tools/generate_popular_node_signatures.py +++ b/tools/generate_popular_node_signatures.py @@ -79,6 +79,73 @@ def _target_names(target): return set() +def _root_name(node): + while isinstance(node, (ast.Attribute, ast.Subscript)): + node = node.value + if isinstance(node, ast.Name): + return node.id + return None + + +def _attribute_target_base_names(target): + if isinstance(target, ast.Attribute): + name = _root_name(target.value) + return {name} if name else set() + if isinstance(target, ast.Subscript): + return _attribute_target_base_names(target.value) + if isinstance(target, (ast.List, ast.Tuple)): + names = set() + for item in target.elts: + names.update(_attribute_target_base_names(item)) + return names + if isinstance(target, ast.Starred): + return _attribute_target_base_names(target.value) + return set() + + +def _class_attribute_mutation_target_names(stmt): + names = set() + + class AttributeMutationVisitor(ast.NodeVisitor): + def visit_FunctionDef(self, node): + return None + + def visit_AsyncFunctionDef(self, node): + return None + + def visit_ClassDef(self, node): + return None + + def visit_Lambda(self, node): + return None + + def visit_Assign(self, node): + for target in node.targets: + names.update(_attribute_target_base_names(target)) + self.visit(node.value) + + def visit_AnnAssign(self, node): + names.update(_attribute_target_base_names(node.target)) + if node.value is not None: + self.visit(node.value) + + def visit_AugAssign(self, node): + names.update(_attribute_target_base_names(node.target)) + self.visit(node.value) + + def visit_Delete(self, node): + for target in node.targets: + names.update(_attribute_target_base_names(target)) + + def visit_Call(self, node): + if isinstance(node.func, ast.Attribute) and node.func.attr in _MUTATING_METHODS: + names.update(_attribute_target_base_names(node.func.value)) + self.generic_visit(node) + + AttributeMutationVisitor().visit(stmt) + return names + + def _pattern_bound_names(pattern): names = set() if isinstance(pattern, ast.MatchAs): @@ -347,7 +414,10 @@ def _apply_module_stmt_to_env(stmt, env, class_bindings=None): env.pop(name, None) if isinstance(stmt, ast.ClassDef): if class_bindings is not None: - class_bindings[stmt.name] = (stmt, dict(env)) + if stmt.decorator_list: + class_bindings.pop(stmt.name, None) + else: + class_bindings[stmt.name] = (stmt, dict(env)) env.pop(stmt.name, None) return if isinstance(stmt, ast.Assign): @@ -635,11 +705,19 @@ def _module_dict_entries(node, env, class_bindings, value_converter): return result -def _final_module_dict(tree, name, value_converter): +def _final_module_dict(tree, name, value_converter, value_invalidated_by_names=None): + value_invalidated_by_names = value_invalidated_by_names or (lambda _value, _names: False) value = _MISSING env = {} class_bindings = {} for stmt in tree.body: + class_attr_names = _class_attribute_mutation_target_names(stmt) + if ( + value not in (_MISSING, _INVALID) + and class_attr_names + and value_invalidated_by_names(value, class_attr_names) + ): + value = _INVALID if name in _mutating_call_target_names(stmt): value = _INVALID if isinstance(stmt, ast.Assign): @@ -712,14 +790,26 @@ def _mapping_value_binding(value, env, class_bindings): class_name = _mapping_value_name(value) if class_name is None: return None - return class_bindings.get(class_name) + binding = class_bindings.get(class_name) + if binding is None: + return None + return class_name, binding + + +def _node_mapping_invalidated_by_names(value, names): + return any(class_name in names for class_name, _binding in value.values()) def _node_class_mappings(tree): if _has_module_wildcard_import(tree): return {} - mappings = _final_module_dict(tree, "NODE_CLASS_MAPPINGS", _mapping_value_binding) - return {str(node_type): binding for node_type, binding in mappings.items() if node_type and binding is not None} + mappings = _final_module_dict( + tree, + "NODE_CLASS_MAPPINGS", + _mapping_value_binding, + _node_mapping_invalidated_by_names, + ) + return {str(node_type): binding for node_type, (_class_name, binding) in mappings.items() if node_type} def _display_mappings(tree):