Invalidate mapped classes on signature attribute observation
This commit is contained in:
@@ -2532,6 +2532,55 @@ RET.clear()
|
||||
self.assertEqual({}, result["nodes"])
|
||||
self.assertEqual("no_static_nodes", result["pack"]["status"])
|
||||
|
||||
def test_module_class_attribute_arbitrary_call_after_mapping_skips_node(self):
|
||||
source = '''
|
||||
class ObserveAttributeCallNode:
|
||||
RETURN_TYPES = ["IMAGE"]
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"image": ("IMAGE",),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"ObserveAttributeCallNode": ObserveAttributeCallNode,
|
||||
}
|
||||
mutate(ObserveAttributeCallNode.RETURN_TYPES)
|
||||
'''
|
||||
result = self._extract_source(source, "observe-attribute-call-pack")
|
||||
|
||||
self.assertEqual({}, result["nodes"])
|
||||
self.assertEqual("no_static_nodes", result["pack"]["status"])
|
||||
|
||||
def test_module_class_attribute_alias_arbitrary_call_after_mapping_skips_node(self):
|
||||
source = '''
|
||||
class ObserveAttributeAliasCallNode:
|
||||
RETURN_TYPES = ["IMAGE"]
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"image": ("IMAGE",),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"ObserveAttributeAliasCallNode": ObserveAttributeAliasCallNode,
|
||||
}
|
||||
RET = ObserveAttributeAliasCallNode.RETURN_TYPES
|
||||
mutate(RET)
|
||||
'''
|
||||
result = self._extract_source(source, "observe-attribute-alias-call-pack")
|
||||
|
||||
self.assertEqual({}, result["nodes"])
|
||||
self.assertEqual("no_static_nodes", result["pack"]["status"])
|
||||
|
||||
def test_getattr_return_types_mutation_after_mapping_skips_node(self):
|
||||
source = '''
|
||||
class GetattrReturnTypesNode:
|
||||
|
||||
@@ -323,6 +323,72 @@ def _class_attribute_mutation_target_names(stmt):
|
||||
return names
|
||||
|
||||
|
||||
def _signature_attribute_reference_names(node):
|
||||
names = set()
|
||||
|
||||
class SignatureAttributeReferenceVisitor(ast.NodeVisitor):
|
||||
def visit_Attribute(self, child):
|
||||
if child.attr in _CLASS_SIGNATURE_ATTRS:
|
||||
name = _root_name(child.value)
|
||||
if name is not None:
|
||||
names.add(name)
|
||||
self.generic_visit(child)
|
||||
|
||||
def visit_Call(self, child):
|
||||
names.update(_getattr_signature_target_names(child))
|
||||
self.generic_visit(child)
|
||||
|
||||
SignatureAttributeReferenceVisitor().visit(node)
|
||||
return names
|
||||
|
||||
|
||||
def _class_attribute_observed_target_names(stmt):
|
||||
names = set()
|
||||
|
||||
class AttributeObservationVisitor(ast.NodeVisitor):
|
||||
def _visit_function_definition_expressions(self, node):
|
||||
for decorator in node.decorator_list:
|
||||
self.visit(decorator)
|
||||
self.visit(node.args)
|
||||
if node.returns is not None:
|
||||
self.visit(node.returns)
|
||||
for type_param in getattr(node, "type_params", ()):
|
||||
self.visit(type_param)
|
||||
|
||||
def visit_FunctionDef(self, node):
|
||||
self._visit_function_definition_expressions(node)
|
||||
|
||||
def visit_AsyncFunctionDef(self, node):
|
||||
self._visit_function_definition_expressions(node)
|
||||
|
||||
def visit_ClassDef(self, node):
|
||||
for decorator in node.decorator_list:
|
||||
self.visit(decorator)
|
||||
for base in node.bases:
|
||||
self.visit(base)
|
||||
for keyword in node.keywords:
|
||||
self.visit(keyword.value)
|
||||
for type_param in getattr(node, "type_params", ()):
|
||||
self.visit(type_param)
|
||||
for child in node.body:
|
||||
self.visit(child)
|
||||
|
||||
def visit_Lambda(self, node):
|
||||
self.visit(node.args)
|
||||
|
||||
def visit_Call(self, node):
|
||||
if isinstance(node.func, ast.Attribute):
|
||||
names.update(_signature_attribute_reference_names(node.func.value))
|
||||
for arg in node.args:
|
||||
names.update(_signature_attribute_reference_names(arg))
|
||||
for keyword in node.keywords:
|
||||
names.update(_signature_attribute_reference_names(keyword.value))
|
||||
self.generic_visit(node)
|
||||
|
||||
AttributeObservationVisitor().visit(stmt)
|
||||
return names
|
||||
|
||||
|
||||
def _pattern_bound_names(pattern):
|
||||
names = set()
|
||||
if isinstance(pattern, ast.MatchAs):
|
||||
@@ -1179,6 +1245,7 @@ def _update_class_attribute_alias_from_unpack(
|
||||
def _class_attribute_alias_invalidated_names(stmt, class_attribute_aliases):
|
||||
names = (
|
||||
_mutating_call_target_names(stmt)
|
||||
| _arbitrary_call_observed_names(stmt)
|
||||
| _assignment_target_names(stmt)
|
||||
| _delete_target_names(stmt)
|
||||
| _bound_names(stmt)
|
||||
@@ -1227,6 +1294,12 @@ def _module_class_attribute_invalidated_names(stmt, class_aliases, class_attribu
|
||||
_class_attribute_mutation_target_names(stmt),
|
||||
class_aliases,
|
||||
)
|
||||
names.update(
|
||||
_expanded_class_attribute_names(
|
||||
_class_attribute_observed_target_names(stmt),
|
||||
class_aliases,
|
||||
)
|
||||
)
|
||||
names.update(_class_attribute_alias_invalidated_names(stmt, class_attribute_aliases))
|
||||
return names
|
||||
|
||||
|
||||
Reference in New Issue
Block a user