diff --git a/tests/test_generate_popular_node_signatures.py b/tests/test_generate_popular_node_signatures.py index eeca5c1..2f60eeb 100644 --- a/tests/test_generate_popular_node_signatures.py +++ b/tests/test_generate_popular_node_signatures.py @@ -2532,6 +2532,55 @@ RET.clear() self.assertEqual({}, result["nodes"]) self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_module_class_attribute_arbitrary_call_after_mapping_skips_node(self): + source = ''' +class ObserveAttributeCallNode: + RETURN_TYPES = ["IMAGE"] + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image": ("IMAGE",), + }, + } + + +NODE_CLASS_MAPPINGS = { + "ObserveAttributeCallNode": ObserveAttributeCallNode, +} +mutate(ObserveAttributeCallNode.RETURN_TYPES) +''' + result = self._extract_source(source, "observe-attribute-call-pack") + + self.assertEqual({}, result["nodes"]) + self.assertEqual("no_static_nodes", result["pack"]["status"]) + + def test_module_class_attribute_alias_arbitrary_call_after_mapping_skips_node(self): + source = ''' +class ObserveAttributeAliasCallNode: + RETURN_TYPES = ["IMAGE"] + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image": ("IMAGE",), + }, + } + + +NODE_CLASS_MAPPINGS = { + "ObserveAttributeAliasCallNode": ObserveAttributeAliasCallNode, +} +RET = ObserveAttributeAliasCallNode.RETURN_TYPES +mutate(RET) +''' + result = self._extract_source(source, "observe-attribute-alias-call-pack") + + self.assertEqual({}, result["nodes"]) + self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_getattr_return_types_mutation_after_mapping_skips_node(self): source = ''' class GetattrReturnTypesNode: diff --git a/tools/generate_popular_node_signatures.py b/tools/generate_popular_node_signatures.py index 85bd45e..db8986a 100644 --- a/tools/generate_popular_node_signatures.py +++ b/tools/generate_popular_node_signatures.py @@ -323,6 +323,72 @@ def _class_attribute_mutation_target_names(stmt): return names +def _signature_attribute_reference_names(node): + names = set() + + class SignatureAttributeReferenceVisitor(ast.NodeVisitor): + def visit_Attribute(self, child): + if child.attr in _CLASS_SIGNATURE_ATTRS: + name = _root_name(child.value) + if name is not None: + names.add(name) + self.generic_visit(child) + + def visit_Call(self, child): + names.update(_getattr_signature_target_names(child)) + self.generic_visit(child) + + SignatureAttributeReferenceVisitor().visit(node) + return names + + +def _class_attribute_observed_target_names(stmt): + names = set() + + class AttributeObservationVisitor(ast.NodeVisitor): + def _visit_function_definition_expressions(self, node): + for decorator in node.decorator_list: + self.visit(decorator) + self.visit(node.args) + if node.returns is not None: + self.visit(node.returns) + for type_param in getattr(node, "type_params", ()): + self.visit(type_param) + + def visit_FunctionDef(self, node): + self._visit_function_definition_expressions(node) + + def visit_AsyncFunctionDef(self, node): + self._visit_function_definition_expressions(node) + + def visit_ClassDef(self, node): + for decorator in node.decorator_list: + self.visit(decorator) + for base in node.bases: + self.visit(base) + for keyword in node.keywords: + self.visit(keyword.value) + for type_param in getattr(node, "type_params", ()): + self.visit(type_param) + for child in node.body: + self.visit(child) + + def visit_Lambda(self, node): + self.visit(node.args) + + def visit_Call(self, node): + if isinstance(node.func, ast.Attribute): + names.update(_signature_attribute_reference_names(node.func.value)) + for arg in node.args: + names.update(_signature_attribute_reference_names(arg)) + for keyword in node.keywords: + names.update(_signature_attribute_reference_names(keyword.value)) + self.generic_visit(node) + + AttributeObservationVisitor().visit(stmt) + return names + + def _pattern_bound_names(pattern): names = set() if isinstance(pattern, ast.MatchAs): @@ -1179,6 +1245,7 @@ def _update_class_attribute_alias_from_unpack( def _class_attribute_alias_invalidated_names(stmt, class_attribute_aliases): names = ( _mutating_call_target_names(stmt) + | _arbitrary_call_observed_names(stmt) | _assignment_target_names(stmt) | _delete_target_names(stmt) | _bound_names(stmt) @@ -1227,6 +1294,12 @@ def _module_class_attribute_invalidated_names(stmt, class_aliases, class_attribu _class_attribute_mutation_target_names(stmt), class_aliases, ) + names.update( + _expanded_class_attribute_names( + _class_attribute_observed_target_names(stmt), + class_aliases, + ) + ) names.update(_class_attribute_alias_invalidated_names(stmt, class_attribute_aliases)) return names