Invalidate static env after arbitrary calls

This commit is contained in:
2026-07-02 21:19:10 +02:00
parent ee8496174f
commit d7c3fc86c1
2 changed files with 92 additions and 5 deletions
@@ -2482,6 +2482,60 @@ NODE_CLASS_MAPPINGS = {
self.assertEqual({}, result["nodes"]) self.assertEqual({}, result["nodes"])
self.assertEqual("no_static_nodes", result["pack"]["status"]) self.assertEqual("no_static_nodes", result["pack"]["status"])
def test_no_arg_class_body_arbitrary_call_before_return_types_skips_node(self):
source = '''
RET = ("IMAGE",)
class NoArgClassBodyCallBeforeReturnTypesNode:
mutate()
RETURN_TYPES = RET
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"image": ("IMAGE",),
},
}
NODE_CLASS_MAPPINGS = {
"NoArgClassBodyCallBeforeReturnTypesNode": NoArgClassBodyCallBeforeReturnTypesNode,
}
'''
result = self._extract_source(source, "no-arg-class-body-call-before-return-types-pack")
self.assertEqual({}, result["nodes"])
self.assertEqual("no_static_nodes", result["pack"]["status"])
def test_no_arg_class_body_arbitrary_call_before_input_types_skips_node(self):
source = '''
SPEC = ("IMAGE",)
class NoArgClassBodyCallBeforeInputTypesNode:
mutate()
RETURN_TYPES = ("IMAGE",)
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"image": SPEC,
},
}
NODE_CLASS_MAPPINGS = {
"NoArgClassBodyCallBeforeInputTypesNode": NoArgClassBodyCallBeforeInputTypesNode,
}
'''
result = self._extract_source(source, "no-arg-class-body-call-before-input-types-pack")
self.assertEqual({}, result["nodes"])
self.assertEqual("no_static_nodes", result["pack"]["status"])
def test_return_types_alias_arbitrary_call_skips_node(self): def test_return_types_alias_arbitrary_call_skips_node(self):
source = ''' source = '''
class AliasArbitraryCallReturnTypesNode: class AliasArbitraryCallReturnTypesNode:
@@ -3196,6 +3250,33 @@ KEY = "Wrong"
self.assertEqual(["IMAGE"], result["nodes"]["Original"]["outputs"]) self.assertEqual(["IMAGE"], result["nodes"]["Original"]["outputs"])
self.assertEqual("ok", result["pack"]["status"]) self.assertEqual("ok", result["pack"]["status"])
def test_no_arg_arbitrary_call_invalidates_immutable_env_mapping_key(self):
source = '''
KEY = "OldNode"
mutate()
class Node:
RETURN_TYPES = ("IMAGE",)
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"image": ("IMAGE",),
},
}
NODE_CLASS_MAPPINGS = {
KEY: Node,
}
'''
result = self._extract_source(source, "no-arg-call-immutable-key-pack")
self.assertEqual({}, result["nodes"])
self.assertEqual("no_static_nodes", result["pack"]["status"])
def test_non_string_node_mapping_key_skips_node(self): def test_non_string_node_mapping_key_skips_node(self):
source = ''' source = '''
class NonStringMappingKeyNode: class NonStringMappingKeyNode:
+11 -5
View File
@@ -1053,9 +1053,7 @@ def _apply_module_stmt_to_env(stmt, env, class_bindings=None):
if name in env and _is_mutable_static_value(env[name]): if name in env and _is_mutable_static_value(env[name]):
_invalidate_env_name(env, name) _invalidate_env_name(env, name)
if _has_arbitrary_call(stmt): if _has_arbitrary_call(stmt):
for name, value in list(env.items()): env.clear()
if _is_mutable_static_value(value):
_invalidate_env_name(env, name)
if class_bindings is not None and not isinstance( if class_bindings is not None and not isinstance(
stmt, (ast.Assign, ast.AnnAssign, ast.AugAssign, ast.Delete) stmt, (ast.Assign, ast.AnnAssign, ast.AugAssign, ast.Delete)
): ):
@@ -1278,6 +1276,7 @@ def _update_input_types_aliases_from_unpack(target, value, aliases):
def _class_attr(cls, name, env): def _class_attr(cls, name, env):
value = _MISSING value = _MISSING
sticky_invalid = False
aliases = set() aliases = set()
namespace_mutations = _class_body_namespace_mutation_names(cls) namespace_mutations = _class_body_namespace_mutation_names(cls)
if _name_invalidated_by(name, namespace_mutations): if _name_invalidated_by(name, namespace_mutations):
@@ -1287,8 +1286,9 @@ def _class_attr(cls, name, env):
observed_targets = _arbitrary_call_observed_names(stmt) observed_targets = _arbitrary_call_observed_names(stmt)
expression_references = _class_body_expression_referenced_names(stmt) expression_references = _class_body_expression_referenced_names(stmt)
has_arbitrary_call = _has_arbitrary_call(stmt) has_arbitrary_call = _has_arbitrary_call(stmt)
if value not in (_MISSING, _INVALID) and has_arbitrary_call: if has_arbitrary_call:
value = _INVALID value = _INVALID
sticky_invalid = True
if aliases.intersection(mutating_targets): if aliases.intersection(mutating_targets):
value = _INVALID value = _INVALID
if name in mutating_targets: if name in mutating_targets:
@@ -1334,6 +1334,9 @@ def _class_attr(cls, name, env):
aliases.difference_update(target_names) aliases.difference_update(target_names)
if name not in target_names: if name not in target_names:
continue continue
if sticky_invalid:
value = _INVALID
continue
if len(stmt.targets) == 1 and isinstance(stmt.targets[0], ast.Name): if len(stmt.targets) == 1 and isinstance(stmt.targets[0], ast.Name):
if _is_mutable_env_reference(stmt.value, env): if _is_mutable_env_reference(stmt.value, env):
value = _INVALID value = _INVALID
@@ -1362,6 +1365,9 @@ def _class_attr(cls, name, env):
continue continue
if isinstance(stmt.target, ast.Name) and stmt.value is None: if isinstance(stmt.target, ast.Name) and stmt.value is None:
continue continue
if sticky_invalid:
value = _INVALID
continue
if not isinstance(stmt.target, ast.Name): if not isinstance(stmt.target, ast.Name):
value = _INVALID value = _INVALID
else: else:
@@ -1432,7 +1438,7 @@ def _input_types(cls, env, decorator_env):
has_arbitrary_call = _has_arbitrary_call(stmt) has_arbitrary_call = _has_arbitrary_call(stmt)
protected_definition_references = _CLASS_SIGNATURE_ATTRS | aliases protected_definition_references = _CLASS_SIGNATURE_ATTRS | aliases
input_types_invalidated = ( input_types_invalidated = (
(value not in (_MISSING, _INVALID) and has_arbitrary_call) has_arbitrary_call
or "INPUT_TYPES" in mutating_targets or "INPUT_TYPES" in mutating_targets
or bool(aliases.intersection(mutating_targets)) or bool(aliases.intersection(mutating_targets))
or "INPUT_TYPES" in observed_targets or "INPUT_TYPES" in observed_targets