Invalidate static extraction on rebinding and alias mutation

This commit is contained in:
2026-07-02 12:50:06 +02:00
parent 21a29b8846
commit fae0c312bc
2 changed files with 302 additions and 7 deletions
@@ -150,6 +150,59 @@ NODE_CLASS_MAPPINGS = {
self.assertIn("GoodNode", result["nodes"]) self.assertIn("GoodNode", result["nodes"])
self.assertEqual("ok", result["pack"]["status"]) self.assertEqual("ok", result["pack"]["status"])
def test_skips_undecodable_python_files_without_modified_parse(self):
undecodable_source = b'''
# invalid byte follows: \xff
class UndecodableNode:
RETURN_TYPES = ("IMAGE",)
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"image": ("IMAGE",),
},
}
NODE_CLASS_MAPPINGS = {
"UndecodableNode": UndecodableNode,
}
'''
good_source = '''
class GoodUtf8Node:
RETURN_TYPES = ("IMAGE",)
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"image": ("IMAGE",),
},
}
NODE_CLASS_MAPPINGS = {
"GoodUtf8Node": GoodUtf8Node,
}
'''
with tempfile.TemporaryDirectory() as tmp:
Path(tmp, "bad.py").write_bytes(undecodable_source)
Path(tmp, "good.py").write_text(textwrap.dedent(good_source), encoding="utf-8")
result = extract_repo_signatures(
Path(tmp),
{
"id": "undecodable-pack",
"title": "Undecodable Pack",
"repository": "https://github.com/example/undecodable-pack",
"rank": 1,
},
)
self.assertNotIn("UndecodableNode", result["nodes"])
self.assertIn("GoodUtf8Node", result["nodes"])
self.assertEqual("ok", result["pack"]["status"])
def test_unsupported_reassignment_invalidates_static_env_value(self): def test_unsupported_reassignment_invalidates_static_env_value(self):
source = ''' source = '''
def build_inputs(): def build_inputs():
@@ -191,6 +244,121 @@ NODE_CLASS_MAPPINGS = {
self.assertEqual({}, result["nodes"]) self.assertEqual({}, result["nodes"])
self.assertEqual("no_static_nodes", result["pack"]["status"]) self.assertEqual("no_static_nodes", result["pack"]["status"])
def test_function_binding_invalidates_static_env_value(self):
source = '''
INPUTS = {
"required": {
"image": ("IMAGE",),
},
}
def INPUTS():
return {}
class FunctionRebindNode:
RETURN_TYPES = ("IMAGE",)
@classmethod
def INPUT_TYPES(cls):
return INPUTS
NODE_CLASS_MAPPINGS = {
"FunctionRebindNode": FunctionRebindNode,
}
'''
result = self._extract_source(source, "function-rebind-pack")
self.assertEqual({}, result["nodes"])
self.assertEqual("no_static_nodes", result["pack"]["status"])
def test_class_binding_invalidates_static_env_value(self):
source = '''
INPUTS = {
"required": {
"image": ("IMAGE",),
},
}
class INPUTS:
pass
class ClassRebindNode:
RETURN_TYPES = ("IMAGE",)
@classmethod
def INPUT_TYPES(cls):
return INPUTS
NODE_CLASS_MAPPINGS = {
"ClassRebindNode": ClassRebindNode,
}
'''
result = self._extract_source(source, "class-rebind-pack")
self.assertEqual({}, result["nodes"])
self.assertEqual("no_static_nodes", result["pack"]["status"])
def test_import_binding_invalidates_static_env_value(self):
source = '''
INPUTS = {
"required": {
"image": ("IMAGE",),
},
}
import something as INPUTS
class ImportRebindNode:
RETURN_TYPES = ("IMAGE",)
@classmethod
def INPUT_TYPES(cls):
return INPUTS
NODE_CLASS_MAPPINGS = {
"ImportRebindNode": ImportRebindNode,
}
'''
result = self._extract_source(source, "import-rebind-pack")
self.assertEqual({}, result["nodes"])
self.assertEqual("no_static_nodes", result["pack"]["status"])
def test_alias_mutation_invalidates_static_source_value(self):
source = '''
INPUTS = {
"required": {
"image": ("IMAGE",),
},
}
ALIAS = INPUTS
ALIAS.clear()
class AliasMutatedInputNode:
RETURN_TYPES = ("IMAGE",)
@classmethod
def INPUT_TYPES(cls):
return INPUTS
NODE_CLASS_MAPPINGS = {
"AliasMutatedInputNode": AliasMutatedInputNode,
}
'''
result = self._extract_source(source, "alias-mutated-input-pack")
self.assertEqual({}, result["nodes"])
self.assertEqual("no_static_nodes", result["pack"]["status"])
def test_annotated_reassignment_invalidates_static_env_value(self): def test_annotated_reassignment_invalidates_static_env_value(self):
source = ''' source = '''
def build_inputs(): def build_inputs():
+133 -6
View File
@@ -54,6 +54,10 @@ def _literal(node, env):
raise UnsupportedStaticExpression(type(node).__name__) raise UnsupportedStaticExpression(type(node).__name__)
def _is_mutable_static_value(value):
return isinstance(value, (dict, list, set))
def _target_names(target): def _target_names(target):
if isinstance(target, ast.Name): if isinstance(target, ast.Name):
return {target.id} return {target.id}
@@ -69,6 +73,81 @@ def _target_names(target):
return set() return set()
def _pattern_bound_names(pattern):
names = set()
if isinstance(pattern, ast.MatchAs):
if pattern.name:
names.add(pattern.name)
if pattern.pattern is not None:
names.update(_pattern_bound_names(pattern.pattern))
elif isinstance(pattern, ast.MatchStar):
if pattern.name:
names.add(pattern.name)
elif isinstance(pattern, ast.MatchMapping):
if pattern.rest:
names.add(pattern.rest)
for subpattern in pattern.patterns:
names.update(_pattern_bound_names(subpattern))
elif isinstance(pattern, ast.MatchSequence):
for subpattern in pattern.patterns:
names.update(_pattern_bound_names(subpattern))
elif isinstance(pattern, ast.MatchClass):
for subpattern in pattern.patterns:
names.update(_pattern_bound_names(subpattern))
for subpattern in pattern.kwd_patterns:
names.update(_pattern_bound_names(subpattern))
elif isinstance(pattern, ast.MatchOr):
for subpattern in pattern.patterns:
names.update(_pattern_bound_names(subpattern))
return names
def _named_expr_target_names(node):
names = set()
class NamedExprVisitor(ast.NodeVisitor):
def visit_FunctionDef(self, child):
return None
def visit_AsyncFunctionDef(self, child):
return None
def visit_ClassDef(self, child):
return None
def visit_Lambda(self, child):
return None
def visit_NamedExpr(self, child):
names.update(_target_names(child.target))
self.visit(child.value)
NamedExprVisitor().visit(node)
return names
def _bound_names(stmt):
names = set()
if isinstance(stmt, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
names.add(stmt.name)
elif isinstance(stmt, ast.Import):
for alias in stmt.names:
names.add(alias.asname or alias.name.split(".", 1)[0])
elif isinstance(stmt, ast.ImportFrom):
for alias in stmt.names:
if alias.name != "*":
names.add(alias.asname or alias.name)
elif isinstance(stmt, (ast.With, ast.AsyncWith)):
for item in stmt.items:
if item.optional_vars is not None:
names.update(_target_names(item.optional_vars))
elif isinstance(stmt, ast.Match):
for case in stmt.cases:
names.update(_pattern_bound_names(case.pattern))
names.update(_named_expr_target_names(stmt))
return names
def _assignment_target_names(stmt): def _assignment_target_names(stmt):
if isinstance(stmt, ast.Assign): if isinstance(stmt, ast.Assign):
names = set() names = set()
@@ -107,14 +186,23 @@ def _assigned_names_in_control_flow(stmt):
class AssignmentVisitor(ast.NodeVisitor): class AssignmentVisitor(ast.NodeVisitor):
def visit_FunctionDef(self, node): def visit_FunctionDef(self, node):
names.add(node.name)
return None return None
def visit_AsyncFunctionDef(self, node): def visit_AsyncFunctionDef(self, node):
names.add(node.name)
return None return None
def visit_ClassDef(self, node): def visit_ClassDef(self, node):
names.add(node.name)
return None return None
def visit_Import(self, node):
names.update(_bound_names(node))
def visit_ImportFrom(self, node):
names.update(_bound_names(node))
def visit_Assign(self, node): def visit_Assign(self, node):
names.update(_assignment_target_names(node)) names.update(_assignment_target_names(node))
@@ -129,6 +217,23 @@ def _assigned_names_in_control_flow(stmt):
def visit_Expr(self, node): def visit_Expr(self, node):
names.update(_mutating_call_target_names(node)) names.update(_mutating_call_target_names(node))
names.update(_named_expr_target_names(node))
def visit_With(self, node):
names.update(_bound_names(node))
self.generic_visit(node)
def visit_AsyncWith(self, node):
names.update(_bound_names(node))
self.generic_visit(node)
def visit_NamedExpr(self, node):
names.update(_target_names(node.target))
self.visit(node.value)
def visit_Match(self, node):
names.update(_bound_names(node))
self.generic_visit(node)
def visit_For(self, node): def visit_For(self, node):
names.update(_assignment_target_names(node)) names.update(_assignment_target_names(node))
@@ -149,6 +254,14 @@ def _collect_module_env(tree):
names = _assignment_target_names(stmt) names = _assignment_target_names(stmt)
if len(stmt.targets) == 1 and isinstance(stmt.targets[0], ast.Name): if len(stmt.targets) == 1 and isinstance(stmt.targets[0], ast.Name):
name = stmt.targets[0].id name = stmt.targets[0].id
if (
isinstance(stmt.value, ast.Name)
and stmt.value.id in env
and _is_mutable_static_value(env[stmt.value.id])
):
env.pop(stmt.value.id, None)
env.pop(name, None)
continue
try: try:
env[name] = _literal(stmt.value, env) env[name] = _literal(stmt.value, env)
except UnsupportedStaticExpression: except UnsupportedStaticExpression:
@@ -182,10 +295,15 @@ def _collect_module_env(tree):
if isinstance(stmt, ast.Expr): if isinstance(stmt, ast.Expr):
for name in _mutating_call_target_names(stmt): for name in _mutating_call_target_names(stmt):
env.pop(name, None) env.pop(name, None)
for name in _bound_names(stmt):
env.pop(name, None)
continue continue
if isinstance(stmt, (ast.If, ast.For, ast.AsyncFor, ast.While, ast.Try)): if isinstance(stmt, (ast.If, ast.For, ast.AsyncFor, ast.While, ast.Try, ast.With, ast.AsyncWith, ast.Match)):
for name in _assigned_names_in_control_flow(stmt): for name in _assigned_names_in_control_flow(stmt):
env.pop(name, None) env.pop(name, None)
continue
for name in _bound_names(stmt):
env.pop(name, None)
return env return env
@@ -235,10 +353,15 @@ def _class_attr(cls, name, env):
if isinstance(stmt, ast.Expr): if isinstance(stmt, ast.Expr):
if name in _mutating_call_target_names(stmt): if name in _mutating_call_target_names(stmt):
value = _INVALID value = _INVALID
if name in _bound_names(stmt):
value = _INVALID
continue continue
if isinstance(stmt, (ast.If, ast.For, ast.AsyncFor, ast.While, ast.Try)): if isinstance(stmt, (ast.If, ast.For, ast.AsyncFor, ast.While, ast.Try, ast.With, ast.AsyncWith, ast.Match)):
if name in _assigned_names_in_control_flow(stmt): if name in _assigned_names_in_control_flow(stmt):
value = _INVALID value = _INVALID
continue
if name in _bound_names(stmt):
value = _INVALID
if value in (_MISSING, _INVALID): if value in (_MISSING, _INVALID):
return None return None
return value return value
@@ -289,6 +412,8 @@ def _final_module_dict(tree, env, name, value_converter):
for stmt in tree.body: for stmt in tree.body:
if isinstance(stmt, ast.Assign): if isinstance(stmt, ast.Assign):
if not _name_is_assigned(stmt, name): if not _name_is_assigned(stmt, name):
if isinstance(stmt.value, ast.Name) and stmt.value.id == name:
value = _INVALID
continue continue
if len(stmt.targets) == 1 and isinstance(stmt.targets[0], ast.Name): if len(stmt.targets) == 1 and isinstance(stmt.targets[0], ast.Name):
try: try:
@@ -320,10 +445,15 @@ def _final_module_dict(tree, env, name, value_converter):
if isinstance(stmt, ast.Expr): if isinstance(stmt, ast.Expr):
if name in _mutating_call_target_names(stmt): if name in _mutating_call_target_names(stmt):
value = _INVALID value = _INVALID
if name in _bound_names(stmt):
value = _INVALID
continue continue
if isinstance(stmt, (ast.If, ast.For, ast.AsyncFor, ast.While, ast.Try)): if isinstance(stmt, (ast.If, ast.For, ast.AsyncFor, ast.While, ast.Try, ast.With, ast.AsyncWith, ast.Match)):
if name in _assigned_names_in_control_flow(stmt): if name in _assigned_names_in_control_flow(stmt):
value = _INVALID value = _INVALID
continue
if name in _bound_names(stmt):
value = _INVALID
if value in (_MISSING, _INVALID): if value in (_MISSING, _INVALID):
return {} return {}
return value return value
@@ -387,9 +517,6 @@ def _parse_python_file(path):
try: try:
return ast.parse(path.read_text(encoding="utf-8"), filename=str(path)) return ast.parse(path.read_text(encoding="utf-8"), filename=str(path))
except UnicodeDecodeError: except UnicodeDecodeError:
try:
return ast.parse(path.read_text(encoding="utf-8", errors="ignore"), filename=str(path))
except SyntaxError:
return None return None
except SyntaxError: except SyntaxError:
return None return None