diff --git a/tests/test_generate_popular_node_signatures.py b/tests/test_generate_popular_node_signatures.py index 4e21db8..f2e138a 100644 --- a/tests/test_generate_popular_node_signatures.py +++ b/tests/test_generate_popular_node_signatures.py @@ -999,6 +999,45 @@ NODE_CLASS_MAPPINGS = { self.assertEqual({"image": "IMAGE"}, result["nodes"]["RuntimeBodyMutatedInputEnvNode"]["inputs"]) self.assertEqual("ok", result["pack"]["status"]) + def test_class_body_global_assignment_invalidates_static_input_env(self): + source = ''' +INPUTS = { + "required": { + "image": ("IMAGE",), + }, +} + + +def build_inputs(): + return { + "required": { + "mask": ("MASK",), + }, + } + + +class MutatesModuleAtDefinition: + global INPUTS + INPUTS = build_inputs() + + +class ClassGlobalMutatedInputEnvNode: + RETURN_TYPES = ("IMAGE",) + + @classmethod + def INPUT_TYPES(cls): + return INPUTS + + +NODE_CLASS_MAPPINGS = { + "ClassGlobalMutatedInputEnvNode": ClassGlobalMutatedInputEnvNode, +} +''' + result = self._extract_source(source, "class-global-mutated-input-env-pack") + + self.assertEqual({}, result["nodes"]) + self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_nested_mutable_env_literal_skips_static_node(self): source = ''' REQ = { @@ -1303,6 +1342,32 @@ NODE_CLASS_MAPPINGS = { self.assertEqual({}, result["nodes"]) self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_input_types_with_duplicate_required_optional_name_skips_node(self): + source = ''' +class DuplicateInputNameNode: + RETURN_TYPES = ("IMAGE",) + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "x": ("IMAGE",), + }, + "optional": { + "x": ("MASK",), + }, + } + + +NODE_CLASS_MAPPINGS = { + "DuplicateInputNameNode": DuplicateInputNameNode, +} +''' + result = self._extract_source(source, "duplicate-input-name-pack") + + self.assertEqual({}, result["nodes"]) + self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_input_types_with_non_string_input_name_skips_node(self): source = ''' class NonStringInputNameNode: diff --git a/tools/generate_popular_node_signatures.py b/tools/generate_popular_node_signatures.py index b3a6e02..45d953c 100644 --- a/tools/generate_popular_node_signatures.py +++ b/tools/generate_popular_node_signatures.py @@ -800,8 +800,140 @@ def _is_trivially_safe_class_def(stmt): ) +def _namespace_assignment_target_names(target): + name = _namespace_subscript_name(target) + if name is not None: + return {name} + if isinstance(target, ast.Attribute): + return _namespace_assignment_target_names(target.value) + if isinstance(target, ast.Subscript): + return _namespace_assignment_target_names(target.value) + if isinstance(target, (ast.List, ast.Tuple)): + names = set() + for item in target.elts: + names.update(_namespace_assignment_target_names(item)) + return names + if isinstance(target, ast.Starred): + return _namespace_assignment_target_names(target.value) + return set() + + +def _class_body_global_names(cls): + names = set() + + class GlobalVisitor(ast.NodeVisitor): + def visit_FunctionDef(self, node): + return None + + def visit_AsyncFunctionDef(self, node): + return None + + def visit_ClassDef(self, node): + return None + + def visit_Global(self, node): + names.update(node.names) + + for stmt in cls.body: + GlobalVisitor().visit(stmt) + return names + + +def _class_body_module_mutation_names(cls): + global_names = _class_body_global_names(cls) + names = set() + + def add_assignment_targets(stmt): + names.update(_assignment_target_names(stmt).intersection(global_names)) + if isinstance(stmt, ast.Assign): + for target in stmt.targets: + names.update(_namespace_assignment_target_names(target)) + elif isinstance(stmt, (ast.AnnAssign, ast.AugAssign)): + names.update(_namespace_assignment_target_names(stmt.target)) + elif isinstance(stmt, (ast.For, ast.AsyncFor)): + names.update(_namespace_assignment_target_names(stmt.target)) + + class ClassBodyMutationVisitor(ast.NodeVisitor): + def _visit_function_definition_expressions(self, node): + names.update(_mutating_call_target_names(node)) + names.update(_namespace_alias_mutation_target_names(node, set())) + + def visit_FunctionDef(self, node): + self._visit_function_definition_expressions(node) + + def visit_AsyncFunctionDef(self, node): + self._visit_function_definition_expressions(node) + + def visit_ClassDef(self, node): + for decorator in node.decorator_list: + self.visit(decorator) + for base in node.bases: + self.visit(base) + for keyword in node.keywords: + self.visit(keyword.value) + for type_param in getattr(node, "type_params", ()): + self.visit(type_param) + names.update(_class_body_module_mutation_names(node)) + + def visit_Assign(self, node): + add_assignment_targets(node) + self.visit(node.value) + + def visit_AnnAssign(self, node): + add_assignment_targets(node) + if node.value is not None: + self.visit(node.value) + + def visit_AugAssign(self, node): + add_assignment_targets(node) + self.visit(node.value) + + def visit_Delete(self, node): + names.update(_delete_target_names(node).intersection(global_names)) + for target in node.targets: + names.update(_namespace_assignment_target_names(target)) + + def visit_For(self, node): + add_assignment_targets(node) + self.generic_visit(node) + + def visit_AsyncFor(self, node): + add_assignment_targets(node) + self.generic_visit(node) + + def visit_With(self, node): + for item in node.items: + if item.optional_vars is not None: + names.update(_target_names(item.optional_vars).intersection(global_names)) + names.update(_namespace_assignment_target_names(item.optional_vars)) + self.generic_visit(node) + + def visit_AsyncWith(self, node): + for item in node.items: + if item.optional_vars is not None: + names.update(_target_names(item.optional_vars).intersection(global_names)) + names.update(_namespace_assignment_target_names(item.optional_vars)) + self.generic_visit(node) + + def visit_Import(self, node): + names.update(_bound_names(node).intersection(global_names)) + + def visit_ImportFrom(self, node): + names.update(_bound_names(node).intersection(global_names)) + + def visit_Call(self, node): + names.update(_namespace_mutating_call_target_names(node)) + self.generic_visit(node) + + for stmt in cls.body: + ClassBodyMutationVisitor().visit(stmt) + return names + + def _apply_module_stmt_to_env(stmt, env, class_bindings=None): names = _mutating_call_target_names(stmt) + if isinstance(stmt, ast.ClassDef): + names.update(_class_body_module_mutation_names(stmt)) if _DYNAMIC_NAMESPACE_MUTATION in names: env.clear() if class_bindings is not None: @@ -1721,6 +1853,9 @@ def _final_module_dict(tree, name, value_converter, value_invalidated_by_names=N _update_namespace_aliases(stmt, namespace_aliases) for stmt in tree.body: + class_body_module_mutations = ( + _class_body_module_mutation_names(stmt) if isinstance(stmt, ast.ClassDef) else set() + ) class_attr_names = _module_class_attribute_invalidated_names(stmt, class_aliases, class_attribute_aliases) if ( value not in (_MISSING, _INVALID) @@ -1732,6 +1867,9 @@ def _final_module_dict(tree, name, value_converter, value_invalidated_by_names=N if _name_invalidated_by(name, _mutating_call_target_names(stmt)): value = _INVALID sticky_invalid = True + if _name_invalidated_by(name, class_body_module_mutations): + value = _INVALID + sticky_invalid = True if _name_invalidated_by(name, _arbitrary_call_observed_names(stmt)): value = _INVALID sticky_invalid = True @@ -1905,6 +2043,8 @@ def _signature_from_class(node_type, cls, display, pack_meta, class_env, input_e for name, spec in values.items(): if not isinstance(name, str): return None + if name in inputs: + return None input_type = normalise_input_spec(spec) if input_type is None: return None