Invalidate mapped classes on signature attribute observation

This commit is contained in:
2026-07-02 17:53:04 +02:00
parent 07822bc3ec
commit c6d2b2d645
2 changed files with 122 additions and 0 deletions
@@ -2532,6 +2532,55 @@ RET.clear()
self.assertEqual({}, result["nodes"])
self.assertEqual("no_static_nodes", result["pack"]["status"])
def test_module_class_attribute_arbitrary_call_after_mapping_skips_node(self):
source = '''
class ObserveAttributeCallNode:
RETURN_TYPES = ["IMAGE"]
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"image": ("IMAGE",),
},
}
NODE_CLASS_MAPPINGS = {
"ObserveAttributeCallNode": ObserveAttributeCallNode,
}
mutate(ObserveAttributeCallNode.RETURN_TYPES)
'''
result = self._extract_source(source, "observe-attribute-call-pack")
self.assertEqual({}, result["nodes"])
self.assertEqual("no_static_nodes", result["pack"]["status"])
def test_module_class_attribute_alias_arbitrary_call_after_mapping_skips_node(self):
source = '''
class ObserveAttributeAliasCallNode:
RETURN_TYPES = ["IMAGE"]
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"image": ("IMAGE",),
},
}
NODE_CLASS_MAPPINGS = {
"ObserveAttributeAliasCallNode": ObserveAttributeAliasCallNode,
}
RET = ObserveAttributeAliasCallNode.RETURN_TYPES
mutate(RET)
'''
result = self._extract_source(source, "observe-attribute-alias-call-pack")
self.assertEqual({}, result["nodes"])
self.assertEqual("no_static_nodes", result["pack"]["status"])
def test_getattr_return_types_mutation_after_mapping_skips_node(self):
source = '''
class GetattrReturnTypesNode:
+73
View File
@@ -323,6 +323,72 @@ def _class_attribute_mutation_target_names(stmt):
return names
def _signature_attribute_reference_names(node):
names = set()
class SignatureAttributeReferenceVisitor(ast.NodeVisitor):
def visit_Attribute(self, child):
if child.attr in _CLASS_SIGNATURE_ATTRS:
name = _root_name(child.value)
if name is not None:
names.add(name)
self.generic_visit(child)
def visit_Call(self, child):
names.update(_getattr_signature_target_names(child))
self.generic_visit(child)
SignatureAttributeReferenceVisitor().visit(node)
return names
def _class_attribute_observed_target_names(stmt):
names = set()
class AttributeObservationVisitor(ast.NodeVisitor):
def _visit_function_definition_expressions(self, node):
for decorator in node.decorator_list:
self.visit(decorator)
self.visit(node.args)
if node.returns is not None:
self.visit(node.returns)
for type_param in getattr(node, "type_params", ()):
self.visit(type_param)
def visit_FunctionDef(self, node):
self._visit_function_definition_expressions(node)
def visit_AsyncFunctionDef(self, node):
self._visit_function_definition_expressions(node)
def visit_ClassDef(self, node):
for decorator in node.decorator_list:
self.visit(decorator)
for base in node.bases:
self.visit(base)
for keyword in node.keywords:
self.visit(keyword.value)
for type_param in getattr(node, "type_params", ()):
self.visit(type_param)
for child in node.body:
self.visit(child)
def visit_Lambda(self, node):
self.visit(node.args)
def visit_Call(self, node):
if isinstance(node.func, ast.Attribute):
names.update(_signature_attribute_reference_names(node.func.value))
for arg in node.args:
names.update(_signature_attribute_reference_names(arg))
for keyword in node.keywords:
names.update(_signature_attribute_reference_names(keyword.value))
self.generic_visit(node)
AttributeObservationVisitor().visit(stmt)
return names
def _pattern_bound_names(pattern):
names = set()
if isinstance(pattern, ast.MatchAs):
@@ -1179,6 +1245,7 @@ def _update_class_attribute_alias_from_unpack(
def _class_attribute_alias_invalidated_names(stmt, class_attribute_aliases):
names = (
_mutating_call_target_names(stmt)
| _arbitrary_call_observed_names(stmt)
| _assignment_target_names(stmt)
| _delete_target_names(stmt)
| _bound_names(stmt)
@@ -1227,6 +1294,12 @@ def _module_class_attribute_invalidated_names(stmt, class_aliases, class_attribu
_class_attribute_mutation_target_names(stmt),
class_aliases,
)
names.update(
_expanded_class_attribute_names(
_class_attribute_observed_target_names(stmt),
class_aliases,
)
)
names.update(_class_attribute_alias_invalidated_names(stmt, class_attribute_aliases))
return names