diff --git a/tests/test_generate_popular_node_signatures.py b/tests/test_generate_popular_node_signatures.py index 26b277e..0084351 100644 --- a/tests/test_generate_popular_node_signatures.py +++ b/tests/test_generate_popular_node_signatures.py @@ -2458,6 +2458,30 @@ NODE_CLASS_MAPPINGS = { self.assertEqual({}, result["nodes"]) self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_no_arg_class_body_arbitrary_call_after_return_types_skips_node(self): + source = ''' +class NoArgClassBodyCallReturnTypesNode: + RETURN_TYPES = ["IMAGE"] + mutate() + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image": ("IMAGE",), + }, + } + + +NODE_CLASS_MAPPINGS = { + "NoArgClassBodyCallReturnTypesNode": NoArgClassBodyCallReturnTypesNode, +} +''' + result = self._extract_source(source, "no-arg-class-body-call-return-types-pack") + + self.assertEqual({}, result["nodes"]) + self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_return_types_alias_arbitrary_call_skips_node(self): source = ''' class AliasArbitraryCallReturnTypesNode: @@ -3705,6 +3729,34 @@ mutate(ObserveAttributeCallNode.RETURN_TYPES) self.assertEqual({}, result["nodes"]) self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_no_arg_arbitrary_call_after_mapping_skips_mapped_class_signature(self): + source = ''' +class NoArgMappedClassMutationNode: + RETURN_TYPES = ["IMAGE"] + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image": ("IMAGE",), + }, + } + + +def mutate(): + NoArgMappedClassMutationNode.RETURN_TYPES.clear() + + +NODE_CLASS_MAPPINGS = { + "NoArgMappedClassMutationNode": NoArgMappedClassMutationNode, +} +mutate() +''' + result = self._extract_source(source, "no-arg-mapped-class-mutation-pack") + + self.assertEqual({}, result["nodes"]) + self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_module_class_attribute_alias_arbitrary_call_after_mapping_skips_node(self): source = ''' class ObserveAttributeAliasCallNode: @@ -4308,6 +4360,30 @@ mutate(NODE_CLASS_MAPPINGS) self.assertEqual({}, result["nodes"]) self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_no_arg_arbitrary_call_after_node_mapping_skips_node(self): + source = ''' +class NoArgMappingCallNode: + RETURN_TYPES = ("IMAGE",) + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image": ("IMAGE",), + }, + } + + +NODE_CLASS_MAPPINGS = { + "NoArgMappingCallNode": NoArgMappingCallNode, +} +mutate() +''' + result = self._extract_source(source, "no-arg-mapping-call-pack") + + self.assertEqual({}, result["nodes"]) + self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_alias_arbitrary_call_invalidates_static_node_mapping(self): source = ''' class AliasArbitraryCallMappingNode: diff --git a/tools/generate_popular_node_signatures.py b/tools/generate_popular_node_signatures.py index 3b2fa95..996dadf 100644 --- a/tools/generate_popular_node_signatures.py +++ b/tools/generate_popular_node_signatures.py @@ -663,6 +663,48 @@ def _arbitrary_call_observed_names(stmt): return names +def _has_arbitrary_call(stmt): + found = False + + class ArbitraryCallPresenceVisitor(ast.NodeVisitor): + def _visit_function_definition_expressions(self, node): + for decorator in node.decorator_list: + self.visit(decorator) + self.visit(node.args) + if node.returns is not None: + self.visit(node.returns) + for type_param in getattr(node, "type_params", ()): + self.visit(type_param) + + def visit_FunctionDef(self, node): + self._visit_function_definition_expressions(node) + + def visit_AsyncFunctionDef(self, node): + self._visit_function_definition_expressions(node) + + def visit_ClassDef(self, node): + for decorator in node.decorator_list: + self.visit(decorator) + for base in node.bases: + self.visit(base) + for keyword in node.keywords: + self.visit(keyword.value) + for type_param in getattr(node, "type_params", ()): + self.visit(type_param) + for child in node.body: + self.visit(child) + + def visit_Lambda(self, node): + self.visit(node.args) + + def visit_Call(self, node): + nonlocal found + found = True + + ArbitraryCallPresenceVisitor().visit(stmt) + return found + + def _definition_time_referenced_names(stmt): names = set() @@ -1010,6 +1052,14 @@ def _apply_module_stmt_to_env(stmt, env, class_bindings=None): for name in observed_names: if name in env and _is_mutable_static_value(env[name]): _invalidate_env_name(env, name) + if _has_arbitrary_call(stmt): + for name, value in list(env.items()): + if _is_mutable_static_value(value): + _invalidate_env_name(env, name) + if class_bindings is not None and not isinstance( + stmt, (ast.Assign, ast.AnnAssign, ast.AugAssign, ast.Delete) + ): + class_bindings.clear() if isinstance(stmt, ast.ClassDef): if class_bindings is not None: if _is_trivially_safe_class_def(stmt): @@ -1236,6 +1286,9 @@ def _class_attr(cls, name, env): mutating_targets = _mutating_call_target_names(stmt) observed_targets = _arbitrary_call_observed_names(stmt) expression_references = _class_body_expression_referenced_names(stmt) + has_arbitrary_call = _has_arbitrary_call(stmt) + if value not in (_MISSING, _INVALID) and has_arbitrary_call: + value = _INVALID if aliases.intersection(mutating_targets): value = _INVALID if name in mutating_targets: @@ -1376,9 +1429,11 @@ def _input_types(cls, env, decorator_env): observed_targets = _arbitrary_call_observed_names(stmt) definition_references = _definition_time_referenced_names(stmt) expression_references = _class_body_expression_referenced_names(stmt) + has_arbitrary_call = _has_arbitrary_call(stmt) protected_definition_references = _CLASS_SIGNATURE_ATTRS | aliases input_types_invalidated = ( - "INPUT_TYPES" in mutating_targets + (value not in (_MISSING, _INVALID) and has_arbitrary_call) + or "INPUT_TYPES" in mutating_targets or bool(aliases.intersection(mutating_targets)) or "INPUT_TYPES" in observed_targets or bool(aliases.intersection(observed_targets)) @@ -1389,6 +1444,10 @@ def _input_types(cls, env, decorator_env): value = _INVALID sticky_invalid = True if isinstance(stmt, ast.FunctionDef) and stmt.name == "INPUT_TYPES": + if has_arbitrary_call: + value = _INVALID + sticky_invalid = True + continue if input_types_invalidated or sticky_invalid: continue if not _input_types_decorators_are_supported(stmt.decorator_list, classmethod_shadowed): @@ -2053,6 +2112,9 @@ def _final_module_dict(tree, name, value_converter, value_invalidated_by_names=N if _name_invalidated_by(name, _arbitrary_call_observed_names(stmt)): value = _INVALID sticky_invalid = True + if value not in (_MISSING, _INVALID) and _has_arbitrary_call(stmt): + value = _INVALID + sticky_invalid = True if _name_invalidated_by(name, _namespace_alias_mutation_target_names(stmt, namespace_aliases)): value = _INVALID sticky_invalid = True