From bf46f9b389df5ccd1cfb98af4cdb3f0eea585e56 Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Thu, 2 Jul 2026 18:08:48 +0200 Subject: [PATCH] Fail closed on duplicate nodes and observed input types --- .../test_generate_popular_node_signatures.py | 102 ++++++++++++++++++ tools/generate_popular_node_signatures.py | 77 ++++++++++++- 2 files changed, 175 insertions(+), 4 deletions(-) diff --git a/tests/test_generate_popular_node_signatures.py b/tests/test_generate_popular_node_signatures.py index 70b9535..674a1b6 100644 --- a/tests/test_generate_popular_node_signatures.py +++ b/tests/test_generate_popular_node_signatures.py @@ -209,6 +209,57 @@ NODE_CLASS_MAPPINGS = { self.assertIn("GoodUtf8Node", result["nodes"]) self.assertEqual("ok", result["pack"]["status"]) + def test_duplicate_node_ids_across_files_are_skipped(self): + source_a = ''' +class FirstDupNode: + RETURN_TYPES = ("IMAGE",) + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image": ("IMAGE",), + }, + } + + +NODE_CLASS_MAPPINGS = { + "DupNode": FirstDupNode, +} +''' + source_b = ''' +class SecondDupNode: + RETURN_TYPES = ("MASK",) + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image": ("IMAGE",), + }, + } + + +NODE_CLASS_MAPPINGS = { + "DupNode": SecondDupNode, +} +''' + with tempfile.TemporaryDirectory() as tmp: + Path(tmp, "a.py").write_text(textwrap.dedent(source_a), encoding="utf-8") + Path(tmp, "b.py").write_text(textwrap.dedent(source_b), encoding="utf-8") + result = extract_repo_signatures( + Path(tmp), + { + "id": "duplicate-node-pack", + "title": "Duplicate Node Pack", + "repository": "https://github.com/example/duplicate-node-pack", + "rank": 1, + }, + ) + + self.assertEqual({}, result["nodes"]) + self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_unsupported_reassignment_invalidates_static_env_value(self): source = ''' def build_inputs(): @@ -3391,6 +3442,57 @@ NODE_CLASS_MAPPINGS = { self.assertEqual({}, result["nodes"]) self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_input_types_observed_by_arbitrary_call_skips_node(self): + source = ''' +class ObservedInputTypesNode: + RETURN_TYPES = ("IMAGE",) + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image": ("IMAGE",), + }, + } + + observe(INPUT_TYPES) + + +NODE_CLASS_MAPPINGS = { + "ObservedInputTypesNode": ObservedInputTypesNode, +} +''' + result = self._extract_source(source, "observed-input-types-pack") + + self.assertEqual({}, result["nodes"]) + self.assertEqual("no_static_nodes", result["pack"]["status"]) + + def test_input_types_alias_observed_by_arbitrary_call_skips_node(self): + source = ''' +class AliasObservedInputTypesNode: + RETURN_TYPES = ("IMAGE",) + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image": ("IMAGE",), + }, + } + + ALIAS = INPUT_TYPES + observe(ALIAS) + + +NODE_CLASS_MAPPINGS = { + "AliasObservedInputTypesNode": AliasObservedInputTypesNode, +} +''' + result = self._extract_source(source, "alias-observed-input-types-pack") + + self.assertEqual({}, result["nodes"]) + self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_write_artifact_is_deterministic(self): with tempfile.TemporaryDirectory() as tmp: out_one = Path(tmp, "one.json") diff --git a/tools/generate_popular_node_signatures.py b/tools/generate_popular_node_signatures.py index d3c3bc9..cb75fbf 100644 --- a/tools/generate_popular_node_signatures.py +++ b/tools/generate_popular_node_signatures.py @@ -955,6 +955,26 @@ def _update_class_attr_aliases_from_unpack(target, value, name, aliases): return found +def _input_types_alias_sources(value, aliases): + if isinstance(value, ast.Name): + return value.id == "INPUT_TYPES" or value.id in aliases + if isinstance(value, (ast.Tuple, ast.List)): + return any(_input_types_alias_sources(item, aliases) for item in value.elts) + return False + + +def _update_input_types_aliases_from_unpack(target, value, aliases): + found = False + for target_item, value_item in _unpack_target_value_pairs(target, value): + target_name = _alias_target_name(target_item) + if target_name is None: + continue + if _input_types_alias_sources(value_item, aliases): + aliases.add(target_name) + found = True + return found + + def _class_attr(cls, name, env): value = _MISSING aliases = set() @@ -1074,9 +1094,14 @@ def _class_attr(cls, name, env): def _input_types(cls, env, decorator_env): value = _MISSING + aliases = set() classmethod_shadowed = "classmethod" in decorator_env for stmt in cls.body: - if "INPUT_TYPES" in _mutating_call_target_names(stmt): + mutating_targets = _mutating_call_target_names(stmt) + observed_targets = _arbitrary_call_observed_names(stmt) + if "INPUT_TYPES" in mutating_targets or aliases.intersection(mutating_targets): + value = _INVALID + if "INPUT_TYPES" in observed_targets or aliases.intersection(observed_targets): value = _INVALID if isinstance(stmt, ast.FunctionDef) and stmt.name == "INPUT_TYPES": if not _input_types_decorators_are_supported(stmt.decorator_list, classmethod_shadowed): @@ -1095,14 +1120,48 @@ def _input_types(cls, env, decorator_env): if isinstance(stmt, ast.AsyncFunctionDef) and stmt.name == "INPUT_TYPES": value = _INVALID continue + rebound_names = _assignment_target_names(stmt) | _delete_target_names(stmt) | _bound_names(stmt) + aliases.difference_update(rebound_names) if "classmethod" in ( _assignment_target_names(stmt) | _delete_target_names(stmt) | _bound_names(stmt) - | _mutating_call_target_names(stmt) + | mutating_targets ): classmethod_shadowed = True - if isinstance(stmt, (ast.Assign, ast.AnnAssign, ast.AugAssign)): + if isinstance(stmt, ast.Assign): + target_names = _assignment_target_names(stmt) + if ( + len(stmt.targets) == 1 + and isinstance(stmt.targets[0], ast.Name) + and stmt.targets[0].id != "INPUT_TYPES" + and _input_types_alias_sources(stmt.value, aliases) + ): + aliases.add(stmt.targets[0].id) + continue + if ( + len(stmt.targets) == 1 + and "INPUT_TYPES" not in target_names + and _update_input_types_aliases_from_unpack(stmt.targets[0], stmt.value, aliases) + ): + continue + if "INPUT_TYPES" in target_names: + value = _INVALID + continue + if isinstance(stmt, ast.AnnAssign): + target_names = _assignment_target_names(stmt) + if ( + isinstance(stmt.target, ast.Name) + and stmt.target.id != "INPUT_TYPES" + and stmt.value is not None + and _input_types_alias_sources(stmt.value, aliases) + ): + aliases.add(stmt.target.id) + continue + if "INPUT_TYPES" in target_names: + value = _INVALID + continue + if isinstance(stmt, ast.AugAssign): if "INPUT_TYPES" in _assignment_target_names(stmt): value = _INVALID continue @@ -1111,7 +1170,7 @@ def _input_types(cls, env, decorator_env): value = _INVALID continue if isinstance(stmt, ast.Expr): - if "INPUT_TYPES" in _mutating_call_target_names(stmt): + if "INPUT_TYPES" in mutating_targets: value = _INVALID if "INPUT_TYPES" in _bound_names(stmt): value = _INVALID @@ -1602,6 +1661,8 @@ def _parse_python_file(path): def extract_repo_signatures(repo_dir, pack_meta): nodes = {} + node_sources = {} + duplicate_node_types = set() for path in sorted(_python_files(repo_dir)): tree = _parse_python_file(path) if tree is None: @@ -1612,6 +1673,14 @@ def extract_repo_signatures(repo_dir, pack_meta): if displays is _INVALID: continue for node_type, binding in sorted(mappings.items()): + prior_path = node_sources.get(node_type) + if prior_path is not None and prior_path != path: + duplicate_node_types.add(node_type) + nodes.pop(node_type, None) + continue + node_sources.setdefault(node_type, path) + if node_type in duplicate_node_types: + continue cls, class_env = binding sig = _signature_from_class(node_type, cls, displays.get(node_type), pack_meta, class_env, env) if sig is not None: