diff --git a/tests/test_generate_popular_node_signatures.py b/tests/test_generate_popular_node_signatures.py index 72f13d3..f25107d 100644 --- a/tests/test_generate_popular_node_signatures.py +++ b/tests/test_generate_popular_node_signatures.py @@ -1288,6 +1288,55 @@ NODE_CLASS_MAPPINGS = { self.assertEqual({}, result["nodes"]) self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_getattr_method_mutation_to_return_types_skips_node(self): + source = ''' +class GetattrMethodMutatedReturnTypesNode: + RETURN_TYPES = ["IMAGE"] + getattr(RETURN_TYPES, "clear")() + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image": ("IMAGE",), + }, + } + + +NODE_CLASS_MAPPINGS = { + "GetattrMethodMutatedReturnTypesNode": GetattrMethodMutatedReturnTypesNode, +} +''' + result = self._extract_source(source, "getattr-method-mutated-return-types-pack") + + self.assertEqual({}, result["nodes"]) + self.assertEqual("no_static_nodes", result["pack"]["status"]) + + def test_getattr_method_mutation_to_return_names_skips_node(self): + source = ''' +class GetattrMethodMutatedReturnNamesNode: + RETURN_TYPES = ("IMAGE",) + RETURN_NAMES = ["image"] + getattr(RETURN_NAMES, "clear")() + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image": ("IMAGE",), + }, + } + + +NODE_CLASS_MAPPINGS = { + "GetattrMethodMutatedReturnNamesNode": GetattrMethodMutatedReturnNamesNode, +} +''' + result = self._extract_source(source, "getattr-method-mutated-return-names-pack") + + self.assertEqual({}, result["nodes"]) + self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_except_handler_binding_to_return_types_skips_node(self): source = ''' class ExceptHandlerBoundReturnTypesNode: diff --git a/tools/generate_popular_node_signatures.py b/tools/generate_popular_node_signatures.py index 5ebca0f..ff85e1d 100644 --- a/tools/generate_popular_node_signatures.py +++ b/tools/generate_popular_node_signatures.py @@ -150,6 +150,23 @@ def _getattr_signature_target_names(node): return {name} +def _getattr_mutating_method_target_names(node): + if not isinstance(node, ast.Call): + return set() + if not isinstance(node.func, ast.Call): + return set() + getattr_call = node.func + if not isinstance(getattr_call.func, ast.Name) or getattr_call.func.id != "getattr": + return set() + if len(getattr_call.args) < 2: + return set() + method = getattr_call.args[1] + if isinstance(method, ast.Constant) and isinstance(method.value, str): + if method.value not in _MUTATING_METHODS: + return set() + return _target_names(getattr_call.args[0]) + + def _attribute_target_base_names(target): if isinstance(target, ast.Attribute): name = _root_name(target.value) @@ -241,6 +258,7 @@ def _class_attribute_mutation_target_names(stmt): def visit_Call(self, node): names.update(_setattr_delattr_target_names(node)) + names.update(_getattr_mutating_method_target_names(node)) if isinstance(node.func, ast.Attribute) and node.func.attr in _MUTATING_METHODS: names.update(_attribute_target_base_names(node.func.value)) self.generic_visit(node) @@ -409,6 +427,7 @@ def _mutating_call_target_names(stmt): def visit_Call(self, node): names.update(_setattr_delattr_target_names(node)) + names.update(_getattr_mutating_method_target_names(node)) if isinstance(node.func, ast.Attribute) and node.func.attr in _MUTATING_METHODS: names.update(_target_names(node.func.value)) self.generic_visit(node)