Skip decorated and patched mapped classes
This commit is contained in:
@@ -1621,6 +1621,34 @@ NODE_CLASS_MAPPINGS = {
|
||||
self.assertIn("TopLevelMappedNode", result["nodes"])
|
||||
self.assertEqual("ok", result["pack"]["status"])
|
||||
|
||||
def test_decorated_class_mapping_skips_node(self):
|
||||
source = '''
|
||||
def decorator(cls):
|
||||
return cls
|
||||
|
||||
|
||||
@decorator
|
||||
class DecoratedMappedNode:
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"image": ("IMAGE",),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"DecoratedMappedNode": DecoratedMappedNode,
|
||||
}
|
||||
'''
|
||||
result = self._extract_source(source, "decorated-mapped-class-pack")
|
||||
|
||||
self.assertEqual({}, result["nodes"])
|
||||
self.assertEqual("no_static_nodes", result["pack"]["status"])
|
||||
|
||||
def test_node_mapping_key_uses_assignment_time_env(self):
|
||||
source = '''
|
||||
KEY = "Original"
|
||||
@@ -1650,6 +1678,58 @@ KEY = "Wrong"
|
||||
self.assertEqual(["IMAGE"], result["nodes"]["Original"]["outputs"])
|
||||
self.assertEqual("ok", result["pack"]["status"])
|
||||
|
||||
def test_module_class_return_types_patch_after_mapping_skips_node(self):
|
||||
source = '''
|
||||
class PatchedReturnTypesNode:
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"image": ("IMAGE",),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"PatchedReturnTypesNode": PatchedReturnTypesNode,
|
||||
}
|
||||
PatchedReturnTypesNode.RETURN_TYPES = ("MASK",)
|
||||
'''
|
||||
result = self._extract_source(source, "patched-return-types-pack")
|
||||
|
||||
self.assertEqual({}, result["nodes"])
|
||||
self.assertEqual("no_static_nodes", result["pack"]["status"])
|
||||
|
||||
def test_module_class_input_types_patch_after_mapping_skips_node(self):
|
||||
source = '''
|
||||
def build_inputs():
|
||||
return {"required": {"mask": ("MASK",)}}
|
||||
|
||||
|
||||
class PatchedInputTypesNode:
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"image": ("IMAGE",),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"PatchedInputTypesNode": PatchedInputTypesNode,
|
||||
}
|
||||
PatchedInputTypesNode.INPUT_TYPES = build_inputs
|
||||
'''
|
||||
result = self._extract_source(source, "patched-input-types-pack")
|
||||
|
||||
self.assertEqual({}, result["nodes"])
|
||||
self.assertEqual("no_static_nodes", result["pack"]["status"])
|
||||
|
||||
def test_mutated_node_class_mapping_skips_node(self):
|
||||
source = '''
|
||||
class MutatedMappingNode:
|
||||
|
||||
@@ -79,6 +79,73 @@ def _target_names(target):
|
||||
return set()
|
||||
|
||||
|
||||
def _root_name(node):
|
||||
while isinstance(node, (ast.Attribute, ast.Subscript)):
|
||||
node = node.value
|
||||
if isinstance(node, ast.Name):
|
||||
return node.id
|
||||
return None
|
||||
|
||||
|
||||
def _attribute_target_base_names(target):
|
||||
if isinstance(target, ast.Attribute):
|
||||
name = _root_name(target.value)
|
||||
return {name} if name else set()
|
||||
if isinstance(target, ast.Subscript):
|
||||
return _attribute_target_base_names(target.value)
|
||||
if isinstance(target, (ast.List, ast.Tuple)):
|
||||
names = set()
|
||||
for item in target.elts:
|
||||
names.update(_attribute_target_base_names(item))
|
||||
return names
|
||||
if isinstance(target, ast.Starred):
|
||||
return _attribute_target_base_names(target.value)
|
||||
return set()
|
||||
|
||||
|
||||
def _class_attribute_mutation_target_names(stmt):
|
||||
names = set()
|
||||
|
||||
class AttributeMutationVisitor(ast.NodeVisitor):
|
||||
def visit_FunctionDef(self, node):
|
||||
return None
|
||||
|
||||
def visit_AsyncFunctionDef(self, node):
|
||||
return None
|
||||
|
||||
def visit_ClassDef(self, node):
|
||||
return None
|
||||
|
||||
def visit_Lambda(self, node):
|
||||
return None
|
||||
|
||||
def visit_Assign(self, node):
|
||||
for target in node.targets:
|
||||
names.update(_attribute_target_base_names(target))
|
||||
self.visit(node.value)
|
||||
|
||||
def visit_AnnAssign(self, node):
|
||||
names.update(_attribute_target_base_names(node.target))
|
||||
if node.value is not None:
|
||||
self.visit(node.value)
|
||||
|
||||
def visit_AugAssign(self, node):
|
||||
names.update(_attribute_target_base_names(node.target))
|
||||
self.visit(node.value)
|
||||
|
||||
def visit_Delete(self, node):
|
||||
for target in node.targets:
|
||||
names.update(_attribute_target_base_names(target))
|
||||
|
||||
def visit_Call(self, node):
|
||||
if isinstance(node.func, ast.Attribute) and node.func.attr in _MUTATING_METHODS:
|
||||
names.update(_attribute_target_base_names(node.func.value))
|
||||
self.generic_visit(node)
|
||||
|
||||
AttributeMutationVisitor().visit(stmt)
|
||||
return names
|
||||
|
||||
|
||||
def _pattern_bound_names(pattern):
|
||||
names = set()
|
||||
if isinstance(pattern, ast.MatchAs):
|
||||
@@ -347,6 +414,9 @@ def _apply_module_stmt_to_env(stmt, env, class_bindings=None):
|
||||
env.pop(name, None)
|
||||
if isinstance(stmt, ast.ClassDef):
|
||||
if class_bindings is not None:
|
||||
if stmt.decorator_list:
|
||||
class_bindings.pop(stmt.name, None)
|
||||
else:
|
||||
class_bindings[stmt.name] = (stmt, dict(env))
|
||||
env.pop(stmt.name, None)
|
||||
return
|
||||
@@ -635,11 +705,19 @@ def _module_dict_entries(node, env, class_bindings, value_converter):
|
||||
return result
|
||||
|
||||
|
||||
def _final_module_dict(tree, name, value_converter):
|
||||
def _final_module_dict(tree, name, value_converter, value_invalidated_by_names=None):
|
||||
value_invalidated_by_names = value_invalidated_by_names or (lambda _value, _names: False)
|
||||
value = _MISSING
|
||||
env = {}
|
||||
class_bindings = {}
|
||||
for stmt in tree.body:
|
||||
class_attr_names = _class_attribute_mutation_target_names(stmt)
|
||||
if (
|
||||
value not in (_MISSING, _INVALID)
|
||||
and class_attr_names
|
||||
and value_invalidated_by_names(value, class_attr_names)
|
||||
):
|
||||
value = _INVALID
|
||||
if name in _mutating_call_target_names(stmt):
|
||||
value = _INVALID
|
||||
if isinstance(stmt, ast.Assign):
|
||||
@@ -712,14 +790,26 @@ def _mapping_value_binding(value, env, class_bindings):
|
||||
class_name = _mapping_value_name(value)
|
||||
if class_name is None:
|
||||
return None
|
||||
return class_bindings.get(class_name)
|
||||
binding = class_bindings.get(class_name)
|
||||
if binding is None:
|
||||
return None
|
||||
return class_name, binding
|
||||
|
||||
|
||||
def _node_mapping_invalidated_by_names(value, names):
|
||||
return any(class_name in names for class_name, _binding in value.values())
|
||||
|
||||
|
||||
def _node_class_mappings(tree):
|
||||
if _has_module_wildcard_import(tree):
|
||||
return {}
|
||||
mappings = _final_module_dict(tree, "NODE_CLASS_MAPPINGS", _mapping_value_binding)
|
||||
return {str(node_type): binding for node_type, binding in mappings.items() if node_type and binding is not None}
|
||||
mappings = _final_module_dict(
|
||||
tree,
|
||||
"NODE_CLASS_MAPPINGS",
|
||||
_mapping_value_binding,
|
||||
_node_mapping_invalidated_by_names,
|
||||
)
|
||||
return {str(node_type): binding for node_type, (_class_name, binding) in mappings.items() if node_type}
|
||||
|
||||
|
||||
def _display_mappings(tree):
|
||||
|
||||
Reference in New Issue
Block a user