Fail closed on no-arg arbitrary calls

This commit is contained in:
2026-07-02 21:09:10 +02:00
parent 126f5db959
commit ee8496174f
2 changed files with 139 additions and 1 deletions
+63 -1
View File
@@ -663,6 +663,48 @@ def _arbitrary_call_observed_names(stmt):
return names
def _has_arbitrary_call(stmt):
found = False
class ArbitraryCallPresenceVisitor(ast.NodeVisitor):
def _visit_function_definition_expressions(self, node):
for decorator in node.decorator_list:
self.visit(decorator)
self.visit(node.args)
if node.returns is not None:
self.visit(node.returns)
for type_param in getattr(node, "type_params", ()):
self.visit(type_param)
def visit_FunctionDef(self, node):
self._visit_function_definition_expressions(node)
def visit_AsyncFunctionDef(self, node):
self._visit_function_definition_expressions(node)
def visit_ClassDef(self, node):
for decorator in node.decorator_list:
self.visit(decorator)
for base in node.bases:
self.visit(base)
for keyword in node.keywords:
self.visit(keyword.value)
for type_param in getattr(node, "type_params", ()):
self.visit(type_param)
for child in node.body:
self.visit(child)
def visit_Lambda(self, node):
self.visit(node.args)
def visit_Call(self, node):
nonlocal found
found = True
ArbitraryCallPresenceVisitor().visit(stmt)
return found
def _definition_time_referenced_names(stmt):
names = set()
@@ -1010,6 +1052,14 @@ def _apply_module_stmt_to_env(stmt, env, class_bindings=None):
for name in observed_names:
if name in env and _is_mutable_static_value(env[name]):
_invalidate_env_name(env, name)
if _has_arbitrary_call(stmt):
for name, value in list(env.items()):
if _is_mutable_static_value(value):
_invalidate_env_name(env, name)
if class_bindings is not None and not isinstance(
stmt, (ast.Assign, ast.AnnAssign, ast.AugAssign, ast.Delete)
):
class_bindings.clear()
if isinstance(stmt, ast.ClassDef):
if class_bindings is not None:
if _is_trivially_safe_class_def(stmt):
@@ -1236,6 +1286,9 @@ def _class_attr(cls, name, env):
mutating_targets = _mutating_call_target_names(stmt)
observed_targets = _arbitrary_call_observed_names(stmt)
expression_references = _class_body_expression_referenced_names(stmt)
has_arbitrary_call = _has_arbitrary_call(stmt)
if value not in (_MISSING, _INVALID) and has_arbitrary_call:
value = _INVALID
if aliases.intersection(mutating_targets):
value = _INVALID
if name in mutating_targets:
@@ -1376,9 +1429,11 @@ def _input_types(cls, env, decorator_env):
observed_targets = _arbitrary_call_observed_names(stmt)
definition_references = _definition_time_referenced_names(stmt)
expression_references = _class_body_expression_referenced_names(stmt)
has_arbitrary_call = _has_arbitrary_call(stmt)
protected_definition_references = _CLASS_SIGNATURE_ATTRS | aliases
input_types_invalidated = (
"INPUT_TYPES" in mutating_targets
(value not in (_MISSING, _INVALID) and has_arbitrary_call)
or "INPUT_TYPES" in mutating_targets
or bool(aliases.intersection(mutating_targets))
or "INPUT_TYPES" in observed_targets
or bool(aliases.intersection(observed_targets))
@@ -1389,6 +1444,10 @@ def _input_types(cls, env, decorator_env):
value = _INVALID
sticky_invalid = True
if isinstance(stmt, ast.FunctionDef) and stmt.name == "INPUT_TYPES":
if has_arbitrary_call:
value = _INVALID
sticky_invalid = True
continue
if input_types_invalidated or sticky_invalid:
continue
if not _input_types_decorators_are_supported(stmt.decorator_list, classmethod_shadowed):
@@ -2053,6 +2112,9 @@ def _final_module_dict(tree, name, value_converter, value_invalidated_by_names=N
if _name_invalidated_by(name, _arbitrary_call_observed_names(stmt)):
value = _INVALID
sticky_invalid = True
if value not in (_MISSING, _INVALID) and _has_arbitrary_call(stmt):
value = _INVALID
sticky_invalid = True
if _name_invalidated_by(name, _namespace_alias_mutation_target_names(stmt, namespace_aliases)):
value = _INVALID
sticky_invalid = True