diff --git a/tests/test_generate_popular_node_signatures.py b/tests/test_generate_popular_node_signatures.py index 3ba51ea..832d76a 100644 --- a/tests/test_generate_popular_node_signatures.py +++ b/tests/test_generate_popular_node_signatures.py @@ -906,6 +906,34 @@ NODE_CLASS_MAPPINGS = { self.assertEqual({}, result["nodes"]) self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_nested_mutable_env_subscript_alias_skips_static_node(self): + source = ''' +INPUTS = { + "required": { + "image": ("IMAGE",), + }, +} +REQ = INPUTS["required"] +REQ.clear() + + +class NestedMutableEnvSubscriptAliasNode: + RETURN_TYPES = ("IMAGE",) + + @classmethod + def INPUT_TYPES(cls): + return INPUTS + + +NODE_CLASS_MAPPINGS = { + "NestedMutableEnvSubscriptAliasNode": NestedMutableEnvSubscriptAliasNode, +} +''' + result = self._extract_source(source, "nested-mutable-env-subscript-alias-pack") + + self.assertEqual({}, result["nodes"]) + self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_post_class_input_reassignment_skips_static_node(self): source = ''' def build_inputs(): @@ -996,6 +1024,39 @@ NODE_CLASS_MAPPINGS = { self.assertEqual({}, result["nodes"]) self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_decorated_input_types_skips_node(self): + source = ''' +def replace(fn): + def replacement(cls): + return { + "required": { + "mask": ("MASK",), + }, + } + return replacement + + +class DecoratedInputTypesNode: + RETURN_TYPES = ("IMAGE",) + + @replace + def INPUT_TYPES(cls): + return { + "required": { + "image": ("IMAGE",), + }, + } + + +NODE_CLASS_MAPPINGS = { + "DecoratedInputTypesNode": DecoratedInputTypesNode, +} +''' + result = self._extract_source(source, "decorated-input-types-pack") + + self.assertEqual({}, result["nodes"]) + self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_dynamic_return_types_reassignment_skips_node(self): source = ''' def build_outputs(): diff --git a/tools/generate_popular_node_signatures.py b/tools/generate_popular_node_signatures.py index b8ea88b..58f112c 100644 --- a/tools/generate_popular_node_signatures.py +++ b/tools/generate_popular_node_signatures.py @@ -443,6 +443,11 @@ def _apply_module_stmt_to_env(stmt, env, class_bindings=None): _invalidate_class_bindings(class_bindings, names) if len(stmt.targets) == 1 and isinstance(stmt.targets[0], ast.Name): name = stmt.targets[0].id + subscript_root = _mutable_env_subscript_root(stmt.value, env) + if subscript_root is not None: + env.pop(subscript_root, None) + env.pop(name, None) + return if ( isinstance(stmt.value, ast.Name) and stmt.value.id in env @@ -466,6 +471,11 @@ def _apply_module_stmt_to_env(stmt, env, class_bindings=None): return if isinstance(stmt.target, ast.Name): name = stmt.target.id + subscript_root = _mutable_env_subscript_root(stmt.value, env) + if subscript_root is not None: + env.pop(subscript_root, None) + env.pop(name, None) + return if ( isinstance(stmt.value, ast.Name) and stmt.value.id in env @@ -544,6 +554,19 @@ def _is_mutable_env_reference(node, env): return isinstance(node, ast.Name) and node.id in env and _is_mutable_static_value(env[node.id]) +def _mutable_env_subscript_root(node, env): + if not isinstance(node, ast.Subscript): + return None + name = _root_name(node) + if name in env and _is_mutable_static_value(env[name]): + return name + return None + + +def _input_types_decorators_are_supported(decorators): + return all(isinstance(decorator, ast.Name) and decorator.id == "classmethod" for decorator in decorators) + + def _class_attr(cls, name, env): value = _MISSING aliases = set() @@ -657,6 +680,9 @@ def _input_types(cls, env): if "INPUT_TYPES" in _mutating_call_target_names(stmt): value = _INVALID if isinstance(stmt, ast.FunctionDef) and stmt.name == "INPUT_TYPES": + if not _input_types_decorators_are_supported(stmt.decorator_list): + value = _INVALID + continue if len(stmt.body) != 1 or not isinstance(stmt.body[0], ast.Return): value = _INVALID continue