Invalidate mappings on duplicate keys and class aliases

This commit is contained in:
2026-07-02 15:34:50 +02:00
parent 51db0d16e5
commit 86ea12924c
2 changed files with 192 additions and 17 deletions
@@ -1730,6 +1730,110 @@ PatchedInputTypesNode.INPUT_TYPES = build_inputs
self.assertEqual({}, result["nodes"])
self.assertEqual("no_static_nodes", result["pack"]["status"])
def test_duplicate_node_mapping_key_with_dynamic_value_skips_node(self):
source = '''
def build_node():
return object()
class DuplicateMappingKeyNode:
RETURN_TYPES = ("IMAGE",)
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"image": ("IMAGE",),
},
}
NODE_CLASS_MAPPINGS = {
"DuplicateMappingKeyNode": DuplicateMappingKeyNode,
"DuplicateMappingKeyNode": build_node(),
}
'''
result = self._extract_source(source, "duplicate-mapping-key-pack")
self.assertEqual({}, result["nodes"])
self.assertEqual("no_static_nodes", result["pack"]["status"])
def test_module_class_alias_patch_after_mapping_skips_node(self):
source = '''
class AliasPatchedNode:
RETURN_TYPES = ("IMAGE",)
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"image": ("IMAGE",),
},
}
NODE_CLASS_MAPPINGS = {
"AliasPatchedNode": AliasPatchedNode,
}
Alias = AliasPatchedNode
Alias.RETURN_TYPES = ("MASK",)
'''
result = self._extract_source(source, "alias-patched-node-pack")
self.assertEqual({}, result["nodes"])
self.assertEqual("no_static_nodes", result["pack"]["status"])
def test_definition_time_class_attribute_mutation_after_mapping_skips_node(self):
source = '''
class DefinitionTimeMutatedMappedNode:
RETURN_TYPES = ["IMAGE"]
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"image": ("IMAGE",),
},
}
NODE_CLASS_MAPPINGS = {
"DefinitionTimeMutatedMappedNode": DefinitionTimeMutatedMappedNode,
}
def helper(x=DefinitionTimeMutatedMappedNode.RETURN_TYPES.clear()):
pass
'''
result = self._extract_source(source, "definition-time-mutated-mapped-node-pack")
self.assertEqual({}, result["nodes"])
self.assertEqual("no_static_nodes", result["pack"]["status"])
def test_unhashable_node_mapping_key_skips_repo_without_raising(self):
source = '''
KEY = ["UnhashableMappingKeyNode"]
class UnhashableMappingKeyNode:
RETURN_TYPES = ("IMAGE",)
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"image": ("IMAGE",),
},
}
NODE_CLASS_MAPPINGS = {
KEY: UnhashableMappingKeyNode,
}
'''
result = self._extract_source(source, "unhashable-mapping-key-pack")
self.assertEqual({}, result["nodes"])
self.assertEqual("no_static_nodes", result["pack"]["status"])
def test_mutated_node_class_mapping_skips_node(self):
source = '''
class MutatedMappingNode:
+88 -17
View File
@@ -107,17 +107,35 @@ def _class_attribute_mutation_target_names(stmt):
names = set()
class AttributeMutationVisitor(ast.NodeVisitor):
def _visit_function_definition_expressions(self, node):
for decorator in node.decorator_list:
self.visit(decorator)
self.visit(node.args)
if node.returns is not None:
self.visit(node.returns)
for type_param in getattr(node, "type_params", ()):
self.visit(type_param)
def visit_FunctionDef(self, node):
return None
self._visit_function_definition_expressions(node)
def visit_AsyncFunctionDef(self, node):
return None
self._visit_function_definition_expressions(node)
def visit_ClassDef(self, node):
return None
for decorator in node.decorator_list:
self.visit(decorator)
for base in node.bases:
self.visit(base)
for keyword in node.keywords:
self.visit(keyword.value)
for type_param in getattr(node, "type_params", ()):
self.visit(type_param)
for child in node.body:
self.visit(child)
def visit_Lambda(self, node):
return None
self.visit(node.args)
def visit_Assign(self, node):
for target in node.targets:
@@ -698,20 +716,73 @@ def _module_dict_entries(node, env, class_bindings, value_converter):
for key, value in zip(node.keys, node.values):
if key is None:
raise UnsupportedStaticExpression("dict unpacking is not supported")
key_value = _literal(key, env)
try:
hash(key_value)
except TypeError as exc:
raise UnsupportedStaticExpression("unhashable dict key") from exc
if key_value in result:
raise UnsupportedStaticExpression("duplicate dict key")
converted_value = value_converter(value, env, class_bindings)
if converted_value is None:
continue
result[_literal(key, env)] = converted_value
raise UnsupportedStaticExpression("unsupported dict value")
result[key_value] = converted_value
return result
def _class_alias_sources(value, class_aliases, class_bindings):
if not isinstance(value, ast.Name):
return set()
if value.id in class_aliases:
return set(class_aliases[value.id])
if value.id in class_bindings:
return {value.id}
return set()
def _update_class_aliases(stmt, class_aliases, class_bindings):
rebound_names = _assignment_target_names(stmt) | _delete_target_names(stmt) | _bound_names(stmt)
for name in rebound_names:
class_aliases.pop(name, None)
if isinstance(stmt, ast.ClassDef):
if stmt.name in class_bindings and not stmt.decorator_list:
class_aliases[stmt.name] = {stmt.name}
return
if isinstance(stmt, ast.Assign) and len(stmt.targets) == 1 and isinstance(stmt.targets[0], ast.Name):
sources = _class_alias_sources(stmt.value, class_aliases, class_bindings)
if sources:
class_aliases[stmt.targets[0].id] = sources
elif isinstance(stmt, ast.AnnAssign) and isinstance(stmt.target, ast.Name) and stmt.value is not None:
sources = _class_alias_sources(stmt.value, class_aliases, class_bindings)
if sources:
class_aliases[stmt.target.id] = sources
def _expanded_class_attribute_names(names, class_aliases):
expanded = set(names)
for name in names:
expanded.update(class_aliases.get(name, ()))
return expanded
def _final_module_dict(tree, name, value_converter, value_invalidated_by_names=None):
value_invalidated_by_names = value_invalidated_by_names or (lambda _value, _names: False)
value = _MISSING
env = {}
class_bindings = {}
class_aliases = {}
def advance_module_state(stmt):
_apply_module_stmt_to_env(stmt, env, class_bindings)
_update_class_aliases(stmt, class_aliases, class_bindings)
for stmt in tree.body:
class_attr_names = _class_attribute_mutation_target_names(stmt)
class_attr_names = _expanded_class_attribute_names(
_class_attribute_mutation_target_names(stmt),
class_aliases,
)
if (
value not in (_MISSING, _INVALID)
and class_attr_names
@@ -724,7 +795,7 @@ def _final_module_dict(tree, name, value_converter, value_invalidated_by_names=N
if not _name_is_assigned(stmt, name):
if isinstance(stmt.value, ast.Name) and stmt.value.id == name:
value = _INVALID
_apply_module_stmt_to_env(stmt, env, class_bindings)
advance_module_state(stmt)
continue
if len(stmt.targets) == 1 and isinstance(stmt.targets[0], ast.Name):
try:
@@ -733,13 +804,13 @@ def _final_module_dict(tree, name, value_converter, value_invalidated_by_names=N
value = _INVALID
else:
value = _INVALID
_apply_module_stmt_to_env(stmt, env, class_bindings)
advance_module_state(stmt)
continue
if isinstance(stmt, ast.AnnAssign):
if not _name_is_assigned(stmt, name):
if isinstance(stmt.value, ast.Name) and stmt.value.id == name:
value = _INVALID
_apply_module_stmt_to_env(stmt, env, class_bindings)
advance_module_state(stmt)
continue
if isinstance(stmt.target, ast.Name) and stmt.value is not None:
try:
@@ -748,39 +819,39 @@ def _final_module_dict(tree, name, value_converter, value_invalidated_by_names=N
value = _INVALID
else:
value = _INVALID
_apply_module_stmt_to_env(stmt, env, class_bindings)
advance_module_state(stmt)
continue
if isinstance(stmt, ast.AugAssign):
if _name_is_assigned(stmt, name):
value = _INVALID
_apply_module_stmt_to_env(stmt, env, class_bindings)
advance_module_state(stmt)
continue
if isinstance(stmt, ast.Delete):
if name in _delete_target_names(stmt):
value = _INVALID
_apply_module_stmt_to_env(stmt, env, class_bindings)
advance_module_state(stmt)
continue
if isinstance(stmt, ast.Expr):
if name in _mutating_call_target_names(stmt):
value = _INVALID
if name in _bound_names(stmt):
value = _INVALID
_apply_module_stmt_to_env(stmt, env, class_bindings)
advance_module_state(stmt)
continue
if isinstance(stmt, _CONTROL_FLOW_TYPES):
if name in _assigned_names_in_control_flow(stmt):
value = _INVALID
if _has_wildcard_import_in_control_flow(stmt):
value = _INVALID
_apply_module_stmt_to_env(stmt, env, class_bindings)
advance_module_state(stmt)
continue
if _has_wildcard_import(stmt):
value = _INVALID
_apply_module_stmt_to_env(stmt, env, class_bindings)
advance_module_state(stmt)
continue
if name in _bound_names(stmt):
value = _INVALID
_apply_module_stmt_to_env(stmt, env, class_bindings)
advance_module_state(stmt)
if value in (_MISSING, _INVALID):
return {}
return value