From 7e4e85a0bd802fcc3391b8d97d73897c3546bc6c Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Thu, 2 Jul 2026 16:27:48 +0200 Subject: [PATCH] Fail closed on dynamic patches and displays --- .../test_generate_popular_node_signatures.py | 63 ++++++++- tools/generate_popular_node_signatures.py | 123 +++++++++++++----- 2 files changed, 148 insertions(+), 38 deletions(-) diff --git a/tests/test_generate_popular_node_signatures.py b/tests/test_generate_popular_node_signatures.py index e9d1781..85b9d10 100644 --- a/tests/test_generate_popular_node_signatures.py +++ b/tests/test_generate_popular_node_signatures.py @@ -1081,6 +1081,39 @@ NODE_CLASS_MAPPINGS = { self.assertEqual({}, result["nodes"]) self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_shadowed_classmethod_decorator_skips_node(self): + source = ''' +def classmethod(fn): + def replacement(cls): + return { + "required": { + "mask": ("MASK",), + }, + } + return replacement + + +class ShadowedClassmethodInputTypesNode: + RETURN_TYPES = ("IMAGE",) + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image": ("IMAGE",), + }, + } + + +NODE_CLASS_MAPPINGS = { + "ShadowedClassmethodInputTypesNode": ShadowedClassmethodInputTypesNode, +} +''' + result = self._extract_source(source, "shadowed-classmethod-input-types-pack") + + self.assertEqual({}, result["nodes"]) + self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_input_types_with_present_non_dict_sections_skips_node(self): source = ''' class InvalidInputSectionsNode: @@ -1860,6 +1893,30 @@ PatchedReturnTypesNode.RETURN_TYPES = ("MASK",) self.assertEqual({}, result["nodes"]) self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_module_class_setattr_patch_after_mapping_skips_node(self): + source = ''' +class SetattrPatchedNode: + RETURN_TYPES = ("IMAGE",) + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image": ("IMAGE",), + }, + } + + +NODE_CLASS_MAPPINGS = { + "SetattrPatchedNode": SetattrPatchedNode, +} +setattr(SetattrPatchedNode, "RETURN_TYPES", ("MASK",)) +''' + result = self._extract_source(source, "setattr-patched-node-pack") + + self.assertEqual({}, result["nodes"]) + self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_module_class_input_types_patch_after_mapping_skips_node(self): source = ''' def build_inputs(): @@ -2266,7 +2323,7 @@ NODE_CLASS_MAPPINGS = { self.assertEqual({}, result["nodes"]) self.assertEqual("no_static_nodes", result["pack"]["status"]) - def test_dynamic_display_mapping_reassignment_falls_back_to_node_type(self): + def test_dynamic_display_mapping_reassignment_skips_node(self): source = ''' def build_displays(): return {"DisplayInvalidatedNode": "Dynamic Display"} @@ -2294,8 +2351,8 @@ NODE_DISPLAY_NAME_MAPPINGS = build_displays() ''' result = self._extract_source(source, "dynamic-display-pack") - self.assertEqual("DisplayInvalidatedNode", result["nodes"]["DisplayInvalidatedNode"]["display"]) - self.assertEqual("ok", result["pack"]["status"]) + self.assertEqual({}, result["nodes"]) + self.assertEqual("no_static_nodes", result["pack"]["status"]) def test_input_types_with_dynamic_control_flow_is_skipped(self): source = ''' diff --git a/tools/generate_popular_node_signatures.py b/tools/generate_popular_node_signatures.py index 60f66d4..8cb0e03 100644 --- a/tools/generate_popular_node_signatures.py +++ b/tools/generate_popular_node_signatures.py @@ -59,12 +59,26 @@ def _literal(node, env, allow_mutable_env=True): return result if isinstance(node, ast.Name) and node.id in env: value = env[node.id] + if value is _INVALID: + raise UnsupportedStaticExpression(f"unsupported env reference {node.id!r}") if not allow_mutable_env and _is_mutable_static_value(value): raise UnsupportedStaticExpression(f"mutable env reference {node.id!r} is not supported") return value raise UnsupportedStaticExpression(type(node).__name__) +def _invalidate_env_name(env, name): + if name == "classmethod": + env[name] = _INVALID + else: + env.pop(name, None) + + +def _invalidate_env_names(env, names): + for name in names: + _invalidate_env_name(env, name) + + def _is_mutable_static_value(value): return isinstance(value, (dict, list, set)) @@ -108,6 +122,24 @@ def _attribute_target_base_names(target): return set() +def _setattr_delattr_target_names(node): + if not isinstance(node, ast.Call): + return set() + if not isinstance(node.func, ast.Name) or node.func.id not in {"delattr", "setattr"}: + return set() + if len(node.args) < 2: + return set() + attr = node.args[1] + if ( + isinstance(attr, ast.Constant) + and isinstance(attr.value, str) + and attr.value not in _CLASS_SIGNATURE_ATTRS + ): + return set() + name = _root_name(node.args[0]) + return {name} if name else set() + + def _class_attribute_mutation_target_names(stmt): names = set() @@ -161,6 +193,7 @@ def _class_attribute_mutation_target_names(stmt): names.update(_attribute_target_base_names(target)) def visit_Call(self, node): + names.update(_setattr_delattr_target_names(node)) if isinstance(node.func, ast.Attribute) and node.func.attr in _MUTATING_METHODS: names.update(_attribute_target_base_names(node.func.value)) self.generic_visit(node) @@ -328,6 +361,7 @@ def _mutating_call_target_names(stmt): self.visit(node.args) def visit_Call(self, node): + names.update(_setattr_delattr_target_names(node)) if isinstance(node.func, ast.Attribute) and node.func.attr in _MUTATING_METHODS: names.update(_target_names(node.func.value)) self.generic_visit(node) @@ -451,15 +485,14 @@ def _invalidate_class_bindings(class_bindings, names): def _apply_module_stmt_to_env(stmt, env, class_bindings=None): names = _mutating_call_target_names(stmt) _invalidate_class_bindings(class_bindings, names) - for name in names: - env.pop(name, None) + _invalidate_env_names(env, names) if isinstance(stmt, ast.ClassDef): if class_bindings is not None: if stmt.decorator_list: class_bindings.pop(stmt.name, None) else: class_bindings[stmt.name] = (stmt, dict(env)) - env.pop(stmt.name, None) + _invalidate_env_name(env, stmt.name) return if isinstance(stmt, ast.Assign): names = _assignment_target_names(stmt) @@ -469,7 +502,7 @@ def _apply_module_stmt_to_env(stmt, env, class_bindings=None): subscript_root = _mutable_env_subscript_root(stmt.value, env) if subscript_root is not None: env.pop(subscript_root, None) - env.pop(name, None) + _invalidate_env_name(env, name) return if ( isinstance(stmt.value, ast.Name) @@ -477,15 +510,14 @@ def _apply_module_stmt_to_env(stmt, env, class_bindings=None): and _is_mutable_static_value(env[stmt.value.id]) ): env.pop(stmt.value.id, None) - env.pop(name, None) + _invalidate_env_name(env, name) return try: env[name] = _literal(stmt.value, env) except UnsupportedStaticExpression: - env.pop(name, None) + _invalidate_env_name(env, name) else: - for name in names: - env.pop(name, None) + _invalidate_env_names(env, names) return if isinstance(stmt, ast.AnnAssign): names = _assignment_target_names(stmt) @@ -497,7 +529,7 @@ def _apply_module_stmt_to_env(stmt, env, class_bindings=None): subscript_root = _mutable_env_subscript_root(stmt.value, env) if subscript_root is not None: env.pop(subscript_root, None) - env.pop(name, None) + _invalidate_env_name(env, name) return if ( isinstance(stmt.value, ast.Name) @@ -505,33 +537,29 @@ def _apply_module_stmt_to_env(stmt, env, class_bindings=None): and _is_mutable_static_value(env[stmt.value.id]) ): env.pop(stmt.value.id, None) - env.pop(name, None) + _invalidate_env_name(env, name) return try: env[name] = _literal(stmt.value, env) except UnsupportedStaticExpression: - env.pop(name, None) + _invalidate_env_name(env, name) else: - for name in names: - env.pop(name, None) + _invalidate_env_names(env, names) return if isinstance(stmt, ast.AugAssign): names = _assignment_target_names(stmt) _invalidate_class_bindings(class_bindings, names) - for name in names: - env.pop(name, None) + _invalidate_env_names(env, names) return if isinstance(stmt, ast.Delete): names = _delete_target_names(stmt) _invalidate_class_bindings(class_bindings, names) - for name in names: - env.pop(name, None) + _invalidate_env_names(env, names) return if isinstance(stmt, ast.Expr): names = _bound_names(stmt) _invalidate_class_bindings(class_bindings, names) - for name in names: - env.pop(name, None) + _invalidate_env_names(env, names) return if isinstance(stmt, _CONTROL_FLOW_TYPES): if _has_wildcard_import_in_control_flow(stmt): @@ -541,8 +569,7 @@ def _apply_module_stmt_to_env(stmt, env, class_bindings=None): return names = _assigned_names_in_control_flow(stmt) _invalidate_class_bindings(class_bindings, names) - for name in names: - env.pop(name, None) + _invalidate_env_names(env, names) return if _has_wildcard_import(stmt): env.clear() @@ -551,8 +578,7 @@ def _apply_module_stmt_to_env(stmt, env, class_bindings=None): return names = _bound_names(stmt) _invalidate_class_bindings(class_bindings, names) - for name in names: - env.pop(name, None) + _invalidate_env_names(env, names) def _collect_module_env(tree, class_bindings=None): @@ -586,8 +612,13 @@ def _mutable_env_subscript_root(node, env): return None -def _input_types_decorators_are_supported(decorators): - return all(isinstance(decorator, ast.Name) and decorator.id == "classmethod" for decorator in decorators) +def _input_types_decorators_are_supported(decorators, classmethod_shadowed): + for decorator in decorators: + if not isinstance(decorator, ast.Name) or decorator.id != "classmethod": + return False + if classmethod_shadowed: + return False + return True def _class_attr_alias_sources(value, name, aliases): @@ -721,13 +752,14 @@ def _class_attr(cls, name, env): return value -def _input_types(cls, env): +def _input_types(cls, env, decorator_env): value = _MISSING + classmethod_shadowed = "classmethod" in decorator_env for stmt in cls.body: if "INPUT_TYPES" in _mutating_call_target_names(stmt): value = _INVALID if isinstance(stmt, ast.FunctionDef) and stmt.name == "INPUT_TYPES": - if not _input_types_decorators_are_supported(stmt.decorator_list): + if not _input_types_decorators_are_supported(stmt.decorator_list, classmethod_shadowed): value = _INVALID continue if len(stmt.body) != 1 or not isinstance(stmt.body[0], ast.Return): @@ -743,6 +775,13 @@ def _input_types(cls, env): if isinstance(stmt, ast.AsyncFunctionDef) and stmt.name == "INPUT_TYPES": value = _INVALID continue + if "classmethod" in ( + _assignment_target_names(stmt) + | _delete_target_names(stmt) + | _bound_names(stmt) + | _mutating_call_target_names(stmt) + ): + classmethod_shadowed = True if isinstance(stmt, (ast.Assign, ast.AnnAssign, ast.AugAssign)): if "INPUT_TYPES" in _assignment_target_names(stmt): value = _INVALID @@ -894,7 +933,16 @@ def _update_class_attribute_aliases(stmt, class_attribute_aliases, class_aliases class_attribute_aliases[stmt.target.id] = sources -def _final_module_dict(tree, name, value_converter, value_invalidated_by_names=None): +def _module_class_attribute_invalidated_names(stmt, class_aliases, class_attribute_aliases): + names = _expanded_class_attribute_names( + _class_attribute_mutation_target_names(stmt), + class_aliases, + ) + names.update(_class_attribute_alias_invalidated_names(stmt, class_attribute_aliases)) + return names + + +def _final_module_dict(tree, name, value_converter, value_invalidated_by_names=None, return_state=False): value_invalidated_by_names = value_invalidated_by_names or (lambda _value, _names: False) value = _MISSING env = {} @@ -905,18 +953,14 @@ def _final_module_dict(tree, name, value_converter, value_invalidated_by_names=N def advance_module_state(stmt): _invalidate_class_bindings( class_bindings, - _class_attribute_alias_invalidated_names(stmt, class_attribute_aliases), + _module_class_attribute_invalidated_names(stmt, class_aliases, class_attribute_aliases), ) _apply_module_stmt_to_env(stmt, env, class_bindings) _update_class_aliases(stmt, class_aliases, class_bindings) _update_class_attribute_aliases(stmt, class_attribute_aliases, class_aliases, class_bindings) for stmt in tree.body: - class_attr_names = _expanded_class_attribute_names( - _class_attribute_mutation_target_names(stmt), - class_aliases, - ) - class_attr_names.update(_class_attribute_alias_invalidated_names(stmt, class_attribute_aliases)) + class_attr_names = _module_class_attribute_invalidated_names(stmt, class_aliases, class_attribute_aliases) if ( value not in (_MISSING, _INVALID) and class_attr_names @@ -986,6 +1030,8 @@ def _final_module_dict(tree, name, value_converter, value_invalidated_by_names=N if name in _bound_names(stmt): value = _INVALID advance_module_state(stmt) + if return_state: + return value if value in (_MISSING, _INVALID): return {} return value @@ -1022,12 +1068,17 @@ def _display_mappings(tree): tree, "NODE_DISPLAY_NAME_MAPPINGS", lambda value, env, _class_bindings: _literal(value, env), + return_state=True, ) + if displays is _MISSING: + return {} + if displays is _INVALID: + return _INVALID return {str(k): str(v) for k, v in displays.items()} def _signature_from_class(node_type, cls, display, pack_meta, class_env, input_env): - input_types = _input_types(cls, input_env) + input_types = _input_types(cls, input_env, class_env) return_types = _class_attr(cls, "RETURN_TYPES", class_env) return_names = _class_attr(cls, "RETURN_NAMES", class_env) if return_types is _INVALID or return_names is _INVALID: @@ -1097,6 +1148,8 @@ def extract_repo_signatures(repo_dir, pack_meta): env = _collect_module_env(tree) mappings = _node_class_mappings(tree) displays = _display_mappings(tree) + if displays is _INVALID: + continue for node_type, binding in sorted(mappings.items()): cls, class_env = binding sig = _signature_from_class(node_type, cls, displays.get(node_type), pack_meta, class_env, env)