From 97929892168803176f751b208b261072177983fe Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Thu, 2 Jul 2026 19:46:55 +0200 Subject: [PATCH] Fail closed on mapping mutation keys and bare input specs --- .../test_generate_popular_node_signatures.py | 76 ++++++++++++++++ tools/generate_popular_node_signatures.py | 89 ++++++++++++++++++- 2 files changed, 164 insertions(+), 1 deletion(-) diff --git a/tests/test_generate_popular_node_signatures.py b/tests/test_generate_popular_node_signatures.py index 8b94935..37e20b1 100644 --- a/tests/test_generate_popular_node_signatures.py +++ b/tests/test_generate_popular_node_signatures.py @@ -39,6 +39,7 @@ class StaticExtractionTests(unittest.TestCase): self.assertEqual("COMBO", normalise_input_spec((["nearest", "bilinear"],))) self.assertEqual("IMAGE", normalise_input_spec(("IMAGE",))) self.assertEqual("FLOAT", normalise_input_spec(("FLOAT", {"default": 1.0}))) + self.assertIsNone(normalise_input_spec("IMAGE")) def test_extracts_static_node_mapping_and_signatures(self): source = ''' @@ -347,6 +348,58 @@ NODE_CLASS_MAPPINGS = { self.assertEqual({}, result["nodes"]) self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_duplicate_node_id_from_mapping_update_skips_static_node(self): + source_a = ''' +class StaticDupNode: + RETURN_TYPES = ("IMAGE",) + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image": ("IMAGE",), + }, + } + + +NODE_CLASS_MAPPINGS = { + "DupNode": StaticDupNode, +} +''' + source_b = ''' +class DynamicDupNode: + RETURN_TYPES = ("MASK",) + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "mask": ("MASK",), + }, + } + + +NODE_CLASS_MAPPINGS = {} +NODE_CLASS_MAPPINGS.update({ + "DupNode": DynamicDupNode, +}) +''' + with tempfile.TemporaryDirectory() as tmp: + Path(tmp, "a.py").write_text(textwrap.dedent(source_a), encoding="utf-8") + Path(tmp, "b.py").write_text(textwrap.dedent(source_b), encoding="utf-8") + result = extract_repo_signatures( + Path(tmp), + { + "id": "update-duplicate-node-pack", + "title": "Update Duplicate Node Pack", + "repository": "https://github.com/example/update-duplicate-node-pack", + "rank": 1, + }, + ) + + self.assertEqual({}, result["nodes"]) + self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_unsupported_reassignment_invalidates_static_env_value(self): source = ''' def build_inputs(): @@ -1458,6 +1511,29 @@ NODE_CLASS_MAPPINGS = { self.assertEqual({}, result["nodes"]) self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_input_types_with_bare_string_input_spec_skips_node(self): + source = ''' +class BareStringInputSpecNode: + RETURN_TYPES = ("IMAGE",) + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image": "IMAGE", + }, + } + + +NODE_CLASS_MAPPINGS = { + "BareStringInputSpecNode": BareStringInputSpecNode, +} +''' + result = self._extract_source(source, "bare-string-input-spec-pack") + + self.assertEqual({}, result["nodes"]) + self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_dynamic_return_types_reassignment_skips_node(self): source = ''' def build_outputs(): diff --git a/tools/generate_popular_node_signatures.py b/tools/generate_popular_node_signatures.py index d146b2e..f012137 100644 --- a/tools/generate_popular_node_signatures.py +++ b/tools/generate_popular_node_signatures.py @@ -1060,7 +1060,9 @@ def _collect_module_env(tree, class_bindings=None): def normalise_input_spec(spec): - first = spec[0] if isinstance(spec, (list, tuple)) and spec else spec + if not isinstance(spec, (list, tuple)) or not spec: + return None + first = spec[0] if isinstance(first, list): return "COMBO" if all(isinstance(value, str) for value in first) else None return first if isinstance(first, str) else None @@ -2056,6 +2058,90 @@ def _literal_module_dict_string_keys(node, env): return keys +def _mapping_subscript_target_key(target, mapping_name, env): + if not isinstance(target, ast.Subscript): + return None + if _root_name(target.value) != mapping_name: + return None + try: + key_value = _literal(target.slice, env) + except UnsupportedStaticExpression: + return None + return key_value if isinstance(key_value, str) and key_value else None + + +def _node_class_mapping_mutation_string_keys(stmt, env): + keys = set() + + class MappingMutationKeyVisitor(ast.NodeVisitor): + def _visit_function_definition_expressions(self, node): + for decorator in node.decorator_list: + self.visit(decorator) + self.visit(node.args) + if node.returns is not None: + self.visit(node.returns) + for type_param in getattr(node, "type_params", ()): + self.visit(type_param) + + def visit_FunctionDef(self, node): + self._visit_function_definition_expressions(node) + + def visit_AsyncFunctionDef(self, node): + self._visit_function_definition_expressions(node) + + def visit_ClassDef(self, node): + for decorator in node.decorator_list: + self.visit(decorator) + for base in node.bases: + self.visit(base) + for keyword in node.keywords: + self.visit(keyword.value) + for type_param in getattr(node, "type_params", ()): + self.visit(type_param) + for child in node.body: + self.visit(child) + + def visit_Assign(self, node): + for target in node.targets: + key = _mapping_subscript_target_key(target, "NODE_CLASS_MAPPINGS", env) + if key is not None: + keys.add(key) + self.visit(node.value) + + def visit_AnnAssign(self, node): + key = _mapping_subscript_target_key(node.target, "NODE_CLASS_MAPPINGS", env) + if key is not None: + keys.add(key) + if node.value is not None: + self.visit(node.value) + + def visit_AugAssign(self, node): + key = _mapping_subscript_target_key(node.target, "NODE_CLASS_MAPPINGS", env) + if key is not None: + keys.add(key) + self.visit(node.value) + + def visit_Call(self, node): + if isinstance(node.func, ast.Attribute) and _root_name(node.func.value) == "NODE_CLASS_MAPPINGS": + if node.func.attr == "update": + for arg in node.args: + keys.update(_literal_module_dict_string_keys(arg, env)) + for keyword in node.keywords: + if keyword.arg: + keys.add(keyword.arg) + elif node.func.attr == "setdefault" and node.args: + try: + key_value = _literal(node.args[0], env) + except UnsupportedStaticExpression: + key_value = None + if isinstance(key_value, str) and key_value: + keys.add(key_value) + self.generic_visit(node) + + MappingMutationKeyVisitor().visit(stmt) + return keys + + def _node_class_mapping_keys(tree): if _has_module_wildcard_import(tree): return set() @@ -2071,6 +2157,7 @@ def _node_class_mapping_keys(tree): and stmt.value is not None ): keys.update(_literal_module_dict_string_keys(stmt.value, env)) + keys.update(_node_class_mapping_mutation_string_keys(stmt, env)) _apply_module_stmt_to_env(stmt, env, class_bindings) return keys