Fail closed on namespace aliases and input observations

This commit is contained in:
2026-07-02 18:19:48 +02:00
parent bf46f9b389
commit f7143e7bac
2 changed files with 255 additions and 3 deletions
@@ -2978,6 +2978,56 @@ globals().update(NODE_CLASS_MAPPINGS={})
self.assertEqual({}, result["nodes"]) self.assertEqual({}, result["nodes"])
self.assertEqual("no_static_nodes", result["pack"]["status"]) self.assertEqual("no_static_nodes", result["pack"]["status"])
def test_globals_alias_subscript_assignment_invalidates_static_node_mapping(self):
source = '''
class GlobalAliasSubscriptAssignmentNode:
RETURN_TYPES = ("IMAGE",)
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"image": ("IMAGE",),
},
}
NODE_CLASS_MAPPINGS = {
"GlobalAliasSubscriptAssignmentNode": GlobalAliasSubscriptAssignmentNode,
}
G = globals()
G["NODE_CLASS_MAPPINGS"] = {}
'''
result = self._extract_source(source, "global-alias-subscript-assignment-pack")
self.assertEqual({}, result["nodes"])
self.assertEqual("no_static_nodes", result["pack"]["status"])
def test_globals_alias_update_invalidates_static_node_mapping(self):
source = '''
class GlobalAliasUpdateNode:
RETURN_TYPES = ("IMAGE",)
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"image": ("IMAGE",),
},
}
NODE_CLASS_MAPPINGS = {
"GlobalAliasUpdateNode": GlobalAliasUpdateNode,
}
G = globals()
G.update(NODE_CLASS_MAPPINGS={})
'''
result = self._extract_source(source, "global-alias-update-pack")
self.assertEqual({}, result["nodes"])
self.assertEqual("no_static_nodes", result["pack"]["status"])
def test_arbitrary_call_invalidates_static_node_mapping(self): def test_arbitrary_call_invalidates_static_node_mapping(self):
source = ''' source = '''
class ArbitraryCallMappingNode: class ArbitraryCallMappingNode:
@@ -3493,6 +3543,52 @@ NODE_CLASS_MAPPINGS = {
self.assertEqual({}, result["nodes"]) self.assertEqual({}, result["nodes"])
self.assertEqual("no_static_nodes", result["pack"]["status"]) self.assertEqual("no_static_nodes", result["pack"]["status"])
def test_input_types_default_observed_by_arbitrary_call_skips_node(self):
source = '''
class DefaultObservedInputTypesNode:
RETURN_TYPES = ("IMAGE",)
@classmethod
def INPUT_TYPES(cls, value=observe(INPUT_TYPES)):
return {
"required": {
"image": ("IMAGE",),
},
}
NODE_CLASS_MAPPINGS = {
"DefaultObservedInputTypesNode": DefaultObservedInputTypesNode,
}
'''
result = self._extract_source(source, "default-observed-input-types-pack")
self.assertEqual({}, result["nodes"])
self.assertEqual("no_static_nodes", result["pack"]["status"])
def test_input_types_return_annotation_observed_by_arbitrary_call_skips_node(self):
source = '''
class AnnotationObservedInputTypesNode:
RETURN_TYPES = ("IMAGE",)
@classmethod
def INPUT_TYPES(cls) -> observe(INPUT_TYPES):
return {
"required": {
"image": ("IMAGE",),
},
}
NODE_CLASS_MAPPINGS = {
"AnnotationObservedInputTypesNode": AnnotationObservedInputTypesNode,
}
'''
result = self._extract_source(source, "annotation-observed-input-types-pack")
self.assertEqual({}, result["nodes"])
self.assertEqual("no_static_nodes", result["pack"]["status"])
def test_write_artifact_is_deterministic(self): def test_write_artifact_is_deterministic(self):
with tempfile.TemporaryDirectory() as tmp: with tempfile.TemporaryDirectory() as tmp:
out_one = Path(tmp, "one.json") out_one = Path(tmp, "one.json")
+159 -3
View File
@@ -142,6 +142,19 @@ def _target_names(target):
return set() return set()
def _direct_target_names(target):
if isinstance(target, ast.Name):
return {target.id}
if isinstance(target, (ast.List, ast.Tuple)):
names = set()
for item in target.elts:
names.update(_direct_target_names(item))
return names
if isinstance(target, ast.Starred):
return _direct_target_names(target.value)
return set()
def _root_name(node): def _root_name(node):
while True: while True:
name = _namespace_lookup_name(node) name = _namespace_lookup_name(node)
@@ -1099,11 +1112,17 @@ def _input_types(cls, env, decorator_env):
for stmt in cls.body: for stmt in cls.body:
mutating_targets = _mutating_call_target_names(stmt) mutating_targets = _mutating_call_target_names(stmt)
observed_targets = _arbitrary_call_observed_names(stmt) observed_targets = _arbitrary_call_observed_names(stmt)
if "INPUT_TYPES" in mutating_targets or aliases.intersection(mutating_targets): input_types_invalidated = (
value = _INVALID "INPUT_TYPES" in mutating_targets
if "INPUT_TYPES" in observed_targets or aliases.intersection(observed_targets): or bool(aliases.intersection(mutating_targets))
or "INPUT_TYPES" in observed_targets
or bool(aliases.intersection(observed_targets))
)
if input_types_invalidated:
value = _INVALID value = _INVALID
if isinstance(stmt, ast.FunctionDef) and stmt.name == "INPUT_TYPES": if isinstance(stmt, ast.FunctionDef) and stmt.name == "INPUT_TYPES":
if input_types_invalidated:
continue
if not _input_types_decorators_are_supported(stmt.decorator_list, classmethod_shadowed): if not _input_types_decorators_are_supported(stmt.decorator_list, classmethod_shadowed):
value = _INVALID value = _INVALID
continue continue
@@ -1435,6 +1454,139 @@ def _module_dict_alias_invalidated(stmt, aliases):
return any(name in aliases for name in names) return any(name in aliases for name in names)
def _namespace_alias_sources(value, aliases):
if _namespace_call_function_name(value) is not None:
return True
if isinstance(value, ast.Name):
return value.id in aliases
if isinstance(value, (ast.Tuple, ast.List)):
return any(_namespace_alias_sources(item, aliases) for item in value.elts)
return False
def _namespace_alias_subscript_name(node, aliases):
if not isinstance(node, ast.Subscript):
return None
if not isinstance(node.value, ast.Name) or node.value.id not in aliases:
return None
if isinstance(node.slice, ast.Constant) and isinstance(node.slice.value, str):
return node.slice.value
return None
def _namespace_alias_target_names(target, aliases):
name = _namespace_alias_subscript_name(target, aliases)
if name is not None:
return {name}
if isinstance(target, (ast.Tuple, ast.List)):
names = set()
for item in target.elts:
names.update(_namespace_alias_target_names(item, aliases))
return names
if isinstance(target, ast.Starred):
return _namespace_alias_target_names(target.value, aliases)
if isinstance(target, (ast.Attribute, ast.Subscript)):
return _namespace_alias_target_names(target.value, aliases)
return set()
def _namespace_alias_mutation_target_names(stmt, aliases):
names = set()
class NamespaceAliasMutationVisitor(ast.NodeVisitor):
def _visit_function_definition_expressions(self, node):
for decorator in node.decorator_list:
self.visit(decorator)
self.visit(node.args)
if node.returns is not None:
self.visit(node.returns)
for type_param in getattr(node, "type_params", ()):
self.visit(type_param)
def visit_FunctionDef(self, node):
self._visit_function_definition_expressions(node)
def visit_AsyncFunctionDef(self, node):
self._visit_function_definition_expressions(node)
def visit_ClassDef(self, node):
for decorator in node.decorator_list:
self.visit(decorator)
for base in node.bases:
self.visit(base)
for keyword in node.keywords:
self.visit(keyword.value)
for type_param in getattr(node, "type_params", ()):
self.visit(type_param)
for child in node.body:
self.visit(child)
def visit_Lambda(self, node):
self.visit(node.args)
def visit_Assign(self, node):
for target in node.targets:
names.update(_namespace_alias_target_names(target, aliases))
self.visit(node.value)
def visit_AnnAssign(self, node):
names.update(_namespace_alias_target_names(node.target, aliases))
if node.value is not None:
self.visit(node.value)
def visit_AugAssign(self, node):
names.update(_namespace_alias_target_names(node.target, aliases))
self.visit(node.value)
def visit_Delete(self, node):
for target in node.targets:
names.update(_namespace_alias_target_names(target, aliases))
def visit_Call(self, node):
if isinstance(node.func, ast.Attribute):
if isinstance(node.func.value, ast.Name) and node.func.value.id in aliases:
if node.func.attr == "update":
for keyword in node.keywords:
names.add(_DYNAMIC_NAMESPACE_MUTATION if keyword.arg is None else keyword.arg)
if node.args or not node.keywords:
names.add(_DYNAMIC_NAMESPACE_MUTATION)
elif node.func.attr in _MUTATING_METHODS:
names.add(_DYNAMIC_NAMESPACE_MUTATION)
namespace_name = _namespace_alias_subscript_name(node.func.value, aliases)
if namespace_name is not None and node.func.attr in _MUTATING_METHODS:
names.add(namespace_name)
self.generic_visit(node)
NamespaceAliasMutationVisitor().visit(stmt)
return names
def _update_namespace_aliases(stmt, aliases):
direct_names = set()
if isinstance(stmt, ast.Assign):
for target in stmt.targets:
direct_names.update(_direct_target_names(target))
elif isinstance(stmt, (ast.AnnAssign, ast.AugAssign)):
direct_names.update(_direct_target_names(stmt.target))
elif isinstance(stmt, ast.Delete):
for target in stmt.targets:
direct_names.update(_direct_target_names(target))
direct_names.update(_bound_names(stmt))
aliases.difference_update(direct_names)
if isinstance(stmt, ast.Assign) and len(stmt.targets) == 1 and isinstance(stmt.targets[0], ast.Name):
if _namespace_alias_sources(stmt.value, aliases):
aliases.add(stmt.targets[0].id)
elif isinstance(stmt, ast.Assign) and len(stmt.targets) == 1:
for target_item, value_item in _unpack_target_value_pairs(stmt.targets[0], stmt.value):
target_name = _alias_target_name(target_item)
if target_name is not None and _namespace_alias_sources(value_item, aliases):
aliases.add(target_name)
elif isinstance(stmt, ast.AnnAssign) and isinstance(stmt.target, ast.Name) and stmt.value is not None:
if _namespace_alias_sources(stmt.value, aliases):
aliases.add(stmt.target.id)
def _update_module_dict_aliases(stmt, name, aliases): def _update_module_dict_aliases(stmt, name, aliases):
rebound_names = _assignment_target_names(stmt) | _delete_target_names(stmt) | _bound_names(stmt) rebound_names = _assignment_target_names(stmt) | _delete_target_names(stmt) | _bound_names(stmt)
for rebound_name in rebound_names: for rebound_name in rebound_names:
@@ -1460,6 +1612,7 @@ def _final_module_dict(tree, name, value_converter, value_invalidated_by_names=N
class_aliases = {} class_aliases = {}
class_attribute_aliases = {} class_attribute_aliases = {}
module_dict_aliases = {} module_dict_aliases = {}
namespace_aliases = set()
def advance_module_state(stmt): def advance_module_state(stmt):
_invalidate_class_bindings( _invalidate_class_bindings(
@@ -1470,6 +1623,7 @@ def _final_module_dict(tree, name, value_converter, value_invalidated_by_names=N
_update_class_aliases(stmt, class_aliases, class_bindings) _update_class_aliases(stmt, class_aliases, class_bindings)
_update_class_attribute_aliases(stmt, class_attribute_aliases, class_aliases, class_bindings) _update_class_attribute_aliases(stmt, class_attribute_aliases, class_aliases, class_bindings)
_update_module_dict_aliases(stmt, name, module_dict_aliases) _update_module_dict_aliases(stmt, name, module_dict_aliases)
_update_namespace_aliases(stmt, namespace_aliases)
for stmt in tree.body: for stmt in tree.body:
class_attr_names = _module_class_attribute_invalidated_names(stmt, class_aliases, class_attribute_aliases) class_attr_names = _module_class_attribute_invalidated_names(stmt, class_aliases, class_attribute_aliases)
@@ -1483,6 +1637,8 @@ def _final_module_dict(tree, name, value_converter, value_invalidated_by_names=N
value = _INVALID value = _INVALID
if _name_invalidated_by(name, _arbitrary_call_observed_names(stmt)): if _name_invalidated_by(name, _arbitrary_call_observed_names(stmt)):
value = _INVALID value = _INVALID
if _name_invalidated_by(name, _namespace_alias_mutation_target_names(stmt, namespace_aliases)):
value = _INVALID
if _module_dict_alias_invalidated(stmt, module_dict_aliases): if _module_dict_alias_invalidated(stmt, module_dict_aliases):
value = _INVALID value = _INVALID
if isinstance(stmt, ast.Assign): if isinstance(stmt, ast.Assign):