Skip decorated and patched mapped classes

This commit is contained in:
2026-07-02 15:19:22 +02:00
parent fc92d1db24
commit 51db0d16e5
2 changed files with 175 additions and 5 deletions
@@ -1621,6 +1621,34 @@ NODE_CLASS_MAPPINGS = {
self.assertIn("TopLevelMappedNode", result["nodes"]) self.assertIn("TopLevelMappedNode", result["nodes"])
self.assertEqual("ok", result["pack"]["status"]) self.assertEqual("ok", result["pack"]["status"])
def test_decorated_class_mapping_skips_node(self):
source = '''
def decorator(cls):
return cls
@decorator
class DecoratedMappedNode:
RETURN_TYPES = ("IMAGE",)
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"image": ("IMAGE",),
},
}
NODE_CLASS_MAPPINGS = {
"DecoratedMappedNode": DecoratedMappedNode,
}
'''
result = self._extract_source(source, "decorated-mapped-class-pack")
self.assertEqual({}, result["nodes"])
self.assertEqual("no_static_nodes", result["pack"]["status"])
def test_node_mapping_key_uses_assignment_time_env(self): def test_node_mapping_key_uses_assignment_time_env(self):
source = ''' source = '''
KEY = "Original" KEY = "Original"
@@ -1650,6 +1678,58 @@ KEY = "Wrong"
self.assertEqual(["IMAGE"], result["nodes"]["Original"]["outputs"]) self.assertEqual(["IMAGE"], result["nodes"]["Original"]["outputs"])
self.assertEqual("ok", result["pack"]["status"]) self.assertEqual("ok", result["pack"]["status"])
def test_module_class_return_types_patch_after_mapping_skips_node(self):
source = '''
class PatchedReturnTypesNode:
RETURN_TYPES = ("IMAGE",)
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"image": ("IMAGE",),
},
}
NODE_CLASS_MAPPINGS = {
"PatchedReturnTypesNode": PatchedReturnTypesNode,
}
PatchedReturnTypesNode.RETURN_TYPES = ("MASK",)
'''
result = self._extract_source(source, "patched-return-types-pack")
self.assertEqual({}, result["nodes"])
self.assertEqual("no_static_nodes", result["pack"]["status"])
def test_module_class_input_types_patch_after_mapping_skips_node(self):
source = '''
def build_inputs():
return {"required": {"mask": ("MASK",)}}
class PatchedInputTypesNode:
RETURN_TYPES = ("IMAGE",)
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"image": ("IMAGE",),
},
}
NODE_CLASS_MAPPINGS = {
"PatchedInputTypesNode": PatchedInputTypesNode,
}
PatchedInputTypesNode.INPUT_TYPES = build_inputs
'''
result = self._extract_source(source, "patched-input-types-pack")
self.assertEqual({}, result["nodes"])
self.assertEqual("no_static_nodes", result["pack"]["status"])
def test_mutated_node_class_mapping_skips_node(self): def test_mutated_node_class_mapping_skips_node(self):
source = ''' source = '''
class MutatedMappingNode: class MutatedMappingNode:
+95 -5
View File
@@ -79,6 +79,73 @@ def _target_names(target):
return set() return set()
def _root_name(node):
while isinstance(node, (ast.Attribute, ast.Subscript)):
node = node.value
if isinstance(node, ast.Name):
return node.id
return None
def _attribute_target_base_names(target):
if isinstance(target, ast.Attribute):
name = _root_name(target.value)
return {name} if name else set()
if isinstance(target, ast.Subscript):
return _attribute_target_base_names(target.value)
if isinstance(target, (ast.List, ast.Tuple)):
names = set()
for item in target.elts:
names.update(_attribute_target_base_names(item))
return names
if isinstance(target, ast.Starred):
return _attribute_target_base_names(target.value)
return set()
def _class_attribute_mutation_target_names(stmt):
names = set()
class AttributeMutationVisitor(ast.NodeVisitor):
def visit_FunctionDef(self, node):
return None
def visit_AsyncFunctionDef(self, node):
return None
def visit_ClassDef(self, node):
return None
def visit_Lambda(self, node):
return None
def visit_Assign(self, node):
for target in node.targets:
names.update(_attribute_target_base_names(target))
self.visit(node.value)
def visit_AnnAssign(self, node):
names.update(_attribute_target_base_names(node.target))
if node.value is not None:
self.visit(node.value)
def visit_AugAssign(self, node):
names.update(_attribute_target_base_names(node.target))
self.visit(node.value)
def visit_Delete(self, node):
for target in node.targets:
names.update(_attribute_target_base_names(target))
def visit_Call(self, node):
if isinstance(node.func, ast.Attribute) and node.func.attr in _MUTATING_METHODS:
names.update(_attribute_target_base_names(node.func.value))
self.generic_visit(node)
AttributeMutationVisitor().visit(stmt)
return names
def _pattern_bound_names(pattern): def _pattern_bound_names(pattern):
names = set() names = set()
if isinstance(pattern, ast.MatchAs): if isinstance(pattern, ast.MatchAs):
@@ -347,7 +414,10 @@ def _apply_module_stmt_to_env(stmt, env, class_bindings=None):
env.pop(name, None) env.pop(name, None)
if isinstance(stmt, ast.ClassDef): if isinstance(stmt, ast.ClassDef):
if class_bindings is not None: if class_bindings is not None:
class_bindings[stmt.name] = (stmt, dict(env)) if stmt.decorator_list:
class_bindings.pop(stmt.name, None)
else:
class_bindings[stmt.name] = (stmt, dict(env))
env.pop(stmt.name, None) env.pop(stmt.name, None)
return return
if isinstance(stmt, ast.Assign): if isinstance(stmt, ast.Assign):
@@ -635,11 +705,19 @@ def _module_dict_entries(node, env, class_bindings, value_converter):
return result return result
def _final_module_dict(tree, name, value_converter): def _final_module_dict(tree, name, value_converter, value_invalidated_by_names=None):
value_invalidated_by_names = value_invalidated_by_names or (lambda _value, _names: False)
value = _MISSING value = _MISSING
env = {} env = {}
class_bindings = {} class_bindings = {}
for stmt in tree.body: for stmt in tree.body:
class_attr_names = _class_attribute_mutation_target_names(stmt)
if (
value not in (_MISSING, _INVALID)
and class_attr_names
and value_invalidated_by_names(value, class_attr_names)
):
value = _INVALID
if name in _mutating_call_target_names(stmt): if name in _mutating_call_target_names(stmt):
value = _INVALID value = _INVALID
if isinstance(stmt, ast.Assign): if isinstance(stmt, ast.Assign):
@@ -712,14 +790,26 @@ def _mapping_value_binding(value, env, class_bindings):
class_name = _mapping_value_name(value) class_name = _mapping_value_name(value)
if class_name is None: if class_name is None:
return None return None
return class_bindings.get(class_name) binding = class_bindings.get(class_name)
if binding is None:
return None
return class_name, binding
def _node_mapping_invalidated_by_names(value, names):
return any(class_name in names for class_name, _binding in value.values())
def _node_class_mappings(tree): def _node_class_mappings(tree):
if _has_module_wildcard_import(tree): if _has_module_wildcard_import(tree):
return {} return {}
mappings = _final_module_dict(tree, "NODE_CLASS_MAPPINGS", _mapping_value_binding) mappings = _final_module_dict(
return {str(node_type): binding for node_type, binding in mappings.items() if node_type and binding is not None} tree,
"NODE_CLASS_MAPPINGS",
_mapping_value_binding,
_node_mapping_invalidated_by_names,
)
return {str(node_type): binding for node_type, (_class_name, binding) in mappings.items() if node_type}
def _display_mappings(tree): def _display_mappings(tree):