From 45e3cbaad83c30885c6400cac5736066f60a0450 Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Thu, 2 Jul 2026 14:29:04 +0200 Subject: [PATCH] Detect mutating calls inside statements --- .../test_generate_popular_node_signatures.py | 75 +++++++++++++++++++ tools/generate_popular_node_signatures.py | 45 ++++++++--- 2 files changed, 111 insertions(+), 9 deletions(-) diff --git a/tests/test_generate_popular_node_signatures.py b/tests/test_generate_popular_node_signatures.py index 3c42548..b531ccd 100644 --- a/tests/test_generate_popular_node_signatures.py +++ b/tests/test_generate_popular_node_signatures.py @@ -590,6 +590,33 @@ NODE_CLASS_MAPPINGS = { self.assertEqual({}, result["nodes"]) self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_rhs_mutating_call_invalidates_static_env_value(self): + source = ''' +INPUTS = { + "required": { + "image": ("IMAGE",), + }, +} +X = INPUTS.clear() + + +class RhsMutatedInputEnvNode: + RETURN_TYPES = ("IMAGE",) + + @classmethod + def INPUT_TYPES(cls): + return INPUTS + + +NODE_CLASS_MAPPINGS = { + "RhsMutatedInputEnvNode": RhsMutatedInputEnvNode, +} +''' + result = self._extract_source(source, "rhs-mutated-input-env-pack") + + self.assertEqual({}, result["nodes"]) + self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_nested_mutable_env_literal_skips_static_node(self): source = ''' REQ = { @@ -784,6 +811,30 @@ NODE_CLASS_MAPPINGS = { self.assertEqual({}, result["nodes"]) self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_rhs_mutating_call_to_return_types_skips_node(self): + source = ''' +class RhsMutatedReturnTypesNode: + RETURN_TYPES = ["IMAGE"] + X = RETURN_TYPES.pop() + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image": ("IMAGE",), + }, + } + + +NODE_CLASS_MAPPINGS = { + "RhsMutatedReturnTypesNode": RhsMutatedReturnTypesNode, +} +''' + result = self._extract_source(source, "rhs-mutated-return-types-pack") + + self.assertEqual({}, result["nodes"]) + self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_return_types_alias_mutation_skips_node(self): source = ''' class AliasMutatedReturnTypesNode: @@ -1244,6 +1295,30 @@ NODE_CLASS_MAPPINGS.clear() self.assertEqual({}, result["nodes"]) self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_rhs_mutating_call_to_node_mapping_skips_node(self): + source = ''' +class RhsMutatedMappingNode: + RETURN_TYPES = ("IMAGE",) + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image": ("IMAGE",), + }, + } + + +NODE_CLASS_MAPPINGS = { + "RhsMutatedMappingNode": RhsMutatedMappingNode, +} +X = NODE_CLASS_MAPPINGS.clear() +''' + result = self._extract_source(source, "rhs-mutated-mapping-pack") + + self.assertEqual({}, result["nodes"]) + self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_annotated_alias_mutation_invalidates_static_node_mapping(self): source = ''' class AnnotatedAliasMutatedMappingNode: diff --git a/tools/generate_popular_node_signatures.py b/tools/generate_popular_node_signatures.py index b14b59c..9a657ae 100644 --- a/tools/generate_popular_node_signatures.py +++ b/tools/generate_popular_node_signatures.py @@ -178,18 +178,32 @@ def _delete_target_names(stmt): def _mutating_call_target_names(stmt): - if not isinstance(stmt, ast.Expr): - return set() - call = stmt.value - if not isinstance(call, ast.Call) or not isinstance(call.func, ast.Attribute): - return set() - if call.func.attr not in _MUTATING_METHODS: - return set() - return _target_names(call.func.value) + names = set() + + class MutatingCallVisitor(ast.NodeVisitor): + def visit_FunctionDef(self, node): + return None + + def visit_AsyncFunctionDef(self, node): + return None + + def visit_ClassDef(self, node): + return None + + def visit_Lambda(self, node): + return None + + def visit_Call(self, node): + if isinstance(node.func, ast.Attribute) and node.func.attr in _MUTATING_METHODS: + names.update(_target_names(node.func.value)) + self.generic_visit(node) + + MutatingCallVisitor().visit(stmt) + return names def _assigned_names_in_control_flow(stmt): - names = set() + names = _mutating_call_target_names(stmt) class AssignmentVisitor(ast.NodeVisitor): def visit_FunctionDef(self, node): @@ -296,6 +310,10 @@ def _invalidate_class_bindings(class_bindings, names): def _collect_module_env(tree, class_bindings=None): env = {} for stmt in tree.body: + names = _mutating_call_target_names(stmt) + _invalidate_class_bindings(class_bindings, names) + for name in names: + env.pop(name, None) if isinstance(stmt, ast.ClassDef): if class_bindings is not None: class_bindings[stmt.name] = (stmt, dict(env)) @@ -409,6 +427,11 @@ def _class_attr(cls, name, env): value = _MISSING aliases = set() for stmt in cls.body: + mutating_targets = _mutating_call_target_names(stmt) + if aliases.intersection(mutating_targets): + value = _INVALID + if name in mutating_targets: + value = _INVALID if isinstance(stmt, ast.Assign): target_names = _assignment_target_names(stmt) if ( @@ -510,6 +533,8 @@ def _class_attr(cls, name, env): def _input_types(cls, env): value = _MISSING for stmt in cls.body: + if "INPUT_TYPES" in _mutating_call_target_names(stmt): + value = _INVALID if isinstance(stmt, ast.FunctionDef) and stmt.name == "INPUT_TYPES": if len(stmt.body) != 1 or not isinstance(stmt.body[0], ast.Return): value = _INVALID @@ -580,6 +605,8 @@ def _module_dict_entries(node, env, value_converter): def _final_module_dict(tree, env, name, value_converter): value = _MISSING for stmt in tree.body: + if name in _mutating_call_target_names(stmt): + value = _INVALID if isinstance(stmt, ast.Assign): if not _name_is_assigned(stmt, name): if isinstance(stmt.value, ast.Name) and stmt.value.id == name: