Invalidate static extraction on rebinding and alias mutation
This commit is contained in:
@@ -54,6 +54,10 @@ def _literal(node, env):
|
||||
raise UnsupportedStaticExpression(type(node).__name__)
|
||||
|
||||
|
||||
def _is_mutable_static_value(value):
|
||||
return isinstance(value, (dict, list, set))
|
||||
|
||||
|
||||
def _target_names(target):
|
||||
if isinstance(target, ast.Name):
|
||||
return {target.id}
|
||||
@@ -69,6 +73,81 @@ def _target_names(target):
|
||||
return set()
|
||||
|
||||
|
||||
def _pattern_bound_names(pattern):
|
||||
names = set()
|
||||
if isinstance(pattern, ast.MatchAs):
|
||||
if pattern.name:
|
||||
names.add(pattern.name)
|
||||
if pattern.pattern is not None:
|
||||
names.update(_pattern_bound_names(pattern.pattern))
|
||||
elif isinstance(pattern, ast.MatchStar):
|
||||
if pattern.name:
|
||||
names.add(pattern.name)
|
||||
elif isinstance(pattern, ast.MatchMapping):
|
||||
if pattern.rest:
|
||||
names.add(pattern.rest)
|
||||
for subpattern in pattern.patterns:
|
||||
names.update(_pattern_bound_names(subpattern))
|
||||
elif isinstance(pattern, ast.MatchSequence):
|
||||
for subpattern in pattern.patterns:
|
||||
names.update(_pattern_bound_names(subpattern))
|
||||
elif isinstance(pattern, ast.MatchClass):
|
||||
for subpattern in pattern.patterns:
|
||||
names.update(_pattern_bound_names(subpattern))
|
||||
for subpattern in pattern.kwd_patterns:
|
||||
names.update(_pattern_bound_names(subpattern))
|
||||
elif isinstance(pattern, ast.MatchOr):
|
||||
for subpattern in pattern.patterns:
|
||||
names.update(_pattern_bound_names(subpattern))
|
||||
return names
|
||||
|
||||
|
||||
def _named_expr_target_names(node):
|
||||
names = set()
|
||||
|
||||
class NamedExprVisitor(ast.NodeVisitor):
|
||||
def visit_FunctionDef(self, child):
|
||||
return None
|
||||
|
||||
def visit_AsyncFunctionDef(self, child):
|
||||
return None
|
||||
|
||||
def visit_ClassDef(self, child):
|
||||
return None
|
||||
|
||||
def visit_Lambda(self, child):
|
||||
return None
|
||||
|
||||
def visit_NamedExpr(self, child):
|
||||
names.update(_target_names(child.target))
|
||||
self.visit(child.value)
|
||||
|
||||
NamedExprVisitor().visit(node)
|
||||
return names
|
||||
|
||||
|
||||
def _bound_names(stmt):
|
||||
names = set()
|
||||
if isinstance(stmt, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
|
||||
names.add(stmt.name)
|
||||
elif isinstance(stmt, ast.Import):
|
||||
for alias in stmt.names:
|
||||
names.add(alias.asname or alias.name.split(".", 1)[0])
|
||||
elif isinstance(stmt, ast.ImportFrom):
|
||||
for alias in stmt.names:
|
||||
if alias.name != "*":
|
||||
names.add(alias.asname or alias.name)
|
||||
elif isinstance(stmt, (ast.With, ast.AsyncWith)):
|
||||
for item in stmt.items:
|
||||
if item.optional_vars is not None:
|
||||
names.update(_target_names(item.optional_vars))
|
||||
elif isinstance(stmt, ast.Match):
|
||||
for case in stmt.cases:
|
||||
names.update(_pattern_bound_names(case.pattern))
|
||||
names.update(_named_expr_target_names(stmt))
|
||||
return names
|
||||
|
||||
|
||||
def _assignment_target_names(stmt):
|
||||
if isinstance(stmt, ast.Assign):
|
||||
names = set()
|
||||
@@ -107,14 +186,23 @@ def _assigned_names_in_control_flow(stmt):
|
||||
|
||||
class AssignmentVisitor(ast.NodeVisitor):
|
||||
def visit_FunctionDef(self, node):
|
||||
names.add(node.name)
|
||||
return None
|
||||
|
||||
def visit_AsyncFunctionDef(self, node):
|
||||
names.add(node.name)
|
||||
return None
|
||||
|
||||
def visit_ClassDef(self, node):
|
||||
names.add(node.name)
|
||||
return None
|
||||
|
||||
def visit_Import(self, node):
|
||||
names.update(_bound_names(node))
|
||||
|
||||
def visit_ImportFrom(self, node):
|
||||
names.update(_bound_names(node))
|
||||
|
||||
def visit_Assign(self, node):
|
||||
names.update(_assignment_target_names(node))
|
||||
|
||||
@@ -129,6 +217,23 @@ def _assigned_names_in_control_flow(stmt):
|
||||
|
||||
def visit_Expr(self, node):
|
||||
names.update(_mutating_call_target_names(node))
|
||||
names.update(_named_expr_target_names(node))
|
||||
|
||||
def visit_With(self, node):
|
||||
names.update(_bound_names(node))
|
||||
self.generic_visit(node)
|
||||
|
||||
def visit_AsyncWith(self, node):
|
||||
names.update(_bound_names(node))
|
||||
self.generic_visit(node)
|
||||
|
||||
def visit_NamedExpr(self, node):
|
||||
names.update(_target_names(node.target))
|
||||
self.visit(node.value)
|
||||
|
||||
def visit_Match(self, node):
|
||||
names.update(_bound_names(node))
|
||||
self.generic_visit(node)
|
||||
|
||||
def visit_For(self, node):
|
||||
names.update(_assignment_target_names(node))
|
||||
@@ -149,6 +254,14 @@ def _collect_module_env(tree):
|
||||
names = _assignment_target_names(stmt)
|
||||
if len(stmt.targets) == 1 and isinstance(stmt.targets[0], ast.Name):
|
||||
name = stmt.targets[0].id
|
||||
if (
|
||||
isinstance(stmt.value, ast.Name)
|
||||
and stmt.value.id in env
|
||||
and _is_mutable_static_value(env[stmt.value.id])
|
||||
):
|
||||
env.pop(stmt.value.id, None)
|
||||
env.pop(name, None)
|
||||
continue
|
||||
try:
|
||||
env[name] = _literal(stmt.value, env)
|
||||
except UnsupportedStaticExpression:
|
||||
@@ -182,10 +295,15 @@ def _collect_module_env(tree):
|
||||
if isinstance(stmt, ast.Expr):
|
||||
for name in _mutating_call_target_names(stmt):
|
||||
env.pop(name, None)
|
||||
for name in _bound_names(stmt):
|
||||
env.pop(name, None)
|
||||
continue
|
||||
if isinstance(stmt, (ast.If, ast.For, ast.AsyncFor, ast.While, ast.Try)):
|
||||
if isinstance(stmt, (ast.If, ast.For, ast.AsyncFor, ast.While, ast.Try, ast.With, ast.AsyncWith, ast.Match)):
|
||||
for name in _assigned_names_in_control_flow(stmt):
|
||||
env.pop(name, None)
|
||||
continue
|
||||
for name in _bound_names(stmt):
|
||||
env.pop(name, None)
|
||||
return env
|
||||
|
||||
|
||||
@@ -235,10 +353,15 @@ def _class_attr(cls, name, env):
|
||||
if isinstance(stmt, ast.Expr):
|
||||
if name in _mutating_call_target_names(stmt):
|
||||
value = _INVALID
|
||||
if name in _bound_names(stmt):
|
||||
value = _INVALID
|
||||
continue
|
||||
if isinstance(stmt, (ast.If, ast.For, ast.AsyncFor, ast.While, ast.Try)):
|
||||
if isinstance(stmt, (ast.If, ast.For, ast.AsyncFor, ast.While, ast.Try, ast.With, ast.AsyncWith, ast.Match)):
|
||||
if name in _assigned_names_in_control_flow(stmt):
|
||||
value = _INVALID
|
||||
continue
|
||||
if name in _bound_names(stmt):
|
||||
value = _INVALID
|
||||
if value in (_MISSING, _INVALID):
|
||||
return None
|
||||
return value
|
||||
@@ -289,6 +412,8 @@ def _final_module_dict(tree, env, name, value_converter):
|
||||
for stmt in tree.body:
|
||||
if isinstance(stmt, ast.Assign):
|
||||
if not _name_is_assigned(stmt, name):
|
||||
if isinstance(stmt.value, ast.Name) and stmt.value.id == name:
|
||||
value = _INVALID
|
||||
continue
|
||||
if len(stmt.targets) == 1 and isinstance(stmt.targets[0], ast.Name):
|
||||
try:
|
||||
@@ -320,10 +445,15 @@ def _final_module_dict(tree, env, name, value_converter):
|
||||
if isinstance(stmt, ast.Expr):
|
||||
if name in _mutating_call_target_names(stmt):
|
||||
value = _INVALID
|
||||
if name in _bound_names(stmt):
|
||||
value = _INVALID
|
||||
continue
|
||||
if isinstance(stmt, (ast.If, ast.For, ast.AsyncFor, ast.While, ast.Try)):
|
||||
if isinstance(stmt, (ast.If, ast.For, ast.AsyncFor, ast.While, ast.Try, ast.With, ast.AsyncWith, ast.Match)):
|
||||
if name in _assigned_names_in_control_flow(stmt):
|
||||
value = _INVALID
|
||||
continue
|
||||
if name in _bound_names(stmt):
|
||||
value = _INVALID
|
||||
if value in (_MISSING, _INVALID):
|
||||
return {}
|
||||
return value
|
||||
@@ -387,10 +517,7 @@ def _parse_python_file(path):
|
||||
try:
|
||||
return ast.parse(path.read_text(encoding="utf-8"), filename=str(path))
|
||||
except UnicodeDecodeError:
|
||||
try:
|
||||
return ast.parse(path.read_text(encoding="utf-8", errors="ignore"), filename=str(path))
|
||||
except SyntaxError:
|
||||
return None
|
||||
return None
|
||||
except SyntaxError:
|
||||
return None
|
||||
|
||||
|
||||
Reference in New Issue
Block a user