Fail closed on class-body mutations and duplicate inputs

This commit is contained in:
2026-07-02 19:14:53 +02:00
parent 1b56798018
commit c45bf3c230
2 changed files with 205 additions and 0 deletions
+140
View File
@@ -800,8 +800,140 @@ def _is_trivially_safe_class_def(stmt):
)
def _namespace_assignment_target_names(target):
name = _namespace_subscript_name(target)
if name is not None:
return {name}
if isinstance(target, ast.Attribute):
return _namespace_assignment_target_names(target.value)
if isinstance(target, ast.Subscript):
return _namespace_assignment_target_names(target.value)
if isinstance(target, (ast.List, ast.Tuple)):
names = set()
for item in target.elts:
names.update(_namespace_assignment_target_names(item))
return names
if isinstance(target, ast.Starred):
return _namespace_assignment_target_names(target.value)
return set()
def _class_body_global_names(cls):
names = set()
class GlobalVisitor(ast.NodeVisitor):
def visit_FunctionDef(self, node):
return None
def visit_AsyncFunctionDef(self, node):
return None
def visit_ClassDef(self, node):
return None
def visit_Global(self, node):
names.update(node.names)
for stmt in cls.body:
GlobalVisitor().visit(stmt)
return names
def _class_body_module_mutation_names(cls):
global_names = _class_body_global_names(cls)
names = set()
def add_assignment_targets(stmt):
names.update(_assignment_target_names(stmt).intersection(global_names))
if isinstance(stmt, ast.Assign):
for target in stmt.targets:
names.update(_namespace_assignment_target_names(target))
elif isinstance(stmt, (ast.AnnAssign, ast.AugAssign)):
names.update(_namespace_assignment_target_names(stmt.target))
elif isinstance(stmt, (ast.For, ast.AsyncFor)):
names.update(_namespace_assignment_target_names(stmt.target))
class ClassBodyMutationVisitor(ast.NodeVisitor):
def _visit_function_definition_expressions(self, node):
names.update(_mutating_call_target_names(node))
names.update(_namespace_alias_mutation_target_names(node, set()))
def visit_FunctionDef(self, node):
self._visit_function_definition_expressions(node)
def visit_AsyncFunctionDef(self, node):
self._visit_function_definition_expressions(node)
def visit_ClassDef(self, node):
for decorator in node.decorator_list:
self.visit(decorator)
for base in node.bases:
self.visit(base)
for keyword in node.keywords:
self.visit(keyword.value)
for type_param in getattr(node, "type_params", ()):
self.visit(type_param)
names.update(_class_body_module_mutation_names(node))
def visit_Assign(self, node):
add_assignment_targets(node)
self.visit(node.value)
def visit_AnnAssign(self, node):
add_assignment_targets(node)
if node.value is not None:
self.visit(node.value)
def visit_AugAssign(self, node):
add_assignment_targets(node)
self.visit(node.value)
def visit_Delete(self, node):
names.update(_delete_target_names(node).intersection(global_names))
for target in node.targets:
names.update(_namespace_assignment_target_names(target))
def visit_For(self, node):
add_assignment_targets(node)
self.generic_visit(node)
def visit_AsyncFor(self, node):
add_assignment_targets(node)
self.generic_visit(node)
def visit_With(self, node):
for item in node.items:
if item.optional_vars is not None:
names.update(_target_names(item.optional_vars).intersection(global_names))
names.update(_namespace_assignment_target_names(item.optional_vars))
self.generic_visit(node)
def visit_AsyncWith(self, node):
for item in node.items:
if item.optional_vars is not None:
names.update(_target_names(item.optional_vars).intersection(global_names))
names.update(_namespace_assignment_target_names(item.optional_vars))
self.generic_visit(node)
def visit_Import(self, node):
names.update(_bound_names(node).intersection(global_names))
def visit_ImportFrom(self, node):
names.update(_bound_names(node).intersection(global_names))
def visit_Call(self, node):
names.update(_namespace_mutating_call_target_names(node))
self.generic_visit(node)
for stmt in cls.body:
ClassBodyMutationVisitor().visit(stmt)
return names
def _apply_module_stmt_to_env(stmt, env, class_bindings=None):
names = _mutating_call_target_names(stmt)
if isinstance(stmt, ast.ClassDef):
names.update(_class_body_module_mutation_names(stmt))
if _DYNAMIC_NAMESPACE_MUTATION in names:
env.clear()
if class_bindings is not None:
@@ -1721,6 +1853,9 @@ def _final_module_dict(tree, name, value_converter, value_invalidated_by_names=N
_update_namespace_aliases(stmt, namespace_aliases)
for stmt in tree.body:
class_body_module_mutations = (
_class_body_module_mutation_names(stmt) if isinstance(stmt, ast.ClassDef) else set()
)
class_attr_names = _module_class_attribute_invalidated_names(stmt, class_aliases, class_attribute_aliases)
if (
value not in (_MISSING, _INVALID)
@@ -1732,6 +1867,9 @@ def _final_module_dict(tree, name, value_converter, value_invalidated_by_names=N
if _name_invalidated_by(name, _mutating_call_target_names(stmt)):
value = _INVALID
sticky_invalid = True
if _name_invalidated_by(name, class_body_module_mutations):
value = _INVALID
sticky_invalid = True
if _name_invalidated_by(name, _arbitrary_call_observed_names(stmt)):
value = _INVALID
sticky_invalid = True
@@ -1905,6 +2043,8 @@ def _signature_from_class(node_type, cls, display, pack_meta, class_env, input_e
for name, spec in values.items():
if not isinstance(name, str):
return None
if name in inputs:
return None
input_type = normalise_input_spec(spec)
if input_type is None:
return None