Invalidate static env after arbitrary calls
This commit is contained in:
@@ -2482,6 +2482,60 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
self.assertEqual({}, result["nodes"])
|
self.assertEqual({}, result["nodes"])
|
||||||
self.assertEqual("no_static_nodes", result["pack"]["status"])
|
self.assertEqual("no_static_nodes", result["pack"]["status"])
|
||||||
|
|
||||||
|
def test_no_arg_class_body_arbitrary_call_before_return_types_skips_node(self):
|
||||||
|
source = '''
|
||||||
|
RET = ("IMAGE",)
|
||||||
|
|
||||||
|
|
||||||
|
class NoArgClassBodyCallBeforeReturnTypesNode:
|
||||||
|
mutate()
|
||||||
|
RETURN_TYPES = RET
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"image": ("IMAGE",),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"NoArgClassBodyCallBeforeReturnTypesNode": NoArgClassBodyCallBeforeReturnTypesNode,
|
||||||
|
}
|
||||||
|
'''
|
||||||
|
result = self._extract_source(source, "no-arg-class-body-call-before-return-types-pack")
|
||||||
|
|
||||||
|
self.assertEqual({}, result["nodes"])
|
||||||
|
self.assertEqual("no_static_nodes", result["pack"]["status"])
|
||||||
|
|
||||||
|
def test_no_arg_class_body_arbitrary_call_before_input_types_skips_node(self):
|
||||||
|
source = '''
|
||||||
|
SPEC = ("IMAGE",)
|
||||||
|
|
||||||
|
|
||||||
|
class NoArgClassBodyCallBeforeInputTypesNode:
|
||||||
|
mutate()
|
||||||
|
RETURN_TYPES = ("IMAGE",)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"image": SPEC,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"NoArgClassBodyCallBeforeInputTypesNode": NoArgClassBodyCallBeforeInputTypesNode,
|
||||||
|
}
|
||||||
|
'''
|
||||||
|
result = self._extract_source(source, "no-arg-class-body-call-before-input-types-pack")
|
||||||
|
|
||||||
|
self.assertEqual({}, result["nodes"])
|
||||||
|
self.assertEqual("no_static_nodes", result["pack"]["status"])
|
||||||
|
|
||||||
def test_return_types_alias_arbitrary_call_skips_node(self):
|
def test_return_types_alias_arbitrary_call_skips_node(self):
|
||||||
source = '''
|
source = '''
|
||||||
class AliasArbitraryCallReturnTypesNode:
|
class AliasArbitraryCallReturnTypesNode:
|
||||||
@@ -3196,6 +3250,33 @@ KEY = "Wrong"
|
|||||||
self.assertEqual(["IMAGE"], result["nodes"]["Original"]["outputs"])
|
self.assertEqual(["IMAGE"], result["nodes"]["Original"]["outputs"])
|
||||||
self.assertEqual("ok", result["pack"]["status"])
|
self.assertEqual("ok", result["pack"]["status"])
|
||||||
|
|
||||||
|
def test_no_arg_arbitrary_call_invalidates_immutable_env_mapping_key(self):
|
||||||
|
source = '''
|
||||||
|
KEY = "OldNode"
|
||||||
|
mutate()
|
||||||
|
|
||||||
|
|
||||||
|
class Node:
|
||||||
|
RETURN_TYPES = ("IMAGE",)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"image": ("IMAGE",),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
KEY: Node,
|
||||||
|
}
|
||||||
|
'''
|
||||||
|
result = self._extract_source(source, "no-arg-call-immutable-key-pack")
|
||||||
|
|
||||||
|
self.assertEqual({}, result["nodes"])
|
||||||
|
self.assertEqual("no_static_nodes", result["pack"]["status"])
|
||||||
|
|
||||||
def test_non_string_node_mapping_key_skips_node(self):
|
def test_non_string_node_mapping_key_skips_node(self):
|
||||||
source = '''
|
source = '''
|
||||||
class NonStringMappingKeyNode:
|
class NonStringMappingKeyNode:
|
||||||
|
|||||||
@@ -1053,9 +1053,7 @@ def _apply_module_stmt_to_env(stmt, env, class_bindings=None):
|
|||||||
if name in env and _is_mutable_static_value(env[name]):
|
if name in env and _is_mutable_static_value(env[name]):
|
||||||
_invalidate_env_name(env, name)
|
_invalidate_env_name(env, name)
|
||||||
if _has_arbitrary_call(stmt):
|
if _has_arbitrary_call(stmt):
|
||||||
for name, value in list(env.items()):
|
env.clear()
|
||||||
if _is_mutable_static_value(value):
|
|
||||||
_invalidate_env_name(env, name)
|
|
||||||
if class_bindings is not None and not isinstance(
|
if class_bindings is not None and not isinstance(
|
||||||
stmt, (ast.Assign, ast.AnnAssign, ast.AugAssign, ast.Delete)
|
stmt, (ast.Assign, ast.AnnAssign, ast.AugAssign, ast.Delete)
|
||||||
):
|
):
|
||||||
@@ -1278,6 +1276,7 @@ def _update_input_types_aliases_from_unpack(target, value, aliases):
|
|||||||
|
|
||||||
def _class_attr(cls, name, env):
|
def _class_attr(cls, name, env):
|
||||||
value = _MISSING
|
value = _MISSING
|
||||||
|
sticky_invalid = False
|
||||||
aliases = set()
|
aliases = set()
|
||||||
namespace_mutations = _class_body_namespace_mutation_names(cls)
|
namespace_mutations = _class_body_namespace_mutation_names(cls)
|
||||||
if _name_invalidated_by(name, namespace_mutations):
|
if _name_invalidated_by(name, namespace_mutations):
|
||||||
@@ -1287,8 +1286,9 @@ def _class_attr(cls, name, env):
|
|||||||
observed_targets = _arbitrary_call_observed_names(stmt)
|
observed_targets = _arbitrary_call_observed_names(stmt)
|
||||||
expression_references = _class_body_expression_referenced_names(stmt)
|
expression_references = _class_body_expression_referenced_names(stmt)
|
||||||
has_arbitrary_call = _has_arbitrary_call(stmt)
|
has_arbitrary_call = _has_arbitrary_call(stmt)
|
||||||
if value not in (_MISSING, _INVALID) and has_arbitrary_call:
|
if has_arbitrary_call:
|
||||||
value = _INVALID
|
value = _INVALID
|
||||||
|
sticky_invalid = True
|
||||||
if aliases.intersection(mutating_targets):
|
if aliases.intersection(mutating_targets):
|
||||||
value = _INVALID
|
value = _INVALID
|
||||||
if name in mutating_targets:
|
if name in mutating_targets:
|
||||||
@@ -1334,6 +1334,9 @@ def _class_attr(cls, name, env):
|
|||||||
aliases.difference_update(target_names)
|
aliases.difference_update(target_names)
|
||||||
if name not in target_names:
|
if name not in target_names:
|
||||||
continue
|
continue
|
||||||
|
if sticky_invalid:
|
||||||
|
value = _INVALID
|
||||||
|
continue
|
||||||
if len(stmt.targets) == 1 and isinstance(stmt.targets[0], ast.Name):
|
if len(stmt.targets) == 1 and isinstance(stmt.targets[0], ast.Name):
|
||||||
if _is_mutable_env_reference(stmt.value, env):
|
if _is_mutable_env_reference(stmt.value, env):
|
||||||
value = _INVALID
|
value = _INVALID
|
||||||
@@ -1362,6 +1365,9 @@ def _class_attr(cls, name, env):
|
|||||||
continue
|
continue
|
||||||
if isinstance(stmt.target, ast.Name) and stmt.value is None:
|
if isinstance(stmt.target, ast.Name) and stmt.value is None:
|
||||||
continue
|
continue
|
||||||
|
if sticky_invalid:
|
||||||
|
value = _INVALID
|
||||||
|
continue
|
||||||
if not isinstance(stmt.target, ast.Name):
|
if not isinstance(stmt.target, ast.Name):
|
||||||
value = _INVALID
|
value = _INVALID
|
||||||
else:
|
else:
|
||||||
@@ -1432,7 +1438,7 @@ def _input_types(cls, env, decorator_env):
|
|||||||
has_arbitrary_call = _has_arbitrary_call(stmt)
|
has_arbitrary_call = _has_arbitrary_call(stmt)
|
||||||
protected_definition_references = _CLASS_SIGNATURE_ATTRS | aliases
|
protected_definition_references = _CLASS_SIGNATURE_ATTRS | aliases
|
||||||
input_types_invalidated = (
|
input_types_invalidated = (
|
||||||
(value not in (_MISSING, _INVALID) and has_arbitrary_call)
|
has_arbitrary_call
|
||||||
or "INPUT_TYPES" in mutating_targets
|
or "INPUT_TYPES" in mutating_targets
|
||||||
or bool(aliases.intersection(mutating_targets))
|
or bool(aliases.intersection(mutating_targets))
|
||||||
or "INPUT_TYPES" in observed_targets
|
or "INPUT_TYPES" in observed_targets
|
||||||
|
|||||||
Reference in New Issue
Block a user