Resolve class lookups through namespace aliases

This commit is contained in:
2026-07-02 20:51:12 +02:00
parent 42aeafd0e9
commit f23d4ae69a
2 changed files with 159 additions and 45 deletions
@@ -3773,6 +3773,57 @@ globals().get("GlobalsGetReturnTypesNode").RETURN_TYPES.clear()
self.assertEqual({}, result["nodes"]) self.assertEqual({}, result["nodes"])
self.assertEqual("no_static_nodes", result["pack"]["status"]) self.assertEqual("no_static_nodes", result["pack"]["status"])
def test_namespace_alias_class_return_types_mutation_after_mapping_skips_node(self):
source = '''
class NamespaceAliasReturnTypesNode:
RETURN_TYPES = ["IMAGE"]
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"image": ("IMAGE",),
},
}
NODE_CLASS_MAPPINGS = {
"NamespaceAliasReturnTypesNode": NamespaceAliasReturnTypesNode,
}
ns = globals()
ns["NamespaceAliasReturnTypesNode"].RETURN_TYPES.clear()
'''
result = self._extract_source(source, "namespace-alias-return-types-pack")
self.assertEqual({}, result["nodes"])
self.assertEqual("no_static_nodes", result["pack"]["status"])
def test_namespace_alias_get_class_alias_patch_after_mapping_skips_node(self):
source = '''
class NamespaceAliasGetPatchedNode:
RETURN_TYPES = ["IMAGE"]
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"image": ("IMAGE",),
},
}
NODE_CLASS_MAPPINGS = {
"NamespaceAliasGetPatchedNode": NamespaceAliasGetPatchedNode,
}
ns = globals()
Alias = ns.get("NamespaceAliasGetPatchedNode")
Alias.RETURN_TYPES.clear()
'''
result = self._extract_source(source, "namespace-alias-get-patched-node-pack")
self.assertEqual({}, result["nodes"])
self.assertEqual("no_static_nodes", result["pack"]["status"])
def test_getattr_class_attribute_alias_mutation_after_mapping_skips_node(self): def test_getattr_class_attribute_alias_mutation_after_mapping_skips_node(self):
source = ''' source = '''
class GetattrAttributeAliasNode: class GetattrAttributeAliasNode:
+108 -45
View File
@@ -157,12 +157,19 @@ def _direct_target_names(target):
return set() return set()
def _root_name(node): def _root_name(node, namespace_aliases=None):
namespace_aliases = namespace_aliases or set()
while True: while True:
name = _namespace_lookup_name(node) name = _namespace_lookup_name(node)
if name is not None: if name is not None:
return name return name
name = _namespace_subscript_name(node) name = _namespace_subscript_name(node)
if name is not None:
return name
name = _namespace_alias_lookup_name(node, namespace_aliases)
if name is not None:
return name
name = _namespace_alias_subscript_name(node, namespace_aliases)
if name is not None: if name is not None:
return name return name
if not isinstance(node, (ast.Attribute, ast.Subscript)): if not isinstance(node, (ast.Attribute, ast.Subscript)):
@@ -173,14 +180,14 @@ def _root_name(node):
return None return None
def _getattr_signature_target_names(node): def _getattr_signature_target_names(node, namespace_aliases=None):
if not isinstance(node, ast.Call): if not isinstance(node, ast.Call):
return set() return set()
if not isinstance(node.func, ast.Name) or node.func.id != "getattr": if not isinstance(node.func, ast.Name) or node.func.id != "getattr":
return set() return set()
if len(node.args) < 2: if len(node.args) < 2:
return set() return set()
name = _root_name(node.args[0]) name = _root_name(node.args[0], namespace_aliases)
if name is None: if name is None:
return set() return set()
attr = node.args[1] attr = node.args[1]
@@ -241,26 +248,26 @@ def _name_invalidated_by(name, names):
return name in names or _DYNAMIC_NAMESPACE_MUTATION in names return name in names or _DYNAMIC_NAMESPACE_MUTATION in names
def _attribute_target_base_names(target): def _attribute_target_base_names(target, namespace_aliases=None):
if isinstance(target, ast.Attribute): if isinstance(target, ast.Attribute):
name = _root_name(target.value) name = _root_name(target.value, namespace_aliases)
return {name} if name else set() return {name} if name else set()
names = _getattr_signature_target_names(target) names = _getattr_signature_target_names(target, namespace_aliases)
if names: if names:
return names return names
if isinstance(target, ast.Subscript): if isinstance(target, ast.Subscript):
return _attribute_target_base_names(target.value) return _attribute_target_base_names(target.value, namespace_aliases)
if isinstance(target, (ast.List, ast.Tuple)): if isinstance(target, (ast.List, ast.Tuple)):
names = set() names = set()
for item in target.elts: for item in target.elts:
names.update(_attribute_target_base_names(item)) names.update(_attribute_target_base_names(item, namespace_aliases))
return names return names
if isinstance(target, ast.Starred): if isinstance(target, ast.Starred):
return _attribute_target_base_names(target.value) return _attribute_target_base_names(target.value, namespace_aliases)
return set() return set()
def _setattr_delattr_target_names(node): def _setattr_delattr_target_names(node, namespace_aliases=None):
if not isinstance(node, ast.Call): if not isinstance(node, ast.Call):
return set() return set()
if not isinstance(node.func, ast.Name) or node.func.id not in {"delattr", "setattr"}: if not isinstance(node.func, ast.Name) or node.func.id not in {"delattr", "setattr"}:
@@ -274,11 +281,12 @@ def _setattr_delattr_target_names(node):
and attr.value not in _CLASS_SIGNATURE_ATTRS and attr.value not in _CLASS_SIGNATURE_ATTRS
): ):
return set() return set()
name = _root_name(node.args[0]) name = _root_name(node.args[0], namespace_aliases)
return {name} if name else set() return {name} if name else set()
def _class_attribute_mutation_target_names(stmt): def _class_attribute_mutation_target_names(stmt, namespace_aliases=None):
namespace_aliases = namespace_aliases or set()
names = set() names = set()
class AttributeMutationVisitor(ast.NodeVisitor): class AttributeMutationVisitor(ast.NodeVisitor):
@@ -314,54 +322,56 @@ def _class_attribute_mutation_target_names(stmt):
def visit_Assign(self, node): def visit_Assign(self, node):
for target in node.targets: for target in node.targets:
names.update(_attribute_target_base_names(target)) names.update(_attribute_target_base_names(target, namespace_aliases))
self.visit(node.value) self.visit(node.value)
def visit_AnnAssign(self, node): def visit_AnnAssign(self, node):
names.update(_attribute_target_base_names(node.target)) names.update(_attribute_target_base_names(node.target, namespace_aliases))
if node.value is not None: if node.value is not None:
self.visit(node.value) self.visit(node.value)
def visit_AugAssign(self, node): def visit_AugAssign(self, node):
names.update(_attribute_target_base_names(node.target)) names.update(_attribute_target_base_names(node.target, namespace_aliases))
self.visit(node.value) self.visit(node.value)
def visit_Delete(self, node): def visit_Delete(self, node):
for target in node.targets: for target in node.targets:
names.update(_attribute_target_base_names(target)) names.update(_attribute_target_base_names(target, namespace_aliases))
def visit_Call(self, node): def visit_Call(self, node):
names.update(_setattr_delattr_target_names(node)) names.update(_setattr_delattr_target_names(node, namespace_aliases))
names.update(_getattr_mutating_method_target_names(node)) names.update(_getattr_mutating_method_target_names(node))
names.update(_namespace_mutating_call_target_names(node)) names.update(_namespace_mutating_call_target_names(node))
if isinstance(node.func, ast.Attribute) and node.func.attr in _MUTATING_METHODS: if isinstance(node.func, ast.Attribute) and node.func.attr in _MUTATING_METHODS:
names.update(_attribute_target_base_names(node.func.value)) names.update(_attribute_target_base_names(node.func.value, namespace_aliases))
self.generic_visit(node) self.generic_visit(node)
AttributeMutationVisitor().visit(stmt) AttributeMutationVisitor().visit(stmt)
return names return names
def _signature_attribute_reference_names(node): def _signature_attribute_reference_names(node, namespace_aliases=None):
namespace_aliases = namespace_aliases or set()
names = set() names = set()
class SignatureAttributeReferenceVisitor(ast.NodeVisitor): class SignatureAttributeReferenceVisitor(ast.NodeVisitor):
def visit_Attribute(self, child): def visit_Attribute(self, child):
if child.attr in _CLASS_SIGNATURE_ATTRS: if child.attr in _CLASS_SIGNATURE_ATTRS:
name = _root_name(child.value) name = _root_name(child.value, namespace_aliases)
if name is not None: if name is not None:
names.add(name) names.add(name)
self.generic_visit(child) self.generic_visit(child)
def visit_Call(self, child): def visit_Call(self, child):
names.update(_getattr_signature_target_names(child)) names.update(_getattr_signature_target_names(child, namespace_aliases))
self.generic_visit(child) self.generic_visit(child)
SignatureAttributeReferenceVisitor().visit(node) SignatureAttributeReferenceVisitor().visit(node)
return names return names
def _class_attribute_observed_target_names(stmt): def _class_attribute_observed_target_names(stmt, namespace_aliases=None):
namespace_aliases = namespace_aliases or set()
names = set() names = set()
class AttributeObservationVisitor(ast.NodeVisitor): class AttributeObservationVisitor(ast.NodeVisitor):
@@ -397,11 +407,11 @@ def _class_attribute_observed_target_names(stmt):
def visit_Call(self, node): def visit_Call(self, node):
if isinstance(node.func, ast.Attribute): if isinstance(node.func, ast.Attribute):
names.update(_signature_attribute_reference_names(node.func.value)) names.update(_signature_attribute_reference_names(node.func.value, namespace_aliases))
for arg in node.args: for arg in node.args:
names.update(_signature_attribute_reference_names(arg)) names.update(_signature_attribute_reference_names(arg, namespace_aliases))
for keyword in node.keywords: for keyword in node.keywords:
names.update(_signature_attribute_reference_names(keyword.value)) names.update(_signature_attribute_reference_names(keyword.value, namespace_aliases))
self.generic_visit(node) self.generic_visit(node)
AttributeObservationVisitor().visit(stmt) AttributeObservationVisitor().visit(stmt)
@@ -1514,7 +1524,8 @@ def _module_dict_entries(node, env, class_bindings, value_converter):
return result return result
def _class_alias_sources(value, class_aliases, class_bindings): def _class_alias_sources(value, class_aliases, class_bindings, namespace_aliases=None):
namespace_aliases = namespace_aliases or set()
if isinstance(value, ast.Name): if isinstance(value, ast.Name):
if value.id in class_aliases: if value.id in class_aliases:
return set(class_aliases[value.id]) return set(class_aliases[value.id])
@@ -1524,10 +1535,12 @@ def _class_alias_sources(value, class_aliases, class_bindings):
if isinstance(value, (ast.Tuple, ast.List)): if isinstance(value, (ast.Tuple, ast.List)):
sources = set() sources = set()
for item in value.elts: for item in value.elts:
sources.update(_class_alias_sources(item, class_aliases, class_bindings)) sources.update(_class_alias_sources(item, class_aliases, class_bindings, namespace_aliases))
return sources return sources
name = _namespace_subscript_name(value) or _namespace_lookup_name(value) name = _namespace_subscript_name(value) or _namespace_lookup_name(value)
name = name or _namespace_alias_subscript_name(value, namespace_aliases)
name = name or _namespace_alias_lookup_name(value, namespace_aliases)
if name in class_aliases: if name in class_aliases:
return set(class_aliases[name]) return set(class_aliases[name])
if name in class_bindings: if name in class_bindings:
@@ -1535,17 +1548,18 @@ def _class_alias_sources(value, class_aliases, class_bindings):
return set() return set()
def _update_class_alias_from_unpack(target, value, class_aliases, class_bindings): def _update_class_alias_from_unpack(target, value, class_aliases, class_bindings, namespace_aliases):
for target_item, value_item in _unpack_target_value_pairs(target, value): for target_item, value_item in _unpack_target_value_pairs(target, value):
target_name = _alias_target_name(target_item) target_name = _alias_target_name(target_item)
if target_name is None: if target_name is None:
continue continue
sources = _class_alias_sources(value_item, class_aliases, class_bindings) sources = _class_alias_sources(value_item, class_aliases, class_bindings, namespace_aliases)
if sources: if sources:
class_aliases[target_name] = sources class_aliases[target_name] = sources
def _update_class_aliases(stmt, class_aliases, class_bindings): def _update_class_aliases(stmt, class_aliases, class_bindings, namespace_aliases=None):
namespace_aliases = namespace_aliases or set()
rebound_names = _assignment_target_names(stmt) | _delete_target_names(stmt) | _bound_names(stmt) rebound_names = _assignment_target_names(stmt) | _delete_target_names(stmt) | _bound_names(stmt)
for name in rebound_names: for name in rebound_names:
class_aliases.pop(name, None) class_aliases.pop(name, None)
@@ -1556,20 +1570,26 @@ def _update_class_aliases(stmt, class_aliases, class_bindings):
return return
if isinstance(stmt, ast.Assign) and len(stmt.targets) == 1 and isinstance(stmt.targets[0], ast.Name): if isinstance(stmt, ast.Assign) and len(stmt.targets) == 1 and isinstance(stmt.targets[0], ast.Name):
sources = _class_alias_sources(stmt.value, class_aliases, class_bindings) sources = _class_alias_sources(stmt.value, class_aliases, class_bindings, namespace_aliases)
if sources: if sources:
class_aliases[stmt.targets[0].id] = sources class_aliases[stmt.targets[0].id] = sources
elif isinstance(stmt, ast.Assign) and len(stmt.targets) > 1: elif isinstance(stmt, ast.Assign) and len(stmt.targets) > 1:
sources = _class_alias_sources(stmt.value, class_aliases, class_bindings) sources = _class_alias_sources(stmt.value, class_aliases, class_bindings, namespace_aliases)
if sources: if sources:
for target in stmt.targets: for target in stmt.targets:
target_name = _alias_target_name(target) target_name = _alias_target_name(target)
if target_name is not None: if target_name is not None:
class_aliases[target_name] = sources class_aliases[target_name] = sources
elif isinstance(stmt, ast.Assign) and len(stmt.targets) == 1: elif isinstance(stmt, ast.Assign) and len(stmt.targets) == 1:
_update_class_alias_from_unpack(stmt.targets[0], stmt.value, class_aliases, class_bindings) _update_class_alias_from_unpack(
stmt.targets[0],
stmt.value,
class_aliases,
class_bindings,
namespace_aliases,
)
elif isinstance(stmt, ast.AnnAssign) and isinstance(stmt.target, ast.Name) and stmt.value is not None: elif isinstance(stmt, ast.AnnAssign) and isinstance(stmt.target, ast.Name) and stmt.value is not None:
sources = _class_alias_sources(stmt.value, class_aliases, class_bindings) sources = _class_alias_sources(stmt.value, class_aliases, class_bindings, namespace_aliases)
if sources: if sources:
class_aliases[stmt.target.id] = sources class_aliases[stmt.target.id] = sources
@@ -1581,7 +1601,14 @@ def _expanded_class_attribute_names(names, class_aliases):
return expanded return expanded
def _class_attribute_alias_sources(value, class_attribute_aliases, class_aliases, class_bindings): def _class_attribute_alias_sources(
value,
class_attribute_aliases,
class_aliases,
class_bindings,
namespace_aliases=None,
):
namespace_aliases = namespace_aliases or set()
if isinstance(value, ast.Name): if isinstance(value, ast.Name):
return set(class_attribute_aliases.get(value.id, ())) return set(class_attribute_aliases.get(value.id, ()))
if isinstance(value, (ast.Tuple, ast.List)): if isinstance(value, (ast.Tuple, ast.List)):
@@ -1593,17 +1620,18 @@ def _class_attribute_alias_sources(value, class_attribute_aliases, class_aliases
class_attribute_aliases, class_attribute_aliases,
class_aliases, class_aliases,
class_bindings, class_bindings,
namespace_aliases,
) )
) )
return sources return sources
names = set() names = set()
if isinstance(value, ast.Attribute) and value.attr in _CLASS_SIGNATURE_ATTRS: if isinstance(value, ast.Attribute) and value.attr in _CLASS_SIGNATURE_ATTRS:
name = _root_name(value.value) name = _root_name(value.value, namespace_aliases)
if name is not None: if name is not None:
names.add(name) names.add(name)
else: else:
names.update(_getattr_signature_target_names(value)) names.update(_getattr_signature_target_names(value, namespace_aliases))
sources = set() sources = set()
for name in names: for name in names:
@@ -1620,6 +1648,7 @@ def _update_class_attribute_alias_from_unpack(
class_attribute_aliases, class_attribute_aliases,
class_aliases, class_aliases,
class_bindings, class_bindings,
namespace_aliases,
): ):
for target_item, value_item in _unpack_target_value_pairs(target, value): for target_item, value_item in _unpack_target_value_pairs(target, value):
target_name = _alias_target_name(target_item) target_name = _alias_target_name(target_item)
@@ -1630,6 +1659,7 @@ def _update_class_attribute_alias_from_unpack(
class_attribute_aliases, class_attribute_aliases,
class_aliases, class_aliases,
class_bindings, class_bindings,
namespace_aliases,
) )
if sources: if sources:
class_attribute_aliases[target_name] = sources class_attribute_aliases[target_name] = sources
@@ -1649,7 +1679,14 @@ def _class_attribute_alias_invalidated_names(stmt, class_attribute_aliases):
return invalidated return invalidated
def _update_class_attribute_aliases(stmt, class_attribute_aliases, class_aliases, class_bindings): def _update_class_attribute_aliases(
stmt,
class_attribute_aliases,
class_aliases,
class_bindings,
namespace_aliases=None,
):
namespace_aliases = namespace_aliases or set()
rebound_names = _assignment_target_names(stmt) | _delete_target_names(stmt) | _bound_names(stmt) rebound_names = _assignment_target_names(stmt) | _delete_target_names(stmt) | _bound_names(stmt)
for name in rebound_names: for name in rebound_names:
class_attribute_aliases.pop(name, None) class_attribute_aliases.pop(name, None)
@@ -1660,6 +1697,7 @@ def _update_class_attribute_aliases(stmt, class_attribute_aliases, class_aliases
class_attribute_aliases, class_attribute_aliases,
class_aliases, class_aliases,
class_bindings, class_bindings,
namespace_aliases,
) )
if sources: if sources:
class_attribute_aliases[stmt.targets[0].id] = sources class_attribute_aliases[stmt.targets[0].id] = sources
@@ -1669,6 +1707,7 @@ def _update_class_attribute_aliases(stmt, class_attribute_aliases, class_aliases
class_attribute_aliases, class_attribute_aliases,
class_aliases, class_aliases,
class_bindings, class_bindings,
namespace_aliases,
) )
if sources: if sources:
for target in stmt.targets: for target in stmt.targets:
@@ -1682,6 +1721,7 @@ def _update_class_attribute_aliases(stmt, class_attribute_aliases, class_aliases
class_attribute_aliases, class_attribute_aliases,
class_aliases, class_aliases,
class_bindings, class_bindings,
namespace_aliases,
) )
elif isinstance(stmt, ast.AnnAssign) and isinstance(stmt.target, ast.Name) and stmt.value is not None: elif isinstance(stmt, ast.AnnAssign) and isinstance(stmt.target, ast.Name) and stmt.value is not None:
sources = _class_attribute_alias_sources( sources = _class_attribute_alias_sources(
@@ -1689,19 +1729,26 @@ def _update_class_attribute_aliases(stmt, class_attribute_aliases, class_aliases
class_attribute_aliases, class_attribute_aliases,
class_aliases, class_aliases,
class_bindings, class_bindings,
namespace_aliases,
) )
if sources: if sources:
class_attribute_aliases[stmt.target.id] = sources class_attribute_aliases[stmt.target.id] = sources
def _module_class_attribute_invalidated_names(stmt, class_aliases, class_attribute_aliases): def _module_class_attribute_invalidated_names(
stmt,
class_aliases,
class_attribute_aliases,
namespace_aliases=None,
):
namespace_aliases = namespace_aliases or set()
names = _expanded_class_attribute_names( names = _expanded_class_attribute_names(
_class_attribute_mutation_target_names(stmt), _class_attribute_mutation_target_names(stmt, namespace_aliases),
class_aliases, class_aliases,
) )
names.update( names.update(
_expanded_class_attribute_names( _expanded_class_attribute_names(
_class_attribute_observed_target_names(stmt), _class_attribute_observed_target_names(stmt, namespace_aliases),
class_aliases, class_aliases,
) )
) )
@@ -1957,11 +2004,22 @@ def _final_module_dict(tree, name, value_converter, value_invalidated_by_names=N
def advance_module_state(stmt): def advance_module_state(stmt):
_invalidate_class_bindings( _invalidate_class_bindings(
class_bindings, class_bindings,
_module_class_attribute_invalidated_names(stmt, class_aliases, class_attribute_aliases), _module_class_attribute_invalidated_names(
stmt,
class_aliases,
class_attribute_aliases,
namespace_aliases,
),
) )
_apply_module_stmt_to_env(stmt, env, class_bindings) _apply_module_stmt_to_env(stmt, env, class_bindings)
_update_class_aliases(stmt, class_aliases, class_bindings) _update_class_aliases(stmt, class_aliases, class_bindings, namespace_aliases)
_update_class_attribute_aliases(stmt, class_attribute_aliases, class_aliases, class_bindings) _update_class_attribute_aliases(
stmt,
class_attribute_aliases,
class_aliases,
class_bindings,
namespace_aliases,
)
_update_module_dict_aliases(stmt, name, module_dict_aliases, namespace_aliases) _update_module_dict_aliases(stmt, name, module_dict_aliases, namespace_aliases)
_update_namespace_aliases(stmt, namespace_aliases) _update_namespace_aliases(stmt, namespace_aliases)
@@ -1969,7 +2027,12 @@ def _final_module_dict(tree, name, value_converter, value_invalidated_by_names=N
class_body_module_mutations = ( class_body_module_mutations = (
_class_body_module_mutation_names(stmt) if isinstance(stmt, ast.ClassDef) else set() _class_body_module_mutation_names(stmt) if isinstance(stmt, ast.ClassDef) else set()
) )
class_attr_names = _module_class_attribute_invalidated_names(stmt, class_aliases, class_attribute_aliases) class_attr_names = _module_class_attribute_invalidated_names(
stmt,
class_aliases,
class_attribute_aliases,
namespace_aliases,
)
if ( if (
value not in (_MISSING, _INVALID) value not in (_MISSING, _INVALID)
and class_attr_names and class_attr_names