Track mapping and class attribute aliases

This commit is contained in:
2026-07-02 16:47:39 +02:00
parent f26e441e03
commit d1f49e7c95
2 changed files with 177 additions and 3 deletions
+99 -3
View File
@@ -894,7 +894,9 @@ def _expanded_class_attribute_names(names, class_aliases):
return expanded
def _class_attribute_alias_sources(value, class_aliases, class_bindings):
def _class_attribute_alias_sources(value, class_attribute_aliases, class_aliases, class_bindings):
if isinstance(value, ast.Name):
return set(class_attribute_aliases.get(value.id, ()))
if not isinstance(value, ast.Attribute) or value.attr not in _CLASS_SIGNATURE_ATTRS:
return set()
name = _root_name(value.value)
@@ -905,6 +907,30 @@ def _class_attribute_alias_sources(value, class_aliases, class_bindings):
return set()
def _update_class_attribute_alias_from_unpack(
target,
value,
class_attribute_aliases,
class_aliases,
class_bindings,
):
if not isinstance(target, (ast.Tuple, ast.List)) or not isinstance(value, (ast.Tuple, ast.List)):
return
if len(target.elts) != len(value.elts):
return
for target_item, value_item in zip(target.elts, value.elts):
if not isinstance(target_item, ast.Name):
continue
sources = _class_attribute_alias_sources(
value_item,
class_attribute_aliases,
class_aliases,
class_bindings,
)
if sources:
class_attribute_aliases[target_item.id] = sources
def _class_attribute_alias_invalidated_names(stmt, class_attribute_aliases):
names = (
_mutating_call_target_names(stmt)
@@ -924,11 +950,29 @@ def _update_class_attribute_aliases(stmt, class_attribute_aliases, class_aliases
class_attribute_aliases.pop(name, None)
if isinstance(stmt, ast.Assign) and len(stmt.targets) == 1 and isinstance(stmt.targets[0], ast.Name):
sources = _class_attribute_alias_sources(stmt.value, class_aliases, class_bindings)
sources = _class_attribute_alias_sources(
stmt.value,
class_attribute_aliases,
class_aliases,
class_bindings,
)
if sources:
class_attribute_aliases[stmt.targets[0].id] = sources
elif isinstance(stmt, ast.Assign) and len(stmt.targets) == 1:
_update_class_attribute_alias_from_unpack(
stmt.targets[0],
stmt.value,
class_attribute_aliases,
class_aliases,
class_bindings,
)
elif isinstance(stmt, ast.AnnAssign) and isinstance(stmt.target, ast.Name) and stmt.value is not None:
sources = _class_attribute_alias_sources(stmt.value, class_aliases, class_bindings)
sources = _class_attribute_alias_sources(
stmt.value,
class_attribute_aliases,
class_aliases,
class_bindings,
)
if sources:
class_attribute_aliases[stmt.target.id] = sources
@@ -942,6 +986,54 @@ def _module_class_attribute_invalidated_names(stmt, class_aliases, class_attribu
return names
def _module_dict_alias_sources(value, name, aliases):
if not isinstance(value, ast.Name):
return set()
if value.id == name:
return {name}
return set(aliases.get(value.id, ()))
def _update_module_dict_alias_from_unpack(target, value, name, aliases):
if not isinstance(target, (ast.Tuple, ast.List)) or not isinstance(value, (ast.Tuple, ast.List)):
return
if len(target.elts) != len(value.elts):
return
for target_item, value_item in zip(target.elts, value.elts):
if not isinstance(target_item, ast.Name):
continue
sources = _module_dict_alias_sources(value_item, name, aliases)
if sources:
aliases[target_item.id] = sources
def _module_dict_alias_invalidated(stmt, aliases):
names = (
_mutating_call_target_names(stmt)
| _assignment_target_names(stmt)
| _delete_target_names(stmt)
| _bound_names(stmt)
)
return any(name in aliases for name in names)
def _update_module_dict_aliases(stmt, name, aliases):
rebound_names = _assignment_target_names(stmt) | _delete_target_names(stmt) | _bound_names(stmt)
for rebound_name in rebound_names:
aliases.pop(rebound_name, None)
if isinstance(stmt, ast.Assign) and len(stmt.targets) == 1 and isinstance(stmt.targets[0], ast.Name):
sources = _module_dict_alias_sources(stmt.value, name, aliases)
if sources:
aliases[stmt.targets[0].id] = sources
elif isinstance(stmt, ast.Assign) and len(stmt.targets) == 1:
_update_module_dict_alias_from_unpack(stmt.targets[0], stmt.value, name, aliases)
elif isinstance(stmt, ast.AnnAssign) and isinstance(stmt.target, ast.Name) and stmt.value is not None:
sources = _module_dict_alias_sources(stmt.value, name, aliases)
if sources:
aliases[stmt.target.id] = sources
def _final_module_dict(tree, name, value_converter, value_invalidated_by_names=None, return_state=False):
value_invalidated_by_names = value_invalidated_by_names or (lambda _value, _names: False)
value = _MISSING
@@ -949,6 +1041,7 @@ def _final_module_dict(tree, name, value_converter, value_invalidated_by_names=N
class_bindings = {}
class_aliases = {}
class_attribute_aliases = {}
module_dict_aliases = {}
def advance_module_state(stmt):
_invalidate_class_bindings(
@@ -958,6 +1051,7 @@ def _final_module_dict(tree, name, value_converter, value_invalidated_by_names=N
_apply_module_stmt_to_env(stmt, env, class_bindings)
_update_class_aliases(stmt, class_aliases, class_bindings)
_update_class_attribute_aliases(stmt, class_attribute_aliases, class_aliases, class_bindings)
_update_module_dict_aliases(stmt, name, module_dict_aliases)
for stmt in tree.body:
class_attr_names = _module_class_attribute_invalidated_names(stmt, class_aliases, class_attribute_aliases)
@@ -969,6 +1063,8 @@ def _final_module_dict(tree, name, value_converter, value_invalidated_by_names=N
value = _INVALID
if name in _mutating_call_target_names(stmt):
value = _INVALID
if _module_dict_alias_invalidated(stmt, module_dict_aliases):
value = _INVALID
if isinstance(stmt, ast.Assign):
if not _name_is_assigned(stmt, name):
if isinstance(stmt.value, ast.Name) and stmt.value.id == name: