Detect mutating calls inside statements

This commit is contained in:
2026-07-02 14:29:04 +02:00
parent 6c2653c803
commit 45e3cbaad8
2 changed files with 111 additions and 9 deletions
+36 -9
View File
@@ -178,18 +178,32 @@ def _delete_target_names(stmt):
def _mutating_call_target_names(stmt):
if not isinstance(stmt, ast.Expr):
return set()
call = stmt.value
if not isinstance(call, ast.Call) or not isinstance(call.func, ast.Attribute):
return set()
if call.func.attr not in _MUTATING_METHODS:
return set()
return _target_names(call.func.value)
names = set()
class MutatingCallVisitor(ast.NodeVisitor):
def visit_FunctionDef(self, node):
return None
def visit_AsyncFunctionDef(self, node):
return None
def visit_ClassDef(self, node):
return None
def visit_Lambda(self, node):
return None
def visit_Call(self, node):
if isinstance(node.func, ast.Attribute) and node.func.attr in _MUTATING_METHODS:
names.update(_target_names(node.func.value))
self.generic_visit(node)
MutatingCallVisitor().visit(stmt)
return names
def _assigned_names_in_control_flow(stmt):
names = set()
names = _mutating_call_target_names(stmt)
class AssignmentVisitor(ast.NodeVisitor):
def visit_FunctionDef(self, node):
@@ -296,6 +310,10 @@ def _invalidate_class_bindings(class_bindings, names):
def _collect_module_env(tree, class_bindings=None):
env = {}
for stmt in tree.body:
names = _mutating_call_target_names(stmt)
_invalidate_class_bindings(class_bindings, names)
for name in names:
env.pop(name, None)
if isinstance(stmt, ast.ClassDef):
if class_bindings is not None:
class_bindings[stmt.name] = (stmt, dict(env))
@@ -409,6 +427,11 @@ def _class_attr(cls, name, env):
value = _MISSING
aliases = set()
for stmt in cls.body:
mutating_targets = _mutating_call_target_names(stmt)
if aliases.intersection(mutating_targets):
value = _INVALID
if name in mutating_targets:
value = _INVALID
if isinstance(stmt, ast.Assign):
target_names = _assignment_target_names(stmt)
if (
@@ -510,6 +533,8 @@ def _class_attr(cls, name, env):
def _input_types(cls, env):
value = _MISSING
for stmt in cls.body:
if "INPUT_TYPES" in _mutating_call_target_names(stmt):
value = _INVALID
if isinstance(stmt, ast.FunctionDef) and stmt.name == "INPUT_TYPES":
if len(stmt.body) != 1 or not isinstance(stmt.body[0], ast.Return):
value = _INVALID
@@ -580,6 +605,8 @@ def _module_dict_entries(node, env, value_converter):
def _final_module_dict(tree, env, name, value_converter):
value = _MISSING
for stmt in tree.body:
if name in _mutating_call_target_names(stmt):
value = _INVALID
if isinstance(stmt, ast.Assign):
if not _name_is_assigned(stmt, name):
if isinstance(stmt.value, ast.Name) and stmt.value.id == name: