Invalidate static extraction on rebinding and alias mutation
This commit is contained in:
@@ -150,6 +150,59 @@ NODE_CLASS_MAPPINGS = {
|
||||
self.assertIn("GoodNode", result["nodes"])
|
||||
self.assertEqual("ok", result["pack"]["status"])
|
||||
|
||||
def test_skips_undecodable_python_files_without_modified_parse(self):
|
||||
undecodable_source = b'''
|
||||
# invalid byte follows: \xff
|
||||
class UndecodableNode:
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"image": ("IMAGE",),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"UndecodableNode": UndecodableNode,
|
||||
}
|
||||
'''
|
||||
good_source = '''
|
||||
class GoodUtf8Node:
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"image": ("IMAGE",),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"GoodUtf8Node": GoodUtf8Node,
|
||||
}
|
||||
'''
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
Path(tmp, "bad.py").write_bytes(undecodable_source)
|
||||
Path(tmp, "good.py").write_text(textwrap.dedent(good_source), encoding="utf-8")
|
||||
result = extract_repo_signatures(
|
||||
Path(tmp),
|
||||
{
|
||||
"id": "undecodable-pack",
|
||||
"title": "Undecodable Pack",
|
||||
"repository": "https://github.com/example/undecodable-pack",
|
||||
"rank": 1,
|
||||
},
|
||||
)
|
||||
|
||||
self.assertNotIn("UndecodableNode", result["nodes"])
|
||||
self.assertIn("GoodUtf8Node", result["nodes"])
|
||||
self.assertEqual("ok", result["pack"]["status"])
|
||||
|
||||
def test_unsupported_reassignment_invalidates_static_env_value(self):
|
||||
source = '''
|
||||
def build_inputs():
|
||||
@@ -191,6 +244,121 @@ NODE_CLASS_MAPPINGS = {
|
||||
self.assertEqual({}, result["nodes"])
|
||||
self.assertEqual("no_static_nodes", result["pack"]["status"])
|
||||
|
||||
def test_function_binding_invalidates_static_env_value(self):
|
||||
source = '''
|
||||
INPUTS = {
|
||||
"required": {
|
||||
"image": ("IMAGE",),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def INPUTS():
|
||||
return {}
|
||||
|
||||
|
||||
class FunctionRebindNode:
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return INPUTS
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"FunctionRebindNode": FunctionRebindNode,
|
||||
}
|
||||
'''
|
||||
result = self._extract_source(source, "function-rebind-pack")
|
||||
|
||||
self.assertEqual({}, result["nodes"])
|
||||
self.assertEqual("no_static_nodes", result["pack"]["status"])
|
||||
|
||||
def test_class_binding_invalidates_static_env_value(self):
|
||||
source = '''
|
||||
INPUTS = {
|
||||
"required": {
|
||||
"image": ("IMAGE",),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class INPUTS:
|
||||
pass
|
||||
|
||||
|
||||
class ClassRebindNode:
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return INPUTS
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"ClassRebindNode": ClassRebindNode,
|
||||
}
|
||||
'''
|
||||
result = self._extract_source(source, "class-rebind-pack")
|
||||
|
||||
self.assertEqual({}, result["nodes"])
|
||||
self.assertEqual("no_static_nodes", result["pack"]["status"])
|
||||
|
||||
def test_import_binding_invalidates_static_env_value(self):
|
||||
source = '''
|
||||
INPUTS = {
|
||||
"required": {
|
||||
"image": ("IMAGE",),
|
||||
},
|
||||
}
|
||||
import something as INPUTS
|
||||
|
||||
|
||||
class ImportRebindNode:
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return INPUTS
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"ImportRebindNode": ImportRebindNode,
|
||||
}
|
||||
'''
|
||||
result = self._extract_source(source, "import-rebind-pack")
|
||||
|
||||
self.assertEqual({}, result["nodes"])
|
||||
self.assertEqual("no_static_nodes", result["pack"]["status"])
|
||||
|
||||
def test_alias_mutation_invalidates_static_source_value(self):
|
||||
source = '''
|
||||
INPUTS = {
|
||||
"required": {
|
||||
"image": ("IMAGE",),
|
||||
},
|
||||
}
|
||||
ALIAS = INPUTS
|
||||
ALIAS.clear()
|
||||
|
||||
|
||||
class AliasMutatedInputNode:
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return INPUTS
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"AliasMutatedInputNode": AliasMutatedInputNode,
|
||||
}
|
||||
'''
|
||||
result = self._extract_source(source, "alias-mutated-input-pack")
|
||||
|
||||
self.assertEqual({}, result["nodes"])
|
||||
self.assertEqual("no_static_nodes", result["pack"]["status"])
|
||||
|
||||
def test_annotated_reassignment_invalidates_static_env_value(self):
|
||||
source = '''
|
||||
def build_inputs():
|
||||
|
||||
@@ -54,6 +54,10 @@ def _literal(node, env):
|
||||
raise UnsupportedStaticExpression(type(node).__name__)
|
||||
|
||||
|
||||
def _is_mutable_static_value(value):
|
||||
return isinstance(value, (dict, list, set))
|
||||
|
||||
|
||||
def _target_names(target):
|
||||
if isinstance(target, ast.Name):
|
||||
return {target.id}
|
||||
@@ -69,6 +73,81 @@ def _target_names(target):
|
||||
return set()
|
||||
|
||||
|
||||
def _pattern_bound_names(pattern):
|
||||
names = set()
|
||||
if isinstance(pattern, ast.MatchAs):
|
||||
if pattern.name:
|
||||
names.add(pattern.name)
|
||||
if pattern.pattern is not None:
|
||||
names.update(_pattern_bound_names(pattern.pattern))
|
||||
elif isinstance(pattern, ast.MatchStar):
|
||||
if pattern.name:
|
||||
names.add(pattern.name)
|
||||
elif isinstance(pattern, ast.MatchMapping):
|
||||
if pattern.rest:
|
||||
names.add(pattern.rest)
|
||||
for subpattern in pattern.patterns:
|
||||
names.update(_pattern_bound_names(subpattern))
|
||||
elif isinstance(pattern, ast.MatchSequence):
|
||||
for subpattern in pattern.patterns:
|
||||
names.update(_pattern_bound_names(subpattern))
|
||||
elif isinstance(pattern, ast.MatchClass):
|
||||
for subpattern in pattern.patterns:
|
||||
names.update(_pattern_bound_names(subpattern))
|
||||
for subpattern in pattern.kwd_patterns:
|
||||
names.update(_pattern_bound_names(subpattern))
|
||||
elif isinstance(pattern, ast.MatchOr):
|
||||
for subpattern in pattern.patterns:
|
||||
names.update(_pattern_bound_names(subpattern))
|
||||
return names
|
||||
|
||||
|
||||
def _named_expr_target_names(node):
|
||||
names = set()
|
||||
|
||||
class NamedExprVisitor(ast.NodeVisitor):
|
||||
def visit_FunctionDef(self, child):
|
||||
return None
|
||||
|
||||
def visit_AsyncFunctionDef(self, child):
|
||||
return None
|
||||
|
||||
def visit_ClassDef(self, child):
|
||||
return None
|
||||
|
||||
def visit_Lambda(self, child):
|
||||
return None
|
||||
|
||||
def visit_NamedExpr(self, child):
|
||||
names.update(_target_names(child.target))
|
||||
self.visit(child.value)
|
||||
|
||||
NamedExprVisitor().visit(node)
|
||||
return names
|
||||
|
||||
|
||||
def _bound_names(stmt):
|
||||
names = set()
|
||||
if isinstance(stmt, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
|
||||
names.add(stmt.name)
|
||||
elif isinstance(stmt, ast.Import):
|
||||
for alias in stmt.names:
|
||||
names.add(alias.asname or alias.name.split(".", 1)[0])
|
||||
elif isinstance(stmt, ast.ImportFrom):
|
||||
for alias in stmt.names:
|
||||
if alias.name != "*":
|
||||
names.add(alias.asname or alias.name)
|
||||
elif isinstance(stmt, (ast.With, ast.AsyncWith)):
|
||||
for item in stmt.items:
|
||||
if item.optional_vars is not None:
|
||||
names.update(_target_names(item.optional_vars))
|
||||
elif isinstance(stmt, ast.Match):
|
||||
for case in stmt.cases:
|
||||
names.update(_pattern_bound_names(case.pattern))
|
||||
names.update(_named_expr_target_names(stmt))
|
||||
return names
|
||||
|
||||
|
||||
def _assignment_target_names(stmt):
|
||||
if isinstance(stmt, ast.Assign):
|
||||
names = set()
|
||||
@@ -107,14 +186,23 @@ def _assigned_names_in_control_flow(stmt):
|
||||
|
||||
class AssignmentVisitor(ast.NodeVisitor):
|
||||
def visit_FunctionDef(self, node):
|
||||
names.add(node.name)
|
||||
return None
|
||||
|
||||
def visit_AsyncFunctionDef(self, node):
|
||||
names.add(node.name)
|
||||
return None
|
||||
|
||||
def visit_ClassDef(self, node):
|
||||
names.add(node.name)
|
||||
return None
|
||||
|
||||
def visit_Import(self, node):
|
||||
names.update(_bound_names(node))
|
||||
|
||||
def visit_ImportFrom(self, node):
|
||||
names.update(_bound_names(node))
|
||||
|
||||
def visit_Assign(self, node):
|
||||
names.update(_assignment_target_names(node))
|
||||
|
||||
@@ -129,6 +217,23 @@ def _assigned_names_in_control_flow(stmt):
|
||||
|
||||
def visit_Expr(self, node):
|
||||
names.update(_mutating_call_target_names(node))
|
||||
names.update(_named_expr_target_names(node))
|
||||
|
||||
def visit_With(self, node):
|
||||
names.update(_bound_names(node))
|
||||
self.generic_visit(node)
|
||||
|
||||
def visit_AsyncWith(self, node):
|
||||
names.update(_bound_names(node))
|
||||
self.generic_visit(node)
|
||||
|
||||
def visit_NamedExpr(self, node):
|
||||
names.update(_target_names(node.target))
|
||||
self.visit(node.value)
|
||||
|
||||
def visit_Match(self, node):
|
||||
names.update(_bound_names(node))
|
||||
self.generic_visit(node)
|
||||
|
||||
def visit_For(self, node):
|
||||
names.update(_assignment_target_names(node))
|
||||
@@ -149,6 +254,14 @@ def _collect_module_env(tree):
|
||||
names = _assignment_target_names(stmt)
|
||||
if len(stmt.targets) == 1 and isinstance(stmt.targets[0], ast.Name):
|
||||
name = stmt.targets[0].id
|
||||
if (
|
||||
isinstance(stmt.value, ast.Name)
|
||||
and stmt.value.id in env
|
||||
and _is_mutable_static_value(env[stmt.value.id])
|
||||
):
|
||||
env.pop(stmt.value.id, None)
|
||||
env.pop(name, None)
|
||||
continue
|
||||
try:
|
||||
env[name] = _literal(stmt.value, env)
|
||||
except UnsupportedStaticExpression:
|
||||
@@ -182,10 +295,15 @@ def _collect_module_env(tree):
|
||||
if isinstance(stmt, ast.Expr):
|
||||
for name in _mutating_call_target_names(stmt):
|
||||
env.pop(name, None)
|
||||
for name in _bound_names(stmt):
|
||||
env.pop(name, None)
|
||||
continue
|
||||
if isinstance(stmt, (ast.If, ast.For, ast.AsyncFor, ast.While, ast.Try)):
|
||||
if isinstance(stmt, (ast.If, ast.For, ast.AsyncFor, ast.While, ast.Try, ast.With, ast.AsyncWith, ast.Match)):
|
||||
for name in _assigned_names_in_control_flow(stmt):
|
||||
env.pop(name, None)
|
||||
continue
|
||||
for name in _bound_names(stmt):
|
||||
env.pop(name, None)
|
||||
return env
|
||||
|
||||
|
||||
@@ -235,10 +353,15 @@ def _class_attr(cls, name, env):
|
||||
if isinstance(stmt, ast.Expr):
|
||||
if name in _mutating_call_target_names(stmt):
|
||||
value = _INVALID
|
||||
if name in _bound_names(stmt):
|
||||
value = _INVALID
|
||||
continue
|
||||
if isinstance(stmt, (ast.If, ast.For, ast.AsyncFor, ast.While, ast.Try)):
|
||||
if isinstance(stmt, (ast.If, ast.For, ast.AsyncFor, ast.While, ast.Try, ast.With, ast.AsyncWith, ast.Match)):
|
||||
if name in _assigned_names_in_control_flow(stmt):
|
||||
value = _INVALID
|
||||
continue
|
||||
if name in _bound_names(stmt):
|
||||
value = _INVALID
|
||||
if value in (_MISSING, _INVALID):
|
||||
return None
|
||||
return value
|
||||
@@ -289,6 +412,8 @@ def _final_module_dict(tree, env, name, value_converter):
|
||||
for stmt in tree.body:
|
||||
if isinstance(stmt, ast.Assign):
|
||||
if not _name_is_assigned(stmt, name):
|
||||
if isinstance(stmt.value, ast.Name) and stmt.value.id == name:
|
||||
value = _INVALID
|
||||
continue
|
||||
if len(stmt.targets) == 1 and isinstance(stmt.targets[0], ast.Name):
|
||||
try:
|
||||
@@ -320,10 +445,15 @@ def _final_module_dict(tree, env, name, value_converter):
|
||||
if isinstance(stmt, ast.Expr):
|
||||
if name in _mutating_call_target_names(stmt):
|
||||
value = _INVALID
|
||||
if name in _bound_names(stmt):
|
||||
value = _INVALID
|
||||
continue
|
||||
if isinstance(stmt, (ast.If, ast.For, ast.AsyncFor, ast.While, ast.Try)):
|
||||
if isinstance(stmt, (ast.If, ast.For, ast.AsyncFor, ast.While, ast.Try, ast.With, ast.AsyncWith, ast.Match)):
|
||||
if name in _assigned_names_in_control_flow(stmt):
|
||||
value = _INVALID
|
||||
continue
|
||||
if name in _bound_names(stmt):
|
||||
value = _INVALID
|
||||
if value in (_MISSING, _INVALID):
|
||||
return {}
|
||||
return value
|
||||
@@ -387,9 +517,6 @@ def _parse_python_file(path):
|
||||
try:
|
||||
return ast.parse(path.read_text(encoding="utf-8"), filename=str(path))
|
||||
except UnicodeDecodeError:
|
||||
try:
|
||||
return ast.parse(path.read_text(encoding="utf-8", errors="ignore"), filename=str(path))
|
||||
except SyntaxError:
|
||||
return None
|
||||
except SyntaxError:
|
||||
return None
|
||||
|
||||
Reference in New Issue
Block a user