From 1b567980185269f2237a4469c06cb8d057e90906 Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Thu, 2 Jul 2026 19:04:25 +0200 Subject: [PATCH] Fail closed on definition references and sticky mappings --- .../test_generate_popular_node_signatures.py | 111 ++++++++++++++++++ tools/generate_popular_node_signatures.py | 62 +++++++++- 2 files changed, 170 insertions(+), 3 deletions(-) diff --git a/tests/test_generate_popular_node_signatures.py b/tests/test_generate_popular_node_signatures.py index a1f8fd9..4e21db8 100644 --- a/tests/test_generate_popular_node_signatures.py +++ b/tests/test_generate_popular_node_signatures.py @@ -1169,6 +1169,52 @@ NODE_CLASS_MAPPINGS = { self.assertEqual({}, result["nodes"]) self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_input_types_default_referencing_return_types_skips_node(self): + source = ''' +class DefaultReferencesReturnTypesInputNode: + RETURN_TYPES = ("IMAGE",) + + @classmethod + def INPUT_TYPES(cls, value=RETURN_TYPES): + return { + "required": { + "image": ("IMAGE",), + }, + } + + +NODE_CLASS_MAPPINGS = { + "DefaultReferencesReturnTypesInputNode": DefaultReferencesReturnTypesInputNode, +} +''' + result = self._extract_source(source, "default-references-return-types-input-pack") + + self.assertEqual({}, result["nodes"]) + self.assertEqual("no_static_nodes", result["pack"]["status"]) + + def test_input_types_return_annotation_referencing_input_types_skips_node(self): + source = ''' +class AnnotationReferencesInputTypesNode: + RETURN_TYPES = ("IMAGE",) + + @classmethod + def INPUT_TYPES(cls) -> INPUT_TYPES: + return { + "required": { + "image": ("IMAGE",), + }, + } + + +NODE_CLASS_MAPPINGS = { + "AnnotationReferencesInputTypesNode": AnnotationReferencesInputTypesNode, +} +''' + result = self._extract_source(source, "annotation-references-input-types-pack") + + self.assertEqual({}, result["nodes"]) + self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_decorated_input_types_skips_node(self): source = ''' def replace(fn): @@ -2163,6 +2209,37 @@ NODE_CLASS_MAPPINGS = build_mappings() self.assertEqual({}, result["nodes"]) self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_dynamic_node_class_mapping_assignment_stays_invalid_after_static_reassignment(self): + source = ''' +def build_mappings(): + return {"StickyDynamicMappingNode": StickyDynamicMappingNode} + + +class StickyDynamicMappingNode: + RETURN_TYPES = ("IMAGE",) + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image": ("IMAGE",), + }, + } + + +NODE_CLASS_MAPPINGS = { + "StickyDynamicMappingNode": StickyDynamicMappingNode, +} +NODE_CLASS_MAPPINGS = build_mappings() +NODE_CLASS_MAPPINGS = { + "StickyDynamicMappingNode": StickyDynamicMappingNode, +} +''' + result = self._extract_source(source, "sticky-dynamic-mapping-pack") + + self.assertEqual({}, result["nodes"]) + self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_rebound_node_class_name_skips_static_mapping(self): source = ''' def build_node(): @@ -3719,6 +3796,40 @@ NODE_DISPLAY_NAME_MAPPINGS = build_displays() self.assertEqual({}, result["nodes"]) self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_dynamic_display_mapping_assignment_stays_invalid_after_static_reassignment(self): + source = ''' +def build_displays(): + return {"StickyDisplayInvalidatedNode": "Dynamic Display"} + + +class StickyDisplayInvalidatedNode: + RETURN_TYPES = ("IMAGE",) + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image": ("IMAGE",), + }, + } + + +NODE_CLASS_MAPPINGS = { + "StickyDisplayInvalidatedNode": StickyDisplayInvalidatedNode, +} +NODE_DISPLAY_NAME_MAPPINGS = { + "StickyDisplayInvalidatedNode": "Stale Display", +} +NODE_DISPLAY_NAME_MAPPINGS = build_displays() +NODE_DISPLAY_NAME_MAPPINGS = { + "StickyDisplayInvalidatedNode": "Recovered Display", +} +''' + result = self._extract_source(source, "sticky-display-invalidated-pack") + + self.assertEqual({}, result["nodes"]) + self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_non_string_display_mapping_value_skips_node(self): source = ''' class NonStringDisplayValueNode: diff --git a/tools/generate_popular_node_signatures.py b/tools/generate_popular_node_signatures.py index a9123cb..b3a6e02 100644 --- a/tools/generate_popular_node_signatures.py +++ b/tools/generate_popular_node_signatures.py @@ -648,6 +648,35 @@ def _arbitrary_call_observed_names(stmt): return names +def _definition_time_referenced_names(stmt): + names = set() + + def collect_function_definition_expressions(node): + for decorator in node.decorator_list: + names.update(_referenced_names(decorator)) + names.update(_referenced_names(node.args)) + if node.returns is not None: + names.update(_referenced_names(node.returns)) + for type_param in getattr(node, "type_params", ()): + names.update(_referenced_names(type_param)) + + if isinstance(stmt, (ast.FunctionDef, ast.AsyncFunctionDef)): + collect_function_definition_expressions(stmt) + elif isinstance(stmt, ast.ClassDef): + for decorator in stmt.decorator_list: + names.update(_referenced_names(decorator)) + for base in stmt.bases: + names.update(_referenced_names(base)) + for keyword in stmt.keywords: + names.update(_referenced_names(keyword.value)) + for type_param in getattr(stmt, "type_params", ()): + names.update(_referenced_names(type_param)) + elif isinstance(stmt, ast.Lambda): + names.update(_referenced_names(stmt.args)) + + return names + + def _assigned_names_in_control_flow(stmt): names = _mutating_call_target_names(stmt) | _arbitrary_call_observed_names(stmt) @@ -980,7 +1009,7 @@ def _update_class_attr_aliases_from_unpack(target, value, name, aliases): def _input_types_alias_sources(value, aliases): if isinstance(value, ast.Name): - return value.id == "INPUT_TYPES" or value.id in aliases + return value.id in _CLASS_SIGNATURE_ATTRS or value.id in aliases if isinstance(value, (ast.Tuple, ast.List)): return any(_input_types_alias_sources(item, aliases) for item in value.elts) return False @@ -1134,11 +1163,14 @@ def _input_types(cls, env, decorator_env): for stmt in cls.body: mutating_targets = _mutating_call_target_names(stmt) observed_targets = _arbitrary_call_observed_names(stmt) + definition_references = _definition_time_referenced_names(stmt) + protected_definition_references = _CLASS_SIGNATURE_ATTRS | aliases input_types_invalidated = ( "INPUT_TYPES" in mutating_targets or bool(aliases.intersection(mutating_targets)) or "INPUT_TYPES" in observed_targets or bool(aliases.intersection(observed_targets)) + or bool(definition_references.intersection(protected_definition_references)) ) if input_types_invalidated: value = _INVALID @@ -1669,6 +1701,7 @@ def _update_module_dict_aliases(stmt, name, aliases): def _final_module_dict(tree, name, value_converter, value_invalidated_by_names=None, return_state=False): value_invalidated_by_names = value_invalidated_by_names or (lambda _value, _names: False) value = _MISSING + sticky_invalid = False env = {} class_bindings = {} class_aliases = {} @@ -1695,74 +1728,97 @@ def _final_module_dict(tree, name, value_converter, value_invalidated_by_names=N and value_invalidated_by_names(value, class_attr_names) ): value = _INVALID + sticky_invalid = True if _name_invalidated_by(name, _mutating_call_target_names(stmt)): value = _INVALID + sticky_invalid = True if _name_invalidated_by(name, _arbitrary_call_observed_names(stmt)): value = _INVALID + sticky_invalid = True if _name_invalidated_by(name, _namespace_alias_mutation_target_names(stmt, namespace_aliases)): value = _INVALID + sticky_invalid = True if _module_dict_alias_invalidated(stmt, module_dict_aliases): value = _INVALID + sticky_invalid = True if isinstance(stmt, ast.Assign): if not _name_is_assigned(stmt, name): if isinstance(stmt.value, ast.Name) and stmt.value.id == name: value = _INVALID + sticky_invalid = True advance_module_state(stmt) continue - if len(stmt.targets) == 1 and isinstance(stmt.targets[0], ast.Name): + if sticky_invalid: + value = _INVALID + elif len(stmt.targets) == 1 and isinstance(stmt.targets[0], ast.Name): try: value = _module_dict_entries(stmt.value, env, class_bindings, value_converter) except UnsupportedStaticExpression: value = _INVALID + sticky_invalid = True else: value = _INVALID + sticky_invalid = True advance_module_state(stmt) continue if isinstance(stmt, ast.AnnAssign): if not _name_is_assigned(stmt, name): if isinstance(stmt.value, ast.Name) and stmt.value.id == name: value = _INVALID + sticky_invalid = True advance_module_state(stmt) continue - if isinstance(stmt.target, ast.Name) and stmt.value is not None: + if sticky_invalid: + value = _INVALID + elif isinstance(stmt.target, ast.Name) and stmt.value is not None: try: value = _module_dict_entries(stmt.value, env, class_bindings, value_converter) except UnsupportedStaticExpression: value = _INVALID + sticky_invalid = True else: value = _INVALID + sticky_invalid = True advance_module_state(stmt) continue if isinstance(stmt, ast.AugAssign): if _name_is_assigned(stmt, name): value = _INVALID + sticky_invalid = True advance_module_state(stmt) continue if isinstance(stmt, ast.Delete): if name in _delete_target_names(stmt): value = _INVALID + sticky_invalid = True advance_module_state(stmt) continue if isinstance(stmt, ast.Expr): if name in _mutating_call_target_names(stmt): value = _INVALID + sticky_invalid = True if name in _bound_names(stmt): value = _INVALID + sticky_invalid = True advance_module_state(stmt) continue if isinstance(stmt, _CONTROL_FLOW_TYPES): if name in _assigned_names_in_control_flow(stmt): value = _INVALID + sticky_invalid = True if _has_wildcard_import_in_control_flow(stmt): value = _INVALID + sticky_invalid = True advance_module_state(stmt) continue if _has_wildcard_import(stmt): value = _INVALID + sticky_invalid = True advance_module_state(stmt) continue if name in _bound_names(stmt): value = _INVALID + sticky_invalid = True advance_module_state(stmt) if return_state: return value