From 317788572e9067b663fc2f82f2b92fe092c06731 Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Thu, 2 Jul 2026 14:50:03 +0200 Subject: [PATCH] Detect definition-time mutating expressions --- .../test_generate_popular_node_signatures.py | 93 +++++++++++++++++++ tools/generate_popular_node_signatures.py | 24 ++++- 2 files changed, 113 insertions(+), 4 deletions(-) diff --git a/tests/test_generate_popular_node_signatures.py b/tests/test_generate_popular_node_signatures.py index 23b9fd7..d18768c 100644 --- a/tests/test_generate_popular_node_signatures.py +++ b/tests/test_generate_popular_node_signatures.py @@ -617,6 +617,73 @@ NODE_CLASS_MAPPINGS = { self.assertEqual({}, result["nodes"]) self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_function_default_mutation_invalidates_static_env_value(self): + source = ''' +INPUTS = { + "required": { + "image": ("IMAGE",), + }, +} + + +def helper(x=INPUTS.clear()): + pass + + +class DefaultMutatedInputEnvNode: + RETURN_TYPES = ("IMAGE",) + + @classmethod + def INPUT_TYPES(cls): + return INPUTS + + +NODE_CLASS_MAPPINGS = { + "DefaultMutatedInputEnvNode": DefaultMutatedInputEnvNode, +} +''' + result = self._extract_source(source, "default-mutated-input-env-pack") + + self.assertEqual({}, result["nodes"]) + self.assertEqual("no_static_nodes", result["pack"]["status"]) + + def test_function_decorator_mutation_invalidates_static_env_value(self): + source = ''' +def decorator(value): + def wrap(fn): + return fn + return wrap + + +INPUTS = { + "required": { + "image": ("IMAGE",), + }, +} + + +@decorator(INPUTS.clear()) +def helper(): + pass + + +class DecoratorMutatedInputEnvNode: + RETURN_TYPES = ("IMAGE",) + + @classmethod + def INPUT_TYPES(cls): + return INPUTS + + +NODE_CLASS_MAPPINGS = { + "DecoratorMutatedInputEnvNode": DecoratorMutatedInputEnvNode, +} +''' + result = self._extract_source(source, "decorator-mutated-input-env-pack") + + self.assertEqual({}, result["nodes"]) + self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_nested_mutable_env_literal_skips_static_node(self): source = ''' REQ = { @@ -835,6 +902,32 @@ NODE_CLASS_MAPPINGS = { self.assertEqual({}, result["nodes"]) self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_function_default_mutation_to_return_types_skips_node(self): + source = ''' +class DefaultMutatedReturnTypesNode: + RETURN_TYPES = ["IMAGE"] + + def helper(self, x=RETURN_TYPES.clear()): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image": ("IMAGE",), + }, + } + + +NODE_CLASS_MAPPINGS = { + "DefaultMutatedReturnTypesNode": DefaultMutatedReturnTypesNode, +} +''' + result = self._extract_source(source, "default-mutated-return-types-pack") + + self.assertEqual({}, result["nodes"]) + self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_return_types_alias_mutation_skips_node(self): source = ''' class AliasMutatedReturnTypesNode: diff --git a/tools/generate_popular_node_signatures.py b/tools/generate_popular_node_signatures.py index 77cd19d..a87cd0e 100644 --- a/tools/generate_popular_node_signatures.py +++ b/tools/generate_popular_node_signatures.py @@ -181,17 +181,33 @@ def _mutating_call_target_names(stmt): names = set() class MutatingCallVisitor(ast.NodeVisitor): + def _visit_function_definition_expressions(self, node): + for decorator in node.decorator_list: + self.visit(decorator) + self.visit(node.args) + if node.returns is not None: + self.visit(node.returns) + for type_param in getattr(node, "type_params", ()): + self.visit(type_param) + def visit_FunctionDef(self, node): - return None + self._visit_function_definition_expressions(node) def visit_AsyncFunctionDef(self, node): - return None + self._visit_function_definition_expressions(node) def visit_ClassDef(self, node): - return None + for decorator in node.decorator_list: + self.visit(decorator) + for base in node.bases: + self.visit(base) + for keyword in node.keywords: + self.visit(keyword.value) + for type_param in getattr(node, "type_params", ()): + self.visit(type_param) def visit_Lambda(self, node): - return None + self.visit(node.args) def visit_Call(self, node): if isinstance(node.func, ast.Attribute) and node.func.attr in _MUTATING_METHODS: