diff --git a/tests/test_generate_popular_node_signatures.py b/tests/test_generate_popular_node_signatures.py index f7c0676..8b94935 100644 --- a/tests/test_generate_popular_node_signatures.py +++ b/tests/test_generate_popular_node_signatures.py @@ -1659,6 +1659,107 @@ NODE_CLASS_MAPPINGS = { self.assertEqual({}, result["nodes"]) self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_class_body_locals_alias_return_types_setitem_skips_node(self): + source = ''' +class LocalsAliasReturnTypesNode: + RETURN_TYPES = ("IMAGE",) + ns = locals() + ns["RETURN_TYPES"] = ("MASK",) + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image": ("IMAGE",), + }, + } + + +NODE_CLASS_MAPPINGS = { + "LocalsAliasReturnTypesNode": LocalsAliasReturnTypesNode, +} +''' + result = self._extract_source(source, "locals-alias-return-types-pack") + + self.assertEqual({}, result["nodes"]) + self.assertEqual("no_static_nodes", result["pack"]["status"]) + + def test_class_body_vars_alias_return_types_update_skips_node(self): + source = ''' +class VarsAliasReturnTypesNode: + RETURN_TYPES = ("IMAGE",) + ns = vars() + ns.update(RETURN_TYPES=("MASK",)) + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image": ("IMAGE",), + }, + } + + +NODE_CLASS_MAPPINGS = { + "VarsAliasReturnTypesNode": VarsAliasReturnTypesNode, +} +''' + result = self._extract_source(source, "vars-alias-return-types-pack") + + self.assertEqual({}, result["nodes"]) + self.assertEqual("no_static_nodes", result["pack"]["status"]) + + def test_class_body_chained_namespace_alias_return_types_setitem_skips_node(self): + source = ''' +class ChainedNamespaceAliasReturnTypesNode: + RETURN_TYPES = ("IMAGE",) + ns = other = locals() + other["RETURN_TYPES"] = ("MASK",) + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image": ("IMAGE",), + }, + } + + +NODE_CLASS_MAPPINGS = { + "ChainedNamespaceAliasReturnTypesNode": ChainedNamespaceAliasReturnTypesNode, +} +''' + result = self._extract_source(source, "chained-namespace-alias-return-types-pack") + + self.assertEqual({}, result["nodes"]) + self.assertEqual("no_static_nodes", result["pack"]["status"]) + + def test_class_body_namespace_alias_return_names_update_skips_node(self): + source = ''' +class NamespaceAliasReturnNamesNode: + RETURN_TYPES = ("IMAGE",) + RETURN_NAMES = ["image"] + ns = locals() + ns.update(RETURN_NAMES=["mask"]) + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image": ("IMAGE",), + }, + } + + +NODE_CLASS_MAPPINGS = { + "NamespaceAliasReturnNamesNode": NamespaceAliasReturnNamesNode, +} +''' + result = self._extract_source(source, "namespace-alias-return-names-pack") + + self.assertEqual({}, result["nodes"]) + self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_except_handler_binding_to_return_types_skips_node(self): source = ''' class ExceptHandlerBoundReturnTypesNode: @@ -4234,6 +4335,40 @@ NODE_CLASS_MAPPINGS = { self.assertEqual({}, result["nodes"]) self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_class_body_namespace_alias_input_types_patch_skips_node(self): + source = ''' +def build_inputs(cls): + return { + "required": { + "mask": ("MASK",), + }, + } + + +class NamespaceAliasInputTypesPatchNode: + RETURN_TYPES = ("IMAGE",) + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image": ("IMAGE",), + }, + } + + ns = locals() + ns["INPUT_TYPES"] = build_inputs + + +NODE_CLASS_MAPPINGS = { + "NamespaceAliasInputTypesPatchNode": NamespaceAliasInputTypesPatchNode, +} +''' + result = self._extract_source(source, "namespace-alias-input-types-patch-pack") + + self.assertEqual({}, result["nodes"]) + self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_input_types_alias_observed_by_arbitrary_call_skips_node(self): source = ''' class AliasObservedInputTypesNode: diff --git a/tools/generate_popular_node_signatures.py b/tools/generate_popular_node_signatures.py index 31561f8..d146b2e 100644 --- a/tools/generate_popular_node_signatures.py +++ b/tools/generate_popular_node_signatures.py @@ -933,6 +933,15 @@ def _class_body_module_mutation_names(cls): return names +def _class_body_namespace_mutation_names(cls): + names = set() + namespace_aliases = set() + for stmt in cls.body: + names.update(_namespace_alias_mutation_target_names(stmt, namespace_aliases)) + _update_namespace_aliases(stmt, namespace_aliases) + return names + + def _apply_module_stmt_to_env(stmt, env, class_bindings=None): names = _mutating_call_target_names(stmt) if isinstance(stmt, ast.ClassDef): @@ -1165,6 +1174,9 @@ def _update_input_types_aliases_from_unpack(target, value, aliases): def _class_attr(cls, name, env): value = _MISSING aliases = set() + namespace_mutations = _class_body_namespace_mutation_names(cls) + if _name_invalidated_by(name, namespace_mutations): + return _INVALID for stmt in cls.body: mutating_targets = _mutating_call_target_names(stmt) observed_targets = _arbitrary_call_observed_names(stmt) @@ -1295,6 +1307,9 @@ def _input_types(cls, env, decorator_env): value = _MISSING aliases = set() classmethod_shadowed = "classmethod" in decorator_env + namespace_mutations = _class_body_namespace_mutation_names(cls) + if _name_invalidated_by("INPUT_TYPES", namespace_mutations): + return None for stmt in cls.body: mutating_targets = _mutating_call_target_names(stmt) observed_targets = _arbitrary_call_observed_names(stmt)