diff --git a/tests/test_generate_popular_node_signatures.py b/tests/test_generate_popular_node_signatures.py index ce59b12..a5461ce 100644 --- a/tests/test_generate_popular_node_signatures.py +++ b/tests/test_generate_popular_node_signatures.py @@ -1359,6 +1359,36 @@ NODE_CLASS_MAPPINGS = { self.assertEqual({}, result["nodes"]) self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_class_body_plain_function_decorator_invalidates_static_input_env(self): + source = ''' +INPUTS = { + "required": { + "image": ("IMAGE",), + }, +} + + +class ClassPlainDecoratorInputEnvNode: + RETURN_TYPES = ("IMAGE",) + + @decorator + def helper(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return INPUTS + + +NODE_CLASS_MAPPINGS = { + "ClassPlainDecoratorInputEnvNode": ClassPlainDecoratorInputEnvNode, +} +''' + result = self._extract_source(source, "class-plain-decorator-input-env-pack") + + self.assertEqual({}, result["nodes"]) + self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_class_body_function_body_mutation_does_not_invalidate_static_input_env(self): source = ''' INPUTS = { @@ -4120,6 +4150,33 @@ def helper(x=DefinitionTimeMutatedMappedNode.RETURN_TYPES.clear()): self.assertEqual({}, result["nodes"]) self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_plain_decorated_function_after_mapping_skips_node(self): + source = ''' +class PlainDecoratedAfterMappingNode: + RETURN_TYPES = ["IMAGE"] + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image": ("IMAGE",), + }, + } + + +NODE_CLASS_MAPPINGS = { + "PlainDecoratedAfterMappingNode": PlainDecoratedAfterMappingNode, +} + +@decorator +def helper(): + pass +''' + result = self._extract_source(source, "plain-decorated-after-mapping-pack") + + self.assertEqual({}, result["nodes"]) + self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_unhashable_node_mapping_key_skips_repo_without_raising(self): source = ''' KEY = ["UnhashableMappingKeyNode"] diff --git a/tools/generate_popular_node_signatures.py b/tools/generate_popular_node_signatures.py index b4a7451..b713798 100644 --- a/tools/generate_popular_node_signatures.py +++ b/tools/generate_popular_node_signatures.py @@ -668,7 +668,13 @@ def _has_arbitrary_call(stmt): class ArbitraryCallPresenceVisitor(ast.NodeVisitor): def _visit_function_definition_expressions(self, node): + nonlocal found for decorator in node.decorator_list: + if not ( + isinstance(decorator, ast.Name) + and decorator.id == "classmethod" + ): + found = True self.visit(decorator) self.visit(node.args) if node.returns is not None: