From fae0c312bc644872b3de9aa414ce431b4c563814 Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Thu, 2 Jul 2026 12:50:06 +0200 Subject: [PATCH] Invalidate static extraction on rebinding and alias mutation --- .../test_generate_popular_node_signatures.py | 168 ++++++++++++++++++ tools/generate_popular_node_signatures.py | 141 ++++++++++++++- 2 files changed, 302 insertions(+), 7 deletions(-) diff --git a/tests/test_generate_popular_node_signatures.py b/tests/test_generate_popular_node_signatures.py index 35c1bb2..8f6f7cc 100644 --- a/tests/test_generate_popular_node_signatures.py +++ b/tests/test_generate_popular_node_signatures.py @@ -150,6 +150,59 @@ NODE_CLASS_MAPPINGS = { self.assertIn("GoodNode", result["nodes"]) self.assertEqual("ok", result["pack"]["status"]) + def test_skips_undecodable_python_files_without_modified_parse(self): + undecodable_source = b''' +# invalid byte follows: \xff +class UndecodableNode: + RETURN_TYPES = ("IMAGE",) + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image": ("IMAGE",), + }, + } + + +NODE_CLASS_MAPPINGS = { + "UndecodableNode": UndecodableNode, +} +''' + good_source = ''' +class GoodUtf8Node: + RETURN_TYPES = ("IMAGE",) + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image": ("IMAGE",), + }, + } + + +NODE_CLASS_MAPPINGS = { + "GoodUtf8Node": GoodUtf8Node, +} +''' + with tempfile.TemporaryDirectory() as tmp: + Path(tmp, "bad.py").write_bytes(undecodable_source) + Path(tmp, "good.py").write_text(textwrap.dedent(good_source), encoding="utf-8") + result = extract_repo_signatures( + Path(tmp), + { + "id": "undecodable-pack", + "title": "Undecodable Pack", + "repository": "https://github.com/example/undecodable-pack", + "rank": 1, + }, + ) + + self.assertNotIn("UndecodableNode", result["nodes"]) + self.assertIn("GoodUtf8Node", result["nodes"]) + self.assertEqual("ok", result["pack"]["status"]) + def test_unsupported_reassignment_invalidates_static_env_value(self): source = ''' def build_inputs(): @@ -191,6 +244,121 @@ NODE_CLASS_MAPPINGS = { self.assertEqual({}, result["nodes"]) self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_function_binding_invalidates_static_env_value(self): + source = ''' +INPUTS = { + "required": { + "image": ("IMAGE",), + }, +} + + +def INPUTS(): + return {} + + +class FunctionRebindNode: + RETURN_TYPES = ("IMAGE",) + + @classmethod + def INPUT_TYPES(cls): + return INPUTS + + +NODE_CLASS_MAPPINGS = { + "FunctionRebindNode": FunctionRebindNode, +} +''' + result = self._extract_source(source, "function-rebind-pack") + + self.assertEqual({}, result["nodes"]) + self.assertEqual("no_static_nodes", result["pack"]["status"]) + + def test_class_binding_invalidates_static_env_value(self): + source = ''' +INPUTS = { + "required": { + "image": ("IMAGE",), + }, +} + + +class INPUTS: + pass + + +class ClassRebindNode: + RETURN_TYPES = ("IMAGE",) + + @classmethod + def INPUT_TYPES(cls): + return INPUTS + + +NODE_CLASS_MAPPINGS = { + "ClassRebindNode": ClassRebindNode, +} +''' + result = self._extract_source(source, "class-rebind-pack") + + self.assertEqual({}, result["nodes"]) + self.assertEqual("no_static_nodes", result["pack"]["status"]) + + def test_import_binding_invalidates_static_env_value(self): + source = ''' +INPUTS = { + "required": { + "image": ("IMAGE",), + }, +} +import something as INPUTS + + +class ImportRebindNode: + RETURN_TYPES = ("IMAGE",) + + @classmethod + def INPUT_TYPES(cls): + return INPUTS + + +NODE_CLASS_MAPPINGS = { + "ImportRebindNode": ImportRebindNode, +} +''' + result = self._extract_source(source, "import-rebind-pack") + + self.assertEqual({}, result["nodes"]) + self.assertEqual("no_static_nodes", result["pack"]["status"]) + + def test_alias_mutation_invalidates_static_source_value(self): + source = ''' +INPUTS = { + "required": { + "image": ("IMAGE",), + }, +} +ALIAS = INPUTS +ALIAS.clear() + + +class AliasMutatedInputNode: + RETURN_TYPES = ("IMAGE",) + + @classmethod + def INPUT_TYPES(cls): + return INPUTS + + +NODE_CLASS_MAPPINGS = { + "AliasMutatedInputNode": AliasMutatedInputNode, +} +''' + result = self._extract_source(source, "alias-mutated-input-pack") + + self.assertEqual({}, result["nodes"]) + self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_annotated_reassignment_invalidates_static_env_value(self): source = ''' def build_inputs(): diff --git a/tools/generate_popular_node_signatures.py b/tools/generate_popular_node_signatures.py index 6dbf6df..417ff07 100644 --- a/tools/generate_popular_node_signatures.py +++ b/tools/generate_popular_node_signatures.py @@ -54,6 +54,10 @@ def _literal(node, env): raise UnsupportedStaticExpression(type(node).__name__) +def _is_mutable_static_value(value): + return isinstance(value, (dict, list, set)) + + def _target_names(target): if isinstance(target, ast.Name): return {target.id} @@ -69,6 +73,81 @@ def _target_names(target): return set() +def _pattern_bound_names(pattern): + names = set() + if isinstance(pattern, ast.MatchAs): + if pattern.name: + names.add(pattern.name) + if pattern.pattern is not None: + names.update(_pattern_bound_names(pattern.pattern)) + elif isinstance(pattern, ast.MatchStar): + if pattern.name: + names.add(pattern.name) + elif isinstance(pattern, ast.MatchMapping): + if pattern.rest: + names.add(pattern.rest) + for subpattern in pattern.patterns: + names.update(_pattern_bound_names(subpattern)) + elif isinstance(pattern, ast.MatchSequence): + for subpattern in pattern.patterns: + names.update(_pattern_bound_names(subpattern)) + elif isinstance(pattern, ast.MatchClass): + for subpattern in pattern.patterns: + names.update(_pattern_bound_names(subpattern)) + for subpattern in pattern.kwd_patterns: + names.update(_pattern_bound_names(subpattern)) + elif isinstance(pattern, ast.MatchOr): + for subpattern in pattern.patterns: + names.update(_pattern_bound_names(subpattern)) + return names + + +def _named_expr_target_names(node): + names = set() + + class NamedExprVisitor(ast.NodeVisitor): + def visit_FunctionDef(self, child): + return None + + def visit_AsyncFunctionDef(self, child): + return None + + def visit_ClassDef(self, child): + return None + + def visit_Lambda(self, child): + return None + + def visit_NamedExpr(self, child): + names.update(_target_names(child.target)) + self.visit(child.value) + + NamedExprVisitor().visit(node) + return names + + +def _bound_names(stmt): + names = set() + if isinstance(stmt, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)): + names.add(stmt.name) + elif isinstance(stmt, ast.Import): + for alias in stmt.names: + names.add(alias.asname or alias.name.split(".", 1)[0]) + elif isinstance(stmt, ast.ImportFrom): + for alias in stmt.names: + if alias.name != "*": + names.add(alias.asname or alias.name) + elif isinstance(stmt, (ast.With, ast.AsyncWith)): + for item in stmt.items: + if item.optional_vars is not None: + names.update(_target_names(item.optional_vars)) + elif isinstance(stmt, ast.Match): + for case in stmt.cases: + names.update(_pattern_bound_names(case.pattern)) + names.update(_named_expr_target_names(stmt)) + return names + + def _assignment_target_names(stmt): if isinstance(stmt, ast.Assign): names = set() @@ -107,14 +186,23 @@ def _assigned_names_in_control_flow(stmt): class AssignmentVisitor(ast.NodeVisitor): def visit_FunctionDef(self, node): + names.add(node.name) return None def visit_AsyncFunctionDef(self, node): + names.add(node.name) return None def visit_ClassDef(self, node): + names.add(node.name) return None + def visit_Import(self, node): + names.update(_bound_names(node)) + + def visit_ImportFrom(self, node): + names.update(_bound_names(node)) + def visit_Assign(self, node): names.update(_assignment_target_names(node)) @@ -129,6 +217,23 @@ def _assigned_names_in_control_flow(stmt): def visit_Expr(self, node): names.update(_mutating_call_target_names(node)) + names.update(_named_expr_target_names(node)) + + def visit_With(self, node): + names.update(_bound_names(node)) + self.generic_visit(node) + + def visit_AsyncWith(self, node): + names.update(_bound_names(node)) + self.generic_visit(node) + + def visit_NamedExpr(self, node): + names.update(_target_names(node.target)) + self.visit(node.value) + + def visit_Match(self, node): + names.update(_bound_names(node)) + self.generic_visit(node) def visit_For(self, node): names.update(_assignment_target_names(node)) @@ -149,6 +254,14 @@ def _collect_module_env(tree): names = _assignment_target_names(stmt) if len(stmt.targets) == 1 and isinstance(stmt.targets[0], ast.Name): name = stmt.targets[0].id + if ( + isinstance(stmt.value, ast.Name) + and stmt.value.id in env + and _is_mutable_static_value(env[stmt.value.id]) + ): + env.pop(stmt.value.id, None) + env.pop(name, None) + continue try: env[name] = _literal(stmt.value, env) except UnsupportedStaticExpression: @@ -182,10 +295,15 @@ def _collect_module_env(tree): if isinstance(stmt, ast.Expr): for name in _mutating_call_target_names(stmt): env.pop(name, None) + for name in _bound_names(stmt): + env.pop(name, None) continue - if isinstance(stmt, (ast.If, ast.For, ast.AsyncFor, ast.While, ast.Try)): + if isinstance(stmt, (ast.If, ast.For, ast.AsyncFor, ast.While, ast.Try, ast.With, ast.AsyncWith, ast.Match)): for name in _assigned_names_in_control_flow(stmt): env.pop(name, None) + continue + for name in _bound_names(stmt): + env.pop(name, None) return env @@ -235,10 +353,15 @@ def _class_attr(cls, name, env): if isinstance(stmt, ast.Expr): if name in _mutating_call_target_names(stmt): value = _INVALID + if name in _bound_names(stmt): + value = _INVALID continue - if isinstance(stmt, (ast.If, ast.For, ast.AsyncFor, ast.While, ast.Try)): + if isinstance(stmt, (ast.If, ast.For, ast.AsyncFor, ast.While, ast.Try, ast.With, ast.AsyncWith, ast.Match)): if name in _assigned_names_in_control_flow(stmt): value = _INVALID + continue + if name in _bound_names(stmt): + value = _INVALID if value in (_MISSING, _INVALID): return None return value @@ -289,6 +412,8 @@ def _final_module_dict(tree, env, name, value_converter): for stmt in tree.body: if isinstance(stmt, ast.Assign): if not _name_is_assigned(stmt, name): + if isinstance(stmt.value, ast.Name) and stmt.value.id == name: + value = _INVALID continue if len(stmt.targets) == 1 and isinstance(stmt.targets[0], ast.Name): try: @@ -320,10 +445,15 @@ def _final_module_dict(tree, env, name, value_converter): if isinstance(stmt, ast.Expr): if name in _mutating_call_target_names(stmt): value = _INVALID + if name in _bound_names(stmt): + value = _INVALID continue - if isinstance(stmt, (ast.If, ast.For, ast.AsyncFor, ast.While, ast.Try)): + if isinstance(stmt, (ast.If, ast.For, ast.AsyncFor, ast.While, ast.Try, ast.With, ast.AsyncWith, ast.Match)): if name in _assigned_names_in_control_flow(stmt): value = _INVALID + continue + if name in _bound_names(stmt): + value = _INVALID if value in (_MISSING, _INVALID): return {} return value @@ -387,10 +517,7 @@ def _parse_python_file(path): try: return ast.parse(path.read_text(encoding="utf-8"), filename=str(path)) except UnicodeDecodeError: - try: - return ast.parse(path.read_text(encoding="utf-8", errors="ignore"), filename=str(path)) - except SyntaxError: - return None + return None except SyntaxError: return None