From 05fa411d47112675c107ec30a8bf69cbe08ee47b Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Thu, 2 Jul 2026 16:00:05 +0200 Subject: [PATCH] Fail closed on walrus bindings and invalid input sections --- .../test_generate_popular_node_signatures.py | 97 +++++++++++++++++++ tools/generate_popular_node_signatures.py | 56 +++++++++-- 2 files changed, 145 insertions(+), 8 deletions(-) diff --git a/tests/test_generate_popular_node_signatures.py b/tests/test_generate_popular_node_signatures.py index 832d76a..932243b 100644 --- a/tests/test_generate_popular_node_signatures.py +++ b/tests/test_generate_popular_node_signatures.py @@ -934,6 +934,30 @@ NODE_CLASS_MAPPINGS = { self.assertEqual({}, result["nodes"]) self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_unhashable_literal_input_key_skips_repo_without_raising(self): + source = ''' +INPUTS = { + ["bad"]: ("IMAGE",), +} + + +class UnhashableLiteralInputKeyNode: + RETURN_TYPES = ("IMAGE",) + + @classmethod + def INPUT_TYPES(cls): + return INPUTS + + +NODE_CLASS_MAPPINGS = { + "UnhashableLiteralInputKeyNode": UnhashableLiteralInputKeyNode, +} +''' + result = self._extract_source(source, "unhashable-literal-input-key-pack") + + self.assertEqual({}, result["nodes"]) + self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_post_class_input_reassignment_skips_static_node(self): source = ''' def build_inputs(): @@ -1057,6 +1081,28 @@ NODE_CLASS_MAPPINGS = { self.assertEqual({}, result["nodes"]) self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_input_types_with_present_non_dict_sections_skips_node(self): + source = ''' +class InvalidInputSectionsNode: + RETURN_TYPES = ("IMAGE",) + + @classmethod + def INPUT_TYPES(cls): + return { + "required": [], + "optional": None, + } + + +NODE_CLASS_MAPPINGS = { + "InvalidInputSectionsNode": InvalidInputSectionsNode, +} +''' + result = self._extract_source(source, "invalid-input-sections-pack") + + self.assertEqual({}, result["nodes"]) + self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_dynamic_return_types_reassignment_skips_node(self): source = ''' def build_outputs(): @@ -1157,6 +1203,32 @@ NODE_CLASS_MAPPINGS = { self.assertEqual({}, result["nodes"]) self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_function_default_walrus_to_return_types_skips_node(self): + source = ''' +class DefaultWalrusReturnTypesNode: + RETURN_TYPES = ("IMAGE",) + + def helper(self, x=(RETURN_TYPES := ("MASK",))): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image": ("IMAGE",), + }, + } + + +NODE_CLASS_MAPPINGS = { + "DefaultWalrusReturnTypesNode": DefaultWalrusReturnTypesNode, +} +''' + result = self._extract_source(source, "default-walrus-return-types-pack") + + self.assertEqual({}, result["nodes"]) + self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_function_default_mutation_to_return_types_skips_node(self): source = ''' class DefaultMutatedReturnTypesNode: @@ -1844,6 +1916,31 @@ Alias.RETURN_TYPES = ("MASK",) self.assertEqual({}, result["nodes"]) self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_module_class_tuple_alias_patch_after_mapping_skips_node(self): + source = ''' +class TupleAliasPatchedNode: + RETURN_TYPES = ("IMAGE",) + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image": ("IMAGE",), + }, + } + + +NODE_CLASS_MAPPINGS = { + "TupleAliasPatchedNode": TupleAliasPatchedNode, +} +Alias, = (TupleAliasPatchedNode,) +Alias.RETURN_TYPES = ("MASK",) +''' + result = self._extract_source(source, "tuple-alias-patched-node-pack") + + self.assertEqual({}, result["nodes"]) + self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_definition_time_class_attribute_mutation_after_mapping_skips_node(self): source = ''' class DefinitionTimeMutatedMappedNode: diff --git a/tools/generate_popular_node_signatures.py b/tools/generate_popular_node_signatures.py index 58f112c..af5fb7f 100644 --- a/tools/generate_popular_node_signatures.py +++ b/tools/generate_popular_node_signatures.py @@ -50,7 +50,11 @@ def _literal(node, env, allow_mutable_env=True): for key, value in zip(node.keys, node.values): if key is None: raise UnsupportedStaticExpression("dict unpacking is not supported") - result[_literal(key, env, allow_mutable_env=False)] = _literal(value, env, allow_mutable_env=False) + key_value = _literal(key, env, allow_mutable_env=False) + try: + result[key_value] = _literal(value, env, allow_mutable_env=False) + except TypeError as exc: + raise UnsupportedStaticExpression("unhashable dict key") from exc return result if isinstance(node, ast.Name) and node.id in env: value = env[node.id] @@ -197,17 +201,35 @@ def _named_expr_target_names(node): names = set() class NamedExprVisitor(ast.NodeVisitor): + def _visit_function_definition_expressions(self, child): + for decorator in child.decorator_list: + self.visit(decorator) + self.visit(child.args) + if child.returns is not None: + self.visit(child.returns) + for type_param in getattr(child, "type_params", ()): + self.visit(type_param) + def visit_FunctionDef(self, child): - return None + self._visit_function_definition_expressions(child) def visit_AsyncFunctionDef(self, child): - return None + self._visit_function_definition_expressions(child) def visit_ClassDef(self, child): - return None + for decorator in child.decorator_list: + self.visit(decorator) + for base in child.bases: + self.visit(base) + for keyword in child.keywords: + self.visit(keyword.value) + for type_param in getattr(child, "type_params", ()): + self.visit(type_param) + for stmt in child.body: + self.visit(stmt) def visit_Lambda(self, child): - return None + self.visit(child.args) def visit_NamedExpr(self, child): names.update(_target_names(child.target)) @@ -766,6 +788,19 @@ def _class_alias_sources(value, class_aliases, class_bindings): return set() +def _update_class_alias_from_unpack(target, value, class_aliases, class_bindings): + if not isinstance(target, (ast.Tuple, ast.List)) or not isinstance(value, (ast.Tuple, ast.List)): + return + if len(target.elts) != len(value.elts): + return + for target_item, value_item in zip(target.elts, value.elts): + if not isinstance(target_item, ast.Name): + continue + sources = _class_alias_sources(value_item, class_aliases, class_bindings) + if sources: + class_aliases[target_item.id] = sources + + def _update_class_aliases(stmt, class_aliases, class_bindings): rebound_names = _assignment_target_names(stmt) | _delete_target_names(stmt) | _bound_names(stmt) for name in rebound_names: @@ -780,6 +815,8 @@ def _update_class_aliases(stmt, class_aliases, class_bindings): sources = _class_alias_sources(stmt.value, class_aliases, class_bindings) if sources: class_aliases[stmt.targets[0].id] = sources + elif isinstance(stmt, ast.Assign) and len(stmt.targets) == 1: + _update_class_alias_from_unpack(stmt.targets[0], stmt.value, class_aliases, class_bindings) elif isinstance(stmt, ast.AnnAssign) and isinstance(stmt.target, ast.Name) and stmt.value is not None: sources = _class_alias_sources(stmt.value, class_aliases, class_bindings) if sources: @@ -930,9 +967,12 @@ def _signature_from_class(node_type, cls, display, pack_meta, class_env, input_e inputs = {} required = [] for section in ("required", "optional"): - values = input_types.get(section) or {} - if not isinstance(values, dict): - return None + if section in input_types: + values = input_types[section] + if not isinstance(values, dict): + return None + else: + values = {} for name, spec in values.items(): inputs[str(name)] = normalise_input_spec(spec) if section == "required":