Fail closed on final input bindings and transitive aliases
This commit is contained in:
@@ -35,22 +35,25 @@ _MUTATING_METHODS = {
|
||||
}
|
||||
|
||||
|
||||
def _literal(node, env):
|
||||
def _literal(node, env, allow_mutable_env=True):
|
||||
if isinstance(node, ast.Constant):
|
||||
return node.value
|
||||
if isinstance(node, ast.List):
|
||||
return [_literal(item, env) for item in node.elts]
|
||||
return [_literal(item, env, allow_mutable_env=False) for item in node.elts]
|
||||
if isinstance(node, ast.Tuple):
|
||||
return tuple(_literal(item, env) for item in node.elts)
|
||||
return tuple(_literal(item, env, allow_mutable_env=False) for item in node.elts)
|
||||
if isinstance(node, ast.Dict):
|
||||
result = {}
|
||||
for key, value in zip(node.keys, node.values):
|
||||
if key is None:
|
||||
raise UnsupportedStaticExpression("dict unpacking is not supported")
|
||||
result[_literal(key, env)] = _literal(value, env)
|
||||
result[_literal(key, env, allow_mutable_env=False)] = _literal(value, env, allow_mutable_env=False)
|
||||
return result
|
||||
if isinstance(node, ast.Name) and node.id in env:
|
||||
return env[node.id]
|
||||
value = env[node.id]
|
||||
if not allow_mutable_env and _is_mutable_static_value(value):
|
||||
raise UnsupportedStaticExpression(f"mutable env reference {node.id!r} is not supported")
|
||||
return value
|
||||
raise UnsupportedStaticExpression(type(node).__name__)
|
||||
|
||||
|
||||
@@ -413,7 +416,7 @@ def _class_attr(cls, name, env):
|
||||
and isinstance(stmt.targets[0], ast.Name)
|
||||
and stmt.targets[0].id != name
|
||||
and isinstance(stmt.value, ast.Name)
|
||||
and stmt.value.id == name
|
||||
and (stmt.value.id == name or stmt.value.id in aliases)
|
||||
):
|
||||
aliases.add(stmt.targets[0].id)
|
||||
continue
|
||||
@@ -439,7 +442,7 @@ def _class_attr(cls, name, env):
|
||||
isinstance(stmt.target, ast.Name)
|
||||
and stmt.target.id != name
|
||||
and isinstance(stmt.value, ast.Name)
|
||||
and stmt.value.id == name
|
||||
and (stmt.value.id == name or stmt.value.id in aliases)
|
||||
):
|
||||
aliases.add(stmt.target.id)
|
||||
continue
|
||||
@@ -505,17 +508,47 @@ def _class_attr(cls, name, env):
|
||||
|
||||
|
||||
def _input_types(cls, env):
|
||||
value = _MISSING
|
||||
for stmt in cls.body:
|
||||
if not isinstance(stmt, ast.FunctionDef) or stmt.name != "INPUT_TYPES":
|
||||
if isinstance(stmt, ast.FunctionDef) and stmt.name == "INPUT_TYPES":
|
||||
if len(stmt.body) != 1 or not isinstance(stmt.body[0], ast.Return):
|
||||
value = _INVALID
|
||||
continue
|
||||
try:
|
||||
candidate = _literal(stmt.body[0].value, env)
|
||||
except UnsupportedStaticExpression:
|
||||
value = _INVALID
|
||||
continue
|
||||
value = candidate if isinstance(candidate, dict) else _INVALID
|
||||
continue
|
||||
if len(stmt.body) != 1 or not isinstance(stmt.body[0], ast.Return):
|
||||
return None
|
||||
try:
|
||||
value = _literal(stmt.body[0].value, env)
|
||||
except UnsupportedStaticExpression:
|
||||
return None
|
||||
return value if isinstance(value, dict) else None
|
||||
return None
|
||||
if isinstance(stmt, ast.AsyncFunctionDef) and stmt.name == "INPUT_TYPES":
|
||||
value = _INVALID
|
||||
continue
|
||||
if isinstance(stmt, (ast.Assign, ast.AnnAssign, ast.AugAssign)):
|
||||
if "INPUT_TYPES" in _assignment_target_names(stmt):
|
||||
value = _INVALID
|
||||
continue
|
||||
if isinstance(stmt, ast.Delete):
|
||||
if "INPUT_TYPES" in _delete_target_names(stmt):
|
||||
value = _INVALID
|
||||
continue
|
||||
if isinstance(stmt, ast.Expr):
|
||||
if "INPUT_TYPES" in _mutating_call_target_names(stmt):
|
||||
value = _INVALID
|
||||
if "INPUT_TYPES" in _bound_names(stmt):
|
||||
value = _INVALID
|
||||
continue
|
||||
if isinstance(stmt, (ast.If, ast.For, ast.AsyncFor, ast.While, ast.Try, ast.With, ast.AsyncWith, ast.Match)):
|
||||
if "INPUT_TYPES" in _assigned_names_in_control_flow(stmt):
|
||||
value = _INVALID
|
||||
if _has_wildcard_import_in_control_flow(stmt):
|
||||
value = _INVALID
|
||||
continue
|
||||
if "INPUT_TYPES" in _bound_names(stmt):
|
||||
value = _INVALID
|
||||
if value in (_MISSING, _INVALID):
|
||||
return None
|
||||
return value
|
||||
|
||||
|
||||
def _mapping_value_name(value):
|
||||
|
||||
Reference in New Issue
Block a user