Fail closed on dynamic patches and displays

This commit is contained in:
2026-07-02 16:27:48 +02:00
parent 9752248ee9
commit 7e4e85a0bd
2 changed files with 148 additions and 38 deletions
+88 -35
View File
@@ -59,12 +59,26 @@ def _literal(node, env, allow_mutable_env=True):
return result
if isinstance(node, ast.Name) and node.id in env:
value = env[node.id]
if value is _INVALID:
raise UnsupportedStaticExpression(f"unsupported env reference {node.id!r}")
if not allow_mutable_env and _is_mutable_static_value(value):
raise UnsupportedStaticExpression(f"mutable env reference {node.id!r} is not supported")
return value
raise UnsupportedStaticExpression(type(node).__name__)
def _invalidate_env_name(env, name):
if name == "classmethod":
env[name] = _INVALID
else:
env.pop(name, None)
def _invalidate_env_names(env, names):
for name in names:
_invalidate_env_name(env, name)
def _is_mutable_static_value(value):
return isinstance(value, (dict, list, set))
@@ -108,6 +122,24 @@ def _attribute_target_base_names(target):
return set()
def _setattr_delattr_target_names(node):
if not isinstance(node, ast.Call):
return set()
if not isinstance(node.func, ast.Name) or node.func.id not in {"delattr", "setattr"}:
return set()
if len(node.args) < 2:
return set()
attr = node.args[1]
if (
isinstance(attr, ast.Constant)
and isinstance(attr.value, str)
and attr.value not in _CLASS_SIGNATURE_ATTRS
):
return set()
name = _root_name(node.args[0])
return {name} if name else set()
def _class_attribute_mutation_target_names(stmt):
names = set()
@@ -161,6 +193,7 @@ def _class_attribute_mutation_target_names(stmt):
names.update(_attribute_target_base_names(target))
def visit_Call(self, node):
names.update(_setattr_delattr_target_names(node))
if isinstance(node.func, ast.Attribute) and node.func.attr in _MUTATING_METHODS:
names.update(_attribute_target_base_names(node.func.value))
self.generic_visit(node)
@@ -328,6 +361,7 @@ def _mutating_call_target_names(stmt):
self.visit(node.args)
def visit_Call(self, node):
names.update(_setattr_delattr_target_names(node))
if isinstance(node.func, ast.Attribute) and node.func.attr in _MUTATING_METHODS:
names.update(_target_names(node.func.value))
self.generic_visit(node)
@@ -451,15 +485,14 @@ def _invalidate_class_bindings(class_bindings, names):
def _apply_module_stmt_to_env(stmt, env, class_bindings=None):
names = _mutating_call_target_names(stmt)
_invalidate_class_bindings(class_bindings, names)
for name in names:
env.pop(name, None)
_invalidate_env_names(env, names)
if isinstance(stmt, ast.ClassDef):
if class_bindings is not None:
if stmt.decorator_list:
class_bindings.pop(stmt.name, None)
else:
class_bindings[stmt.name] = (stmt, dict(env))
env.pop(stmt.name, None)
_invalidate_env_name(env, stmt.name)
return
if isinstance(stmt, ast.Assign):
names = _assignment_target_names(stmt)
@@ -469,7 +502,7 @@ def _apply_module_stmt_to_env(stmt, env, class_bindings=None):
subscript_root = _mutable_env_subscript_root(stmt.value, env)
if subscript_root is not None:
env.pop(subscript_root, None)
env.pop(name, None)
_invalidate_env_name(env, name)
return
if (
isinstance(stmt.value, ast.Name)
@@ -477,15 +510,14 @@ def _apply_module_stmt_to_env(stmt, env, class_bindings=None):
and _is_mutable_static_value(env[stmt.value.id])
):
env.pop(stmt.value.id, None)
env.pop(name, None)
_invalidate_env_name(env, name)
return
try:
env[name] = _literal(stmt.value, env)
except UnsupportedStaticExpression:
env.pop(name, None)
_invalidate_env_name(env, name)
else:
for name in names:
env.pop(name, None)
_invalidate_env_names(env, names)
return
if isinstance(stmt, ast.AnnAssign):
names = _assignment_target_names(stmt)
@@ -497,7 +529,7 @@ def _apply_module_stmt_to_env(stmt, env, class_bindings=None):
subscript_root = _mutable_env_subscript_root(stmt.value, env)
if subscript_root is not None:
env.pop(subscript_root, None)
env.pop(name, None)
_invalidate_env_name(env, name)
return
if (
isinstance(stmt.value, ast.Name)
@@ -505,33 +537,29 @@ def _apply_module_stmt_to_env(stmt, env, class_bindings=None):
and _is_mutable_static_value(env[stmt.value.id])
):
env.pop(stmt.value.id, None)
env.pop(name, None)
_invalidate_env_name(env, name)
return
try:
env[name] = _literal(stmt.value, env)
except UnsupportedStaticExpression:
env.pop(name, None)
_invalidate_env_name(env, name)
else:
for name in names:
env.pop(name, None)
_invalidate_env_names(env, names)
return
if isinstance(stmt, ast.AugAssign):
names = _assignment_target_names(stmt)
_invalidate_class_bindings(class_bindings, names)
for name in names:
env.pop(name, None)
_invalidate_env_names(env, names)
return
if isinstance(stmt, ast.Delete):
names = _delete_target_names(stmt)
_invalidate_class_bindings(class_bindings, names)
for name in names:
env.pop(name, None)
_invalidate_env_names(env, names)
return
if isinstance(stmt, ast.Expr):
names = _bound_names(stmt)
_invalidate_class_bindings(class_bindings, names)
for name in names:
env.pop(name, None)
_invalidate_env_names(env, names)
return
if isinstance(stmt, _CONTROL_FLOW_TYPES):
if _has_wildcard_import_in_control_flow(stmt):
@@ -541,8 +569,7 @@ def _apply_module_stmt_to_env(stmt, env, class_bindings=None):
return
names = _assigned_names_in_control_flow(stmt)
_invalidate_class_bindings(class_bindings, names)
for name in names:
env.pop(name, None)
_invalidate_env_names(env, names)
return
if _has_wildcard_import(stmt):
env.clear()
@@ -551,8 +578,7 @@ def _apply_module_stmt_to_env(stmt, env, class_bindings=None):
return
names = _bound_names(stmt)
_invalidate_class_bindings(class_bindings, names)
for name in names:
env.pop(name, None)
_invalidate_env_names(env, names)
def _collect_module_env(tree, class_bindings=None):
@@ -586,8 +612,13 @@ def _mutable_env_subscript_root(node, env):
return None
def _input_types_decorators_are_supported(decorators):
return all(isinstance(decorator, ast.Name) and decorator.id == "classmethod" for decorator in decorators)
def _input_types_decorators_are_supported(decorators, classmethod_shadowed):
for decorator in decorators:
if not isinstance(decorator, ast.Name) or decorator.id != "classmethod":
return False
if classmethod_shadowed:
return False
return True
def _class_attr_alias_sources(value, name, aliases):
@@ -721,13 +752,14 @@ def _class_attr(cls, name, env):
return value
def _input_types(cls, env):
def _input_types(cls, env, decorator_env):
value = _MISSING
classmethod_shadowed = "classmethod" in decorator_env
for stmt in cls.body:
if "INPUT_TYPES" in _mutating_call_target_names(stmt):
value = _INVALID
if isinstance(stmt, ast.FunctionDef) and stmt.name == "INPUT_TYPES":
if not _input_types_decorators_are_supported(stmt.decorator_list):
if not _input_types_decorators_are_supported(stmt.decorator_list, classmethod_shadowed):
value = _INVALID
continue
if len(stmt.body) != 1 or not isinstance(stmt.body[0], ast.Return):
@@ -743,6 +775,13 @@ def _input_types(cls, env):
if isinstance(stmt, ast.AsyncFunctionDef) and stmt.name == "INPUT_TYPES":
value = _INVALID
continue
if "classmethod" in (
_assignment_target_names(stmt)
| _delete_target_names(stmt)
| _bound_names(stmt)
| _mutating_call_target_names(stmt)
):
classmethod_shadowed = True
if isinstance(stmt, (ast.Assign, ast.AnnAssign, ast.AugAssign)):
if "INPUT_TYPES" in _assignment_target_names(stmt):
value = _INVALID
@@ -894,7 +933,16 @@ def _update_class_attribute_aliases(stmt, class_attribute_aliases, class_aliases
class_attribute_aliases[stmt.target.id] = sources
def _final_module_dict(tree, name, value_converter, value_invalidated_by_names=None):
def _module_class_attribute_invalidated_names(stmt, class_aliases, class_attribute_aliases):
names = _expanded_class_attribute_names(
_class_attribute_mutation_target_names(stmt),
class_aliases,
)
names.update(_class_attribute_alias_invalidated_names(stmt, class_attribute_aliases))
return names
def _final_module_dict(tree, name, value_converter, value_invalidated_by_names=None, return_state=False):
value_invalidated_by_names = value_invalidated_by_names or (lambda _value, _names: False)
value = _MISSING
env = {}
@@ -905,18 +953,14 @@ def _final_module_dict(tree, name, value_converter, value_invalidated_by_names=N
def advance_module_state(stmt):
_invalidate_class_bindings(
class_bindings,
_class_attribute_alias_invalidated_names(stmt, class_attribute_aliases),
_module_class_attribute_invalidated_names(stmt, class_aliases, class_attribute_aliases),
)
_apply_module_stmt_to_env(stmt, env, class_bindings)
_update_class_aliases(stmt, class_aliases, class_bindings)
_update_class_attribute_aliases(stmt, class_attribute_aliases, class_aliases, class_bindings)
for stmt in tree.body:
class_attr_names = _expanded_class_attribute_names(
_class_attribute_mutation_target_names(stmt),
class_aliases,
)
class_attr_names.update(_class_attribute_alias_invalidated_names(stmt, class_attribute_aliases))
class_attr_names = _module_class_attribute_invalidated_names(stmt, class_aliases, class_attribute_aliases)
if (
value not in (_MISSING, _INVALID)
and class_attr_names
@@ -986,6 +1030,8 @@ def _final_module_dict(tree, name, value_converter, value_invalidated_by_names=N
if name in _bound_names(stmt):
value = _INVALID
advance_module_state(stmt)
if return_state:
return value
if value in (_MISSING, _INVALID):
return {}
return value
@@ -1022,12 +1068,17 @@ def _display_mappings(tree):
tree,
"NODE_DISPLAY_NAME_MAPPINGS",
lambda value, env, _class_bindings: _literal(value, env),
return_state=True,
)
if displays is _MISSING:
return {}
if displays is _INVALID:
return _INVALID
return {str(k): str(v) for k, v in displays.items()}
def _signature_from_class(node_type, cls, display, pack_meta, class_env, input_env):
input_types = _input_types(cls, input_env)
input_types = _input_types(cls, input_env, class_env)
return_types = _class_attr(cls, "RETURN_TYPES", class_env)
return_names = _class_attr(cls, "RETURN_NAMES", class_env)
if return_types is _INVALID or return_names is _INVALID:
@@ -1097,6 +1148,8 @@ def extract_repo_signatures(repo_dir, pack_meta):
env = _collect_module_env(tree)
mappings = _node_class_mappings(tree)
displays = _display_mappings(tree)
if displays is _INVALID:
continue
for node_type, binding in sorted(mappings.items()):
cls, class_env = binding
sig = _signature_from_class(node_type, cls, displays.get(node_type), pack_meta, class_env, env)