diff --git a/tests/test_generate_popular_node_signatures.py b/tests/test_generate_popular_node_signatures.py index 12023d1..b4c49fb 100644 --- a/tests/test_generate_popular_node_signatures.py +++ b/tests/test_generate_popular_node_signatures.py @@ -1889,6 +1889,63 @@ NODE_CLASS_MAPPINGS = { self.assertEqual({}, result["nodes"]) self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_class_with_base_mapping_skips_node(self): + source = ''' +class Base: + def __init_subclass__(cls): + cls.RETURN_TYPES = ("MASK",) + + +class HookedNode(Base): + RETURN_TYPES = ("IMAGE",) + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image": ("IMAGE",), + }, + } + + +NODE_CLASS_MAPPINGS = { + "HookedNode": HookedNode, +} +''' + result = self._extract_source(source, "hooked-base-class-pack") + + self.assertEqual({}, result["nodes"]) + self.assertEqual("no_static_nodes", result["pack"]["status"]) + + def test_class_with_metaclass_mapping_skips_node(self): + source = ''' +class Meta(type): + def __new__(mcls, name, bases, attrs): + attrs["RETURN_TYPES"] = ("MASK",) + return super().__new__(mcls, name, bases, attrs) + + +class MetaNode(metaclass=Meta): + RETURN_TYPES = ("IMAGE",) + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image": ("IMAGE",), + }, + } + + +NODE_CLASS_MAPPINGS = { + "MetaNode": MetaNode, +} +''' + result = self._extract_source(source, "metaclass-pack") + + self.assertEqual({}, result["nodes"]) + self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_node_mapping_key_uses_assignment_time_env(self): source = ''' KEY = "Original" diff --git a/tools/generate_popular_node_signatures.py b/tools/generate_popular_node_signatures.py index 5f9c018..d479296 100644 --- a/tools/generate_popular_node_signatures.py +++ b/tools/generate_popular_node_signatures.py @@ -605,6 +605,16 @@ def _invalidate_class_bindings(class_bindings, names): class_bindings.pop(name, None) +def _is_trivially_safe_class_def(stmt): + return ( + isinstance(stmt, ast.ClassDef) + and not stmt.decorator_list + and not stmt.bases + and not stmt.keywords + and not getattr(stmt, "type_params", ()) + ) + + def _apply_module_stmt_to_env(stmt, env, class_bindings=None): names = _mutating_call_target_names(stmt) if _DYNAMIC_NAMESPACE_MUTATION in names: @@ -616,10 +626,10 @@ def _apply_module_stmt_to_env(stmt, env, class_bindings=None): _invalidate_env_names(env, names) if isinstance(stmt, ast.ClassDef): if class_bindings is not None: - if stmt.decorator_list: - class_bindings.pop(stmt.name, None) - else: + if _is_trivially_safe_class_def(stmt): class_bindings[stmt.name] = (stmt, dict(env)) + else: + class_bindings.pop(stmt.name, None) _invalidate_env_name(env, stmt.name) return if isinstance(stmt, ast.Assign): @@ -999,7 +1009,7 @@ def _update_class_aliases(stmt, class_aliases, class_bindings): class_aliases.pop(name, None) if isinstance(stmt, ast.ClassDef): - if stmt.name in class_bindings and not stmt.decorator_list: + if stmt.name in class_bindings and _is_trivially_safe_class_def(stmt): class_aliases[stmt.name] = {stmt.name} return