From d7c3fc86c14ce434b6bdb100e2e9e939dac43ba7 Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Thu, 2 Jul 2026 21:19:10 +0200 Subject: [PATCH] Invalidate static env after arbitrary calls --- .../test_generate_popular_node_signatures.py | 81 +++++++++++++++++++ tools/generate_popular_node_signatures.py | 16 ++-- 2 files changed, 92 insertions(+), 5 deletions(-) diff --git a/tests/test_generate_popular_node_signatures.py b/tests/test_generate_popular_node_signatures.py index 0084351..45842b0 100644 --- a/tests/test_generate_popular_node_signatures.py +++ b/tests/test_generate_popular_node_signatures.py @@ -2482,6 +2482,60 @@ NODE_CLASS_MAPPINGS = { self.assertEqual({}, result["nodes"]) self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_no_arg_class_body_arbitrary_call_before_return_types_skips_node(self): + source = ''' +RET = ("IMAGE",) + + +class NoArgClassBodyCallBeforeReturnTypesNode: + mutate() + RETURN_TYPES = RET + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image": ("IMAGE",), + }, + } + + +NODE_CLASS_MAPPINGS = { + "NoArgClassBodyCallBeforeReturnTypesNode": NoArgClassBodyCallBeforeReturnTypesNode, +} +''' + result = self._extract_source(source, "no-arg-class-body-call-before-return-types-pack") + + self.assertEqual({}, result["nodes"]) + self.assertEqual("no_static_nodes", result["pack"]["status"]) + + def test_no_arg_class_body_arbitrary_call_before_input_types_skips_node(self): + source = ''' +SPEC = ("IMAGE",) + + +class NoArgClassBodyCallBeforeInputTypesNode: + mutate() + RETURN_TYPES = ("IMAGE",) + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image": SPEC, + }, + } + + +NODE_CLASS_MAPPINGS = { + "NoArgClassBodyCallBeforeInputTypesNode": NoArgClassBodyCallBeforeInputTypesNode, +} +''' + result = self._extract_source(source, "no-arg-class-body-call-before-input-types-pack") + + self.assertEqual({}, result["nodes"]) + self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_return_types_alias_arbitrary_call_skips_node(self): source = ''' class AliasArbitraryCallReturnTypesNode: @@ -3196,6 +3250,33 @@ KEY = "Wrong" self.assertEqual(["IMAGE"], result["nodes"]["Original"]["outputs"]) self.assertEqual("ok", result["pack"]["status"]) + def test_no_arg_arbitrary_call_invalidates_immutable_env_mapping_key(self): + source = ''' +KEY = "OldNode" +mutate() + + +class Node: + RETURN_TYPES = ("IMAGE",) + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image": ("IMAGE",), + }, + } + + +NODE_CLASS_MAPPINGS = { + KEY: Node, +} +''' + result = self._extract_source(source, "no-arg-call-immutable-key-pack") + + self.assertEqual({}, result["nodes"]) + self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_non_string_node_mapping_key_skips_node(self): source = ''' class NonStringMappingKeyNode: diff --git a/tools/generate_popular_node_signatures.py b/tools/generate_popular_node_signatures.py index 996dadf..a61757f 100644 --- a/tools/generate_popular_node_signatures.py +++ b/tools/generate_popular_node_signatures.py @@ -1053,9 +1053,7 @@ def _apply_module_stmt_to_env(stmt, env, class_bindings=None): if name in env and _is_mutable_static_value(env[name]): _invalidate_env_name(env, name) if _has_arbitrary_call(stmt): - for name, value in list(env.items()): - if _is_mutable_static_value(value): - _invalidate_env_name(env, name) + env.clear() if class_bindings is not None and not isinstance( stmt, (ast.Assign, ast.AnnAssign, ast.AugAssign, ast.Delete) ): @@ -1278,6 +1276,7 @@ def _update_input_types_aliases_from_unpack(target, value, aliases): def _class_attr(cls, name, env): value = _MISSING + sticky_invalid = False aliases = set() namespace_mutations = _class_body_namespace_mutation_names(cls) if _name_invalidated_by(name, namespace_mutations): @@ -1287,8 +1286,9 @@ def _class_attr(cls, name, env): observed_targets = _arbitrary_call_observed_names(stmt) expression_references = _class_body_expression_referenced_names(stmt) has_arbitrary_call = _has_arbitrary_call(stmt) - if value not in (_MISSING, _INVALID) and has_arbitrary_call: + if has_arbitrary_call: value = _INVALID + sticky_invalid = True if aliases.intersection(mutating_targets): value = _INVALID if name in mutating_targets: @@ -1334,6 +1334,9 @@ def _class_attr(cls, name, env): aliases.difference_update(target_names) if name not in target_names: continue + if sticky_invalid: + value = _INVALID + continue if len(stmt.targets) == 1 and isinstance(stmt.targets[0], ast.Name): if _is_mutable_env_reference(stmt.value, env): value = _INVALID @@ -1362,6 +1365,9 @@ def _class_attr(cls, name, env): continue if isinstance(stmt.target, ast.Name) and stmt.value is None: continue + if sticky_invalid: + value = _INVALID + continue if not isinstance(stmt.target, ast.Name): value = _INVALID else: @@ -1432,7 +1438,7 @@ def _input_types(cls, env, decorator_env): has_arbitrary_call = _has_arbitrary_call(stmt) protected_definition_references = _CLASS_SIGNATURE_ATTRS | aliases input_types_invalidated = ( - (value not in (_MISSING, _INVALID) and has_arbitrary_call) + has_arbitrary_call or "INPUT_TYPES" in mutating_targets or bool(aliases.intersection(mutating_targets)) or "INPUT_TYPES" in observed_targets