diff --git a/tests/test_generate_popular_node_signatures.py b/tests/test_generate_popular_node_signatures.py index 603f6b8..1cf79af 100644 --- a/tests/test_generate_popular_node_signatures.py +++ b/tests/test_generate_popular_node_signatures.py @@ -3773,6 +3773,57 @@ globals().get("GlobalsGetReturnTypesNode").RETURN_TYPES.clear() self.assertEqual({}, result["nodes"]) self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_namespace_alias_class_return_types_mutation_after_mapping_skips_node(self): + source = ''' +class NamespaceAliasReturnTypesNode: + RETURN_TYPES = ["IMAGE"] + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image": ("IMAGE",), + }, + } + + +NODE_CLASS_MAPPINGS = { + "NamespaceAliasReturnTypesNode": NamespaceAliasReturnTypesNode, +} +ns = globals() +ns["NamespaceAliasReturnTypesNode"].RETURN_TYPES.clear() +''' + result = self._extract_source(source, "namespace-alias-return-types-pack") + + self.assertEqual({}, result["nodes"]) + self.assertEqual("no_static_nodes", result["pack"]["status"]) + + def test_namespace_alias_get_class_alias_patch_after_mapping_skips_node(self): + source = ''' +class NamespaceAliasGetPatchedNode: + RETURN_TYPES = ["IMAGE"] + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image": ("IMAGE",), + }, + } + + +NODE_CLASS_MAPPINGS = { + "NamespaceAliasGetPatchedNode": NamespaceAliasGetPatchedNode, +} +ns = globals() +Alias = ns.get("NamespaceAliasGetPatchedNode") +Alias.RETURN_TYPES.clear() +''' + result = self._extract_source(source, "namespace-alias-get-patched-node-pack") + + self.assertEqual({}, result["nodes"]) + self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_getattr_class_attribute_alias_mutation_after_mapping_skips_node(self): source = ''' class GetattrAttributeAliasNode: diff --git a/tools/generate_popular_node_signatures.py b/tools/generate_popular_node_signatures.py index aff5b06..8214946 100644 --- a/tools/generate_popular_node_signatures.py +++ b/tools/generate_popular_node_signatures.py @@ -157,12 +157,19 @@ def _direct_target_names(target): return set() -def _root_name(node): +def _root_name(node, namespace_aliases=None): + namespace_aliases = namespace_aliases or set() while True: name = _namespace_lookup_name(node) if name is not None: return name name = _namespace_subscript_name(node) + if name is not None: + return name + name = _namespace_alias_lookup_name(node, namespace_aliases) + if name is not None: + return name + name = _namespace_alias_subscript_name(node, namespace_aliases) if name is not None: return name if not isinstance(node, (ast.Attribute, ast.Subscript)): @@ -173,14 +180,14 @@ def _root_name(node): return None -def _getattr_signature_target_names(node): +def _getattr_signature_target_names(node, namespace_aliases=None): if not isinstance(node, ast.Call): return set() if not isinstance(node.func, ast.Name) or node.func.id != "getattr": return set() if len(node.args) < 2: return set() - name = _root_name(node.args[0]) + name = _root_name(node.args[0], namespace_aliases) if name is None: return set() attr = node.args[1] @@ -241,26 +248,26 @@ def _name_invalidated_by(name, names): return name in names or _DYNAMIC_NAMESPACE_MUTATION in names -def _attribute_target_base_names(target): +def _attribute_target_base_names(target, namespace_aliases=None): if isinstance(target, ast.Attribute): - name = _root_name(target.value) + name = _root_name(target.value, namespace_aliases) return {name} if name else set() - names = _getattr_signature_target_names(target) + names = _getattr_signature_target_names(target, namespace_aliases) if names: return names if isinstance(target, ast.Subscript): - return _attribute_target_base_names(target.value) + return _attribute_target_base_names(target.value, namespace_aliases) if isinstance(target, (ast.List, ast.Tuple)): names = set() for item in target.elts: - names.update(_attribute_target_base_names(item)) + names.update(_attribute_target_base_names(item, namespace_aliases)) return names if isinstance(target, ast.Starred): - return _attribute_target_base_names(target.value) + return _attribute_target_base_names(target.value, namespace_aliases) return set() -def _setattr_delattr_target_names(node): +def _setattr_delattr_target_names(node, namespace_aliases=None): if not isinstance(node, ast.Call): return set() if not isinstance(node.func, ast.Name) or node.func.id not in {"delattr", "setattr"}: @@ -274,11 +281,12 @@ def _setattr_delattr_target_names(node): and attr.value not in _CLASS_SIGNATURE_ATTRS ): return set() - name = _root_name(node.args[0]) + name = _root_name(node.args[0], namespace_aliases) return {name} if name else set() -def _class_attribute_mutation_target_names(stmt): +def _class_attribute_mutation_target_names(stmt, namespace_aliases=None): + namespace_aliases = namespace_aliases or set() names = set() class AttributeMutationVisitor(ast.NodeVisitor): @@ -314,54 +322,56 @@ def _class_attribute_mutation_target_names(stmt): def visit_Assign(self, node): for target in node.targets: - names.update(_attribute_target_base_names(target)) + names.update(_attribute_target_base_names(target, namespace_aliases)) self.visit(node.value) def visit_AnnAssign(self, node): - names.update(_attribute_target_base_names(node.target)) + names.update(_attribute_target_base_names(node.target, namespace_aliases)) if node.value is not None: self.visit(node.value) def visit_AugAssign(self, node): - names.update(_attribute_target_base_names(node.target)) + names.update(_attribute_target_base_names(node.target, namespace_aliases)) self.visit(node.value) def visit_Delete(self, node): for target in node.targets: - names.update(_attribute_target_base_names(target)) + names.update(_attribute_target_base_names(target, namespace_aliases)) def visit_Call(self, node): - names.update(_setattr_delattr_target_names(node)) + names.update(_setattr_delattr_target_names(node, namespace_aliases)) names.update(_getattr_mutating_method_target_names(node)) names.update(_namespace_mutating_call_target_names(node)) if isinstance(node.func, ast.Attribute) and node.func.attr in _MUTATING_METHODS: - names.update(_attribute_target_base_names(node.func.value)) + names.update(_attribute_target_base_names(node.func.value, namespace_aliases)) self.generic_visit(node) AttributeMutationVisitor().visit(stmt) return names -def _signature_attribute_reference_names(node): +def _signature_attribute_reference_names(node, namespace_aliases=None): + namespace_aliases = namespace_aliases or set() names = set() class SignatureAttributeReferenceVisitor(ast.NodeVisitor): def visit_Attribute(self, child): if child.attr in _CLASS_SIGNATURE_ATTRS: - name = _root_name(child.value) + name = _root_name(child.value, namespace_aliases) if name is not None: names.add(name) self.generic_visit(child) def visit_Call(self, child): - names.update(_getattr_signature_target_names(child)) + names.update(_getattr_signature_target_names(child, namespace_aliases)) self.generic_visit(child) SignatureAttributeReferenceVisitor().visit(node) return names -def _class_attribute_observed_target_names(stmt): +def _class_attribute_observed_target_names(stmt, namespace_aliases=None): + namespace_aliases = namespace_aliases or set() names = set() class AttributeObservationVisitor(ast.NodeVisitor): @@ -397,11 +407,11 @@ def _class_attribute_observed_target_names(stmt): def visit_Call(self, node): if isinstance(node.func, ast.Attribute): - names.update(_signature_attribute_reference_names(node.func.value)) + names.update(_signature_attribute_reference_names(node.func.value, namespace_aliases)) for arg in node.args: - names.update(_signature_attribute_reference_names(arg)) + names.update(_signature_attribute_reference_names(arg, namespace_aliases)) for keyword in node.keywords: - names.update(_signature_attribute_reference_names(keyword.value)) + names.update(_signature_attribute_reference_names(keyword.value, namespace_aliases)) self.generic_visit(node) AttributeObservationVisitor().visit(stmt) @@ -1514,7 +1524,8 @@ def _module_dict_entries(node, env, class_bindings, value_converter): return result -def _class_alias_sources(value, class_aliases, class_bindings): +def _class_alias_sources(value, class_aliases, class_bindings, namespace_aliases=None): + namespace_aliases = namespace_aliases or set() if isinstance(value, ast.Name): if value.id in class_aliases: return set(class_aliases[value.id]) @@ -1524,10 +1535,12 @@ def _class_alias_sources(value, class_aliases, class_bindings): if isinstance(value, (ast.Tuple, ast.List)): sources = set() for item in value.elts: - sources.update(_class_alias_sources(item, class_aliases, class_bindings)) + sources.update(_class_alias_sources(item, class_aliases, class_bindings, namespace_aliases)) return sources name = _namespace_subscript_name(value) or _namespace_lookup_name(value) + name = name or _namespace_alias_subscript_name(value, namespace_aliases) + name = name or _namespace_alias_lookup_name(value, namespace_aliases) if name in class_aliases: return set(class_aliases[name]) if name in class_bindings: @@ -1535,17 +1548,18 @@ def _class_alias_sources(value, class_aliases, class_bindings): return set() -def _update_class_alias_from_unpack(target, value, class_aliases, class_bindings): +def _update_class_alias_from_unpack(target, value, class_aliases, class_bindings, namespace_aliases): for target_item, value_item in _unpack_target_value_pairs(target, value): target_name = _alias_target_name(target_item) if target_name is None: continue - sources = _class_alias_sources(value_item, class_aliases, class_bindings) + sources = _class_alias_sources(value_item, class_aliases, class_bindings, namespace_aliases) if sources: class_aliases[target_name] = sources -def _update_class_aliases(stmt, class_aliases, class_bindings): +def _update_class_aliases(stmt, class_aliases, class_bindings, namespace_aliases=None): + namespace_aliases = namespace_aliases or set() rebound_names = _assignment_target_names(stmt) | _delete_target_names(stmt) | _bound_names(stmt) for name in rebound_names: class_aliases.pop(name, None) @@ -1556,20 +1570,26 @@ def _update_class_aliases(stmt, class_aliases, class_bindings): return if isinstance(stmt, ast.Assign) and len(stmt.targets) == 1 and isinstance(stmt.targets[0], ast.Name): - sources = _class_alias_sources(stmt.value, class_aliases, class_bindings) + sources = _class_alias_sources(stmt.value, class_aliases, class_bindings, namespace_aliases) if sources: class_aliases[stmt.targets[0].id] = sources elif isinstance(stmt, ast.Assign) and len(stmt.targets) > 1: - sources = _class_alias_sources(stmt.value, class_aliases, class_bindings) + sources = _class_alias_sources(stmt.value, class_aliases, class_bindings, namespace_aliases) if sources: for target in stmt.targets: target_name = _alias_target_name(target) if target_name is not None: class_aliases[target_name] = sources elif isinstance(stmt, ast.Assign) and len(stmt.targets) == 1: - _update_class_alias_from_unpack(stmt.targets[0], stmt.value, class_aliases, class_bindings) + _update_class_alias_from_unpack( + stmt.targets[0], + stmt.value, + class_aliases, + class_bindings, + namespace_aliases, + ) elif isinstance(stmt, ast.AnnAssign) and isinstance(stmt.target, ast.Name) and stmt.value is not None: - sources = _class_alias_sources(stmt.value, class_aliases, class_bindings) + sources = _class_alias_sources(stmt.value, class_aliases, class_bindings, namespace_aliases) if sources: class_aliases[stmt.target.id] = sources @@ -1581,7 +1601,14 @@ def _expanded_class_attribute_names(names, class_aliases): return expanded -def _class_attribute_alias_sources(value, class_attribute_aliases, class_aliases, class_bindings): +def _class_attribute_alias_sources( + value, + class_attribute_aliases, + class_aliases, + class_bindings, + namespace_aliases=None, +): + namespace_aliases = namespace_aliases or set() if isinstance(value, ast.Name): return set(class_attribute_aliases.get(value.id, ())) if isinstance(value, (ast.Tuple, ast.List)): @@ -1593,17 +1620,18 @@ def _class_attribute_alias_sources(value, class_attribute_aliases, class_aliases class_attribute_aliases, class_aliases, class_bindings, + namespace_aliases, ) ) return sources names = set() if isinstance(value, ast.Attribute) and value.attr in _CLASS_SIGNATURE_ATTRS: - name = _root_name(value.value) + name = _root_name(value.value, namespace_aliases) if name is not None: names.add(name) else: - names.update(_getattr_signature_target_names(value)) + names.update(_getattr_signature_target_names(value, namespace_aliases)) sources = set() for name in names: @@ -1620,6 +1648,7 @@ def _update_class_attribute_alias_from_unpack( class_attribute_aliases, class_aliases, class_bindings, + namespace_aliases, ): for target_item, value_item in _unpack_target_value_pairs(target, value): target_name = _alias_target_name(target_item) @@ -1630,6 +1659,7 @@ def _update_class_attribute_alias_from_unpack( class_attribute_aliases, class_aliases, class_bindings, + namespace_aliases, ) if sources: class_attribute_aliases[target_name] = sources @@ -1649,7 +1679,14 @@ def _class_attribute_alias_invalidated_names(stmt, class_attribute_aliases): return invalidated -def _update_class_attribute_aliases(stmt, class_attribute_aliases, class_aliases, class_bindings): +def _update_class_attribute_aliases( + stmt, + class_attribute_aliases, + class_aliases, + class_bindings, + namespace_aliases=None, +): + namespace_aliases = namespace_aliases or set() rebound_names = _assignment_target_names(stmt) | _delete_target_names(stmt) | _bound_names(stmt) for name in rebound_names: class_attribute_aliases.pop(name, None) @@ -1660,6 +1697,7 @@ def _update_class_attribute_aliases(stmt, class_attribute_aliases, class_aliases class_attribute_aliases, class_aliases, class_bindings, + namespace_aliases, ) if sources: class_attribute_aliases[stmt.targets[0].id] = sources @@ -1669,6 +1707,7 @@ def _update_class_attribute_aliases(stmt, class_attribute_aliases, class_aliases class_attribute_aliases, class_aliases, class_bindings, + namespace_aliases, ) if sources: for target in stmt.targets: @@ -1682,6 +1721,7 @@ def _update_class_attribute_aliases(stmt, class_attribute_aliases, class_aliases class_attribute_aliases, class_aliases, class_bindings, + namespace_aliases, ) elif isinstance(stmt, ast.AnnAssign) and isinstance(stmt.target, ast.Name) and stmt.value is not None: sources = _class_attribute_alias_sources( @@ -1689,19 +1729,26 @@ def _update_class_attribute_aliases(stmt, class_attribute_aliases, class_aliases class_attribute_aliases, class_aliases, class_bindings, + namespace_aliases, ) if sources: class_attribute_aliases[stmt.target.id] = sources -def _module_class_attribute_invalidated_names(stmt, class_aliases, class_attribute_aliases): +def _module_class_attribute_invalidated_names( + stmt, + class_aliases, + class_attribute_aliases, + namespace_aliases=None, +): + namespace_aliases = namespace_aliases or set() names = _expanded_class_attribute_names( - _class_attribute_mutation_target_names(stmt), + _class_attribute_mutation_target_names(stmt, namespace_aliases), class_aliases, ) names.update( _expanded_class_attribute_names( - _class_attribute_observed_target_names(stmt), + _class_attribute_observed_target_names(stmt, namespace_aliases), class_aliases, ) ) @@ -1957,11 +2004,22 @@ def _final_module_dict(tree, name, value_converter, value_invalidated_by_names=N def advance_module_state(stmt): _invalidate_class_bindings( class_bindings, - _module_class_attribute_invalidated_names(stmt, class_aliases, class_attribute_aliases), + _module_class_attribute_invalidated_names( + stmt, + class_aliases, + class_attribute_aliases, + namespace_aliases, + ), ) _apply_module_stmt_to_env(stmt, env, class_bindings) - _update_class_aliases(stmt, class_aliases, class_bindings) - _update_class_attribute_aliases(stmt, class_attribute_aliases, class_aliases, class_bindings) + _update_class_aliases(stmt, class_aliases, class_bindings, namespace_aliases) + _update_class_attribute_aliases( + stmt, + class_attribute_aliases, + class_aliases, + class_bindings, + namespace_aliases, + ) _update_module_dict_aliases(stmt, name, module_dict_aliases, namespace_aliases) _update_namespace_aliases(stmt, namespace_aliases) @@ -1969,7 +2027,12 @@ def _final_module_dict(tree, name, value_converter, value_invalidated_by_names=N class_body_module_mutations = ( _class_body_module_mutation_names(stmt) if isinstance(stmt, ast.ClassDef) else set() ) - class_attr_names = _module_class_attribute_invalidated_names(stmt, class_aliases, class_attribute_aliases) + class_attr_names = _module_class_attribute_invalidated_names( + stmt, + class_aliases, + class_attribute_aliases, + namespace_aliases, + ) if ( value not in (_MISSING, _INVALID) and class_attr_names