Invalidate static extraction on rebinding and alias mutation
This commit is contained in:
@@ -150,6 +150,59 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
self.assertIn("GoodNode", result["nodes"])
|
self.assertIn("GoodNode", result["nodes"])
|
||||||
self.assertEqual("ok", result["pack"]["status"])
|
self.assertEqual("ok", result["pack"]["status"])
|
||||||
|
|
||||||
|
def test_skips_undecodable_python_files_without_modified_parse(self):
|
||||||
|
undecodable_source = b'''
|
||||||
|
# invalid byte follows: \xff
|
||||||
|
class UndecodableNode:
|
||||||
|
RETURN_TYPES = ("IMAGE",)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"image": ("IMAGE",),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"UndecodableNode": UndecodableNode,
|
||||||
|
}
|
||||||
|
'''
|
||||||
|
good_source = '''
|
||||||
|
class GoodUtf8Node:
|
||||||
|
RETURN_TYPES = ("IMAGE",)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"image": ("IMAGE",),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"GoodUtf8Node": GoodUtf8Node,
|
||||||
|
}
|
||||||
|
'''
|
||||||
|
with tempfile.TemporaryDirectory() as tmp:
|
||||||
|
Path(tmp, "bad.py").write_bytes(undecodable_source)
|
||||||
|
Path(tmp, "good.py").write_text(textwrap.dedent(good_source), encoding="utf-8")
|
||||||
|
result = extract_repo_signatures(
|
||||||
|
Path(tmp),
|
||||||
|
{
|
||||||
|
"id": "undecodable-pack",
|
||||||
|
"title": "Undecodable Pack",
|
||||||
|
"repository": "https://github.com/example/undecodable-pack",
|
||||||
|
"rank": 1,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertNotIn("UndecodableNode", result["nodes"])
|
||||||
|
self.assertIn("GoodUtf8Node", result["nodes"])
|
||||||
|
self.assertEqual("ok", result["pack"]["status"])
|
||||||
|
|
||||||
def test_unsupported_reassignment_invalidates_static_env_value(self):
|
def test_unsupported_reassignment_invalidates_static_env_value(self):
|
||||||
source = '''
|
source = '''
|
||||||
def build_inputs():
|
def build_inputs():
|
||||||
@@ -191,6 +244,121 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
self.assertEqual({}, result["nodes"])
|
self.assertEqual({}, result["nodes"])
|
||||||
self.assertEqual("no_static_nodes", result["pack"]["status"])
|
self.assertEqual("no_static_nodes", result["pack"]["status"])
|
||||||
|
|
||||||
|
def test_function_binding_invalidates_static_env_value(self):
|
||||||
|
source = '''
|
||||||
|
INPUTS = {
|
||||||
|
"required": {
|
||||||
|
"image": ("IMAGE",),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def INPUTS():
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
class FunctionRebindNode:
|
||||||
|
RETURN_TYPES = ("IMAGE",)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return INPUTS
|
||||||
|
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"FunctionRebindNode": FunctionRebindNode,
|
||||||
|
}
|
||||||
|
'''
|
||||||
|
result = self._extract_source(source, "function-rebind-pack")
|
||||||
|
|
||||||
|
self.assertEqual({}, result["nodes"])
|
||||||
|
self.assertEqual("no_static_nodes", result["pack"]["status"])
|
||||||
|
|
||||||
|
def test_class_binding_invalidates_static_env_value(self):
|
||||||
|
source = '''
|
||||||
|
INPUTS = {
|
||||||
|
"required": {
|
||||||
|
"image": ("IMAGE",),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class INPUTS:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ClassRebindNode:
|
||||||
|
RETURN_TYPES = ("IMAGE",)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return INPUTS
|
||||||
|
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"ClassRebindNode": ClassRebindNode,
|
||||||
|
}
|
||||||
|
'''
|
||||||
|
result = self._extract_source(source, "class-rebind-pack")
|
||||||
|
|
||||||
|
self.assertEqual({}, result["nodes"])
|
||||||
|
self.assertEqual("no_static_nodes", result["pack"]["status"])
|
||||||
|
|
||||||
|
def test_import_binding_invalidates_static_env_value(self):
|
||||||
|
source = '''
|
||||||
|
INPUTS = {
|
||||||
|
"required": {
|
||||||
|
"image": ("IMAGE",),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
import something as INPUTS
|
||||||
|
|
||||||
|
|
||||||
|
class ImportRebindNode:
|
||||||
|
RETURN_TYPES = ("IMAGE",)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return INPUTS
|
||||||
|
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"ImportRebindNode": ImportRebindNode,
|
||||||
|
}
|
||||||
|
'''
|
||||||
|
result = self._extract_source(source, "import-rebind-pack")
|
||||||
|
|
||||||
|
self.assertEqual({}, result["nodes"])
|
||||||
|
self.assertEqual("no_static_nodes", result["pack"]["status"])
|
||||||
|
|
||||||
|
def test_alias_mutation_invalidates_static_source_value(self):
|
||||||
|
source = '''
|
||||||
|
INPUTS = {
|
||||||
|
"required": {
|
||||||
|
"image": ("IMAGE",),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
ALIAS = INPUTS
|
||||||
|
ALIAS.clear()
|
||||||
|
|
||||||
|
|
||||||
|
class AliasMutatedInputNode:
|
||||||
|
RETURN_TYPES = ("IMAGE",)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return INPUTS
|
||||||
|
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"AliasMutatedInputNode": AliasMutatedInputNode,
|
||||||
|
}
|
||||||
|
'''
|
||||||
|
result = self._extract_source(source, "alias-mutated-input-pack")
|
||||||
|
|
||||||
|
self.assertEqual({}, result["nodes"])
|
||||||
|
self.assertEqual("no_static_nodes", result["pack"]["status"])
|
||||||
|
|
||||||
def test_annotated_reassignment_invalidates_static_env_value(self):
|
def test_annotated_reassignment_invalidates_static_env_value(self):
|
||||||
source = '''
|
source = '''
|
||||||
def build_inputs():
|
def build_inputs():
|
||||||
|
|||||||
@@ -54,6 +54,10 @@ def _literal(node, env):
|
|||||||
raise UnsupportedStaticExpression(type(node).__name__)
|
raise UnsupportedStaticExpression(type(node).__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_mutable_static_value(value):
|
||||||
|
return isinstance(value, (dict, list, set))
|
||||||
|
|
||||||
|
|
||||||
def _target_names(target):
|
def _target_names(target):
|
||||||
if isinstance(target, ast.Name):
|
if isinstance(target, ast.Name):
|
||||||
return {target.id}
|
return {target.id}
|
||||||
@@ -69,6 +73,81 @@ def _target_names(target):
|
|||||||
return set()
|
return set()
|
||||||
|
|
||||||
|
|
||||||
|
def _pattern_bound_names(pattern):
|
||||||
|
names = set()
|
||||||
|
if isinstance(pattern, ast.MatchAs):
|
||||||
|
if pattern.name:
|
||||||
|
names.add(pattern.name)
|
||||||
|
if pattern.pattern is not None:
|
||||||
|
names.update(_pattern_bound_names(pattern.pattern))
|
||||||
|
elif isinstance(pattern, ast.MatchStar):
|
||||||
|
if pattern.name:
|
||||||
|
names.add(pattern.name)
|
||||||
|
elif isinstance(pattern, ast.MatchMapping):
|
||||||
|
if pattern.rest:
|
||||||
|
names.add(pattern.rest)
|
||||||
|
for subpattern in pattern.patterns:
|
||||||
|
names.update(_pattern_bound_names(subpattern))
|
||||||
|
elif isinstance(pattern, ast.MatchSequence):
|
||||||
|
for subpattern in pattern.patterns:
|
||||||
|
names.update(_pattern_bound_names(subpattern))
|
||||||
|
elif isinstance(pattern, ast.MatchClass):
|
||||||
|
for subpattern in pattern.patterns:
|
||||||
|
names.update(_pattern_bound_names(subpattern))
|
||||||
|
for subpattern in pattern.kwd_patterns:
|
||||||
|
names.update(_pattern_bound_names(subpattern))
|
||||||
|
elif isinstance(pattern, ast.MatchOr):
|
||||||
|
for subpattern in pattern.patterns:
|
||||||
|
names.update(_pattern_bound_names(subpattern))
|
||||||
|
return names
|
||||||
|
|
||||||
|
|
||||||
|
def _named_expr_target_names(node):
|
||||||
|
names = set()
|
||||||
|
|
||||||
|
class NamedExprVisitor(ast.NodeVisitor):
|
||||||
|
def visit_FunctionDef(self, child):
|
||||||
|
return None
|
||||||
|
|
||||||
|
def visit_AsyncFunctionDef(self, child):
|
||||||
|
return None
|
||||||
|
|
||||||
|
def visit_ClassDef(self, child):
|
||||||
|
return None
|
||||||
|
|
||||||
|
def visit_Lambda(self, child):
|
||||||
|
return None
|
||||||
|
|
||||||
|
def visit_NamedExpr(self, child):
|
||||||
|
names.update(_target_names(child.target))
|
||||||
|
self.visit(child.value)
|
||||||
|
|
||||||
|
NamedExprVisitor().visit(node)
|
||||||
|
return names
|
||||||
|
|
||||||
|
|
||||||
|
def _bound_names(stmt):
|
||||||
|
names = set()
|
||||||
|
if isinstance(stmt, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
|
||||||
|
names.add(stmt.name)
|
||||||
|
elif isinstance(stmt, ast.Import):
|
||||||
|
for alias in stmt.names:
|
||||||
|
names.add(alias.asname or alias.name.split(".", 1)[0])
|
||||||
|
elif isinstance(stmt, ast.ImportFrom):
|
||||||
|
for alias in stmt.names:
|
||||||
|
if alias.name != "*":
|
||||||
|
names.add(alias.asname or alias.name)
|
||||||
|
elif isinstance(stmt, (ast.With, ast.AsyncWith)):
|
||||||
|
for item in stmt.items:
|
||||||
|
if item.optional_vars is not None:
|
||||||
|
names.update(_target_names(item.optional_vars))
|
||||||
|
elif isinstance(stmt, ast.Match):
|
||||||
|
for case in stmt.cases:
|
||||||
|
names.update(_pattern_bound_names(case.pattern))
|
||||||
|
names.update(_named_expr_target_names(stmt))
|
||||||
|
return names
|
||||||
|
|
||||||
|
|
||||||
def _assignment_target_names(stmt):
|
def _assignment_target_names(stmt):
|
||||||
if isinstance(stmt, ast.Assign):
|
if isinstance(stmt, ast.Assign):
|
||||||
names = set()
|
names = set()
|
||||||
@@ -107,14 +186,23 @@ def _assigned_names_in_control_flow(stmt):
|
|||||||
|
|
||||||
class AssignmentVisitor(ast.NodeVisitor):
|
class AssignmentVisitor(ast.NodeVisitor):
|
||||||
def visit_FunctionDef(self, node):
|
def visit_FunctionDef(self, node):
|
||||||
|
names.add(node.name)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def visit_AsyncFunctionDef(self, node):
|
def visit_AsyncFunctionDef(self, node):
|
||||||
|
names.add(node.name)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def visit_ClassDef(self, node):
|
def visit_ClassDef(self, node):
|
||||||
|
names.add(node.name)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def visit_Import(self, node):
|
||||||
|
names.update(_bound_names(node))
|
||||||
|
|
||||||
|
def visit_ImportFrom(self, node):
|
||||||
|
names.update(_bound_names(node))
|
||||||
|
|
||||||
def visit_Assign(self, node):
|
def visit_Assign(self, node):
|
||||||
names.update(_assignment_target_names(node))
|
names.update(_assignment_target_names(node))
|
||||||
|
|
||||||
@@ -129,6 +217,23 @@ def _assigned_names_in_control_flow(stmt):
|
|||||||
|
|
||||||
def visit_Expr(self, node):
|
def visit_Expr(self, node):
|
||||||
names.update(_mutating_call_target_names(node))
|
names.update(_mutating_call_target_names(node))
|
||||||
|
names.update(_named_expr_target_names(node))
|
||||||
|
|
||||||
|
def visit_With(self, node):
|
||||||
|
names.update(_bound_names(node))
|
||||||
|
self.generic_visit(node)
|
||||||
|
|
||||||
|
def visit_AsyncWith(self, node):
|
||||||
|
names.update(_bound_names(node))
|
||||||
|
self.generic_visit(node)
|
||||||
|
|
||||||
|
def visit_NamedExpr(self, node):
|
||||||
|
names.update(_target_names(node.target))
|
||||||
|
self.visit(node.value)
|
||||||
|
|
||||||
|
def visit_Match(self, node):
|
||||||
|
names.update(_bound_names(node))
|
||||||
|
self.generic_visit(node)
|
||||||
|
|
||||||
def visit_For(self, node):
|
def visit_For(self, node):
|
||||||
names.update(_assignment_target_names(node))
|
names.update(_assignment_target_names(node))
|
||||||
@@ -149,6 +254,14 @@ def _collect_module_env(tree):
|
|||||||
names = _assignment_target_names(stmt)
|
names = _assignment_target_names(stmt)
|
||||||
if len(stmt.targets) == 1 and isinstance(stmt.targets[0], ast.Name):
|
if len(stmt.targets) == 1 and isinstance(stmt.targets[0], ast.Name):
|
||||||
name = stmt.targets[0].id
|
name = stmt.targets[0].id
|
||||||
|
if (
|
||||||
|
isinstance(stmt.value, ast.Name)
|
||||||
|
and stmt.value.id in env
|
||||||
|
and _is_mutable_static_value(env[stmt.value.id])
|
||||||
|
):
|
||||||
|
env.pop(stmt.value.id, None)
|
||||||
|
env.pop(name, None)
|
||||||
|
continue
|
||||||
try:
|
try:
|
||||||
env[name] = _literal(stmt.value, env)
|
env[name] = _literal(stmt.value, env)
|
||||||
except UnsupportedStaticExpression:
|
except UnsupportedStaticExpression:
|
||||||
@@ -182,10 +295,15 @@ def _collect_module_env(tree):
|
|||||||
if isinstance(stmt, ast.Expr):
|
if isinstance(stmt, ast.Expr):
|
||||||
for name in _mutating_call_target_names(stmt):
|
for name in _mutating_call_target_names(stmt):
|
||||||
env.pop(name, None)
|
env.pop(name, None)
|
||||||
|
for name in _bound_names(stmt):
|
||||||
|
env.pop(name, None)
|
||||||
continue
|
continue
|
||||||
if isinstance(stmt, (ast.If, ast.For, ast.AsyncFor, ast.While, ast.Try)):
|
if isinstance(stmt, (ast.If, ast.For, ast.AsyncFor, ast.While, ast.Try, ast.With, ast.AsyncWith, ast.Match)):
|
||||||
for name in _assigned_names_in_control_flow(stmt):
|
for name in _assigned_names_in_control_flow(stmt):
|
||||||
env.pop(name, None)
|
env.pop(name, None)
|
||||||
|
continue
|
||||||
|
for name in _bound_names(stmt):
|
||||||
|
env.pop(name, None)
|
||||||
return env
|
return env
|
||||||
|
|
||||||
|
|
||||||
@@ -235,10 +353,15 @@ def _class_attr(cls, name, env):
|
|||||||
if isinstance(stmt, ast.Expr):
|
if isinstance(stmt, ast.Expr):
|
||||||
if name in _mutating_call_target_names(stmt):
|
if name in _mutating_call_target_names(stmt):
|
||||||
value = _INVALID
|
value = _INVALID
|
||||||
|
if name in _bound_names(stmt):
|
||||||
|
value = _INVALID
|
||||||
continue
|
continue
|
||||||
if isinstance(stmt, (ast.If, ast.For, ast.AsyncFor, ast.While, ast.Try)):
|
if isinstance(stmt, (ast.If, ast.For, ast.AsyncFor, ast.While, ast.Try, ast.With, ast.AsyncWith, ast.Match)):
|
||||||
if name in _assigned_names_in_control_flow(stmt):
|
if name in _assigned_names_in_control_flow(stmt):
|
||||||
value = _INVALID
|
value = _INVALID
|
||||||
|
continue
|
||||||
|
if name in _bound_names(stmt):
|
||||||
|
value = _INVALID
|
||||||
if value in (_MISSING, _INVALID):
|
if value in (_MISSING, _INVALID):
|
||||||
return None
|
return None
|
||||||
return value
|
return value
|
||||||
@@ -289,6 +412,8 @@ def _final_module_dict(tree, env, name, value_converter):
|
|||||||
for stmt in tree.body:
|
for stmt in tree.body:
|
||||||
if isinstance(stmt, ast.Assign):
|
if isinstance(stmt, ast.Assign):
|
||||||
if not _name_is_assigned(stmt, name):
|
if not _name_is_assigned(stmt, name):
|
||||||
|
if isinstance(stmt.value, ast.Name) and stmt.value.id == name:
|
||||||
|
value = _INVALID
|
||||||
continue
|
continue
|
||||||
if len(stmt.targets) == 1 and isinstance(stmt.targets[0], ast.Name):
|
if len(stmt.targets) == 1 and isinstance(stmt.targets[0], ast.Name):
|
||||||
try:
|
try:
|
||||||
@@ -320,10 +445,15 @@ def _final_module_dict(tree, env, name, value_converter):
|
|||||||
if isinstance(stmt, ast.Expr):
|
if isinstance(stmt, ast.Expr):
|
||||||
if name in _mutating_call_target_names(stmt):
|
if name in _mutating_call_target_names(stmt):
|
||||||
value = _INVALID
|
value = _INVALID
|
||||||
|
if name in _bound_names(stmt):
|
||||||
|
value = _INVALID
|
||||||
continue
|
continue
|
||||||
if isinstance(stmt, (ast.If, ast.For, ast.AsyncFor, ast.While, ast.Try)):
|
if isinstance(stmt, (ast.If, ast.For, ast.AsyncFor, ast.While, ast.Try, ast.With, ast.AsyncWith, ast.Match)):
|
||||||
if name in _assigned_names_in_control_flow(stmt):
|
if name in _assigned_names_in_control_flow(stmt):
|
||||||
value = _INVALID
|
value = _INVALID
|
||||||
|
continue
|
||||||
|
if name in _bound_names(stmt):
|
||||||
|
value = _INVALID
|
||||||
if value in (_MISSING, _INVALID):
|
if value in (_MISSING, _INVALID):
|
||||||
return {}
|
return {}
|
||||||
return value
|
return value
|
||||||
@@ -387,9 +517,6 @@ def _parse_python_file(path):
|
|||||||
try:
|
try:
|
||||||
return ast.parse(path.read_text(encoding="utf-8"), filename=str(path))
|
return ast.parse(path.read_text(encoding="utf-8"), filename=str(path))
|
||||||
except UnicodeDecodeError:
|
except UnicodeDecodeError:
|
||||||
try:
|
|
||||||
return ast.parse(path.read_text(encoding="utf-8", errors="ignore"), filename=str(path))
|
|
||||||
except SyntaxError:
|
|
||||||
return None
|
return None
|
||||||
except SyntaxError:
|
except SyntaxError:
|
||||||
return None
|
return None
|
||||||
|
|||||||
Reference in New Issue
Block a user