Fail closed on no-arg arbitrary calls
This commit is contained in:
@@ -2458,6 +2458,30 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
self.assertEqual({}, result["nodes"])
|
self.assertEqual({}, result["nodes"])
|
||||||
self.assertEqual("no_static_nodes", result["pack"]["status"])
|
self.assertEqual("no_static_nodes", result["pack"]["status"])
|
||||||
|
|
||||||
|
def test_no_arg_class_body_arbitrary_call_after_return_types_skips_node(self):
|
||||||
|
source = '''
|
||||||
|
class NoArgClassBodyCallReturnTypesNode:
|
||||||
|
RETURN_TYPES = ["IMAGE"]
|
||||||
|
mutate()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"image": ("IMAGE",),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"NoArgClassBodyCallReturnTypesNode": NoArgClassBodyCallReturnTypesNode,
|
||||||
|
}
|
||||||
|
'''
|
||||||
|
result = self._extract_source(source, "no-arg-class-body-call-return-types-pack")
|
||||||
|
|
||||||
|
self.assertEqual({}, result["nodes"])
|
||||||
|
self.assertEqual("no_static_nodes", result["pack"]["status"])
|
||||||
|
|
||||||
def test_return_types_alias_arbitrary_call_skips_node(self):
|
def test_return_types_alias_arbitrary_call_skips_node(self):
|
||||||
source = '''
|
source = '''
|
||||||
class AliasArbitraryCallReturnTypesNode:
|
class AliasArbitraryCallReturnTypesNode:
|
||||||
@@ -3705,6 +3729,34 @@ mutate(ObserveAttributeCallNode.RETURN_TYPES)
|
|||||||
self.assertEqual({}, result["nodes"])
|
self.assertEqual({}, result["nodes"])
|
||||||
self.assertEqual("no_static_nodes", result["pack"]["status"])
|
self.assertEqual("no_static_nodes", result["pack"]["status"])
|
||||||
|
|
||||||
|
def test_no_arg_arbitrary_call_after_mapping_skips_mapped_class_signature(self):
|
||||||
|
source = '''
|
||||||
|
class NoArgMappedClassMutationNode:
|
||||||
|
RETURN_TYPES = ["IMAGE"]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"image": ("IMAGE",),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def mutate():
|
||||||
|
NoArgMappedClassMutationNode.RETURN_TYPES.clear()
|
||||||
|
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"NoArgMappedClassMutationNode": NoArgMappedClassMutationNode,
|
||||||
|
}
|
||||||
|
mutate()
|
||||||
|
'''
|
||||||
|
result = self._extract_source(source, "no-arg-mapped-class-mutation-pack")
|
||||||
|
|
||||||
|
self.assertEqual({}, result["nodes"])
|
||||||
|
self.assertEqual("no_static_nodes", result["pack"]["status"])
|
||||||
|
|
||||||
def test_module_class_attribute_alias_arbitrary_call_after_mapping_skips_node(self):
|
def test_module_class_attribute_alias_arbitrary_call_after_mapping_skips_node(self):
|
||||||
source = '''
|
source = '''
|
||||||
class ObserveAttributeAliasCallNode:
|
class ObserveAttributeAliasCallNode:
|
||||||
@@ -4308,6 +4360,30 @@ mutate(NODE_CLASS_MAPPINGS)
|
|||||||
self.assertEqual({}, result["nodes"])
|
self.assertEqual({}, result["nodes"])
|
||||||
self.assertEqual("no_static_nodes", result["pack"]["status"])
|
self.assertEqual("no_static_nodes", result["pack"]["status"])
|
||||||
|
|
||||||
|
def test_no_arg_arbitrary_call_after_node_mapping_skips_node(self):
|
||||||
|
source = '''
|
||||||
|
class NoArgMappingCallNode:
|
||||||
|
RETURN_TYPES = ("IMAGE",)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"image": ("IMAGE",),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"NoArgMappingCallNode": NoArgMappingCallNode,
|
||||||
|
}
|
||||||
|
mutate()
|
||||||
|
'''
|
||||||
|
result = self._extract_source(source, "no-arg-mapping-call-pack")
|
||||||
|
|
||||||
|
self.assertEqual({}, result["nodes"])
|
||||||
|
self.assertEqual("no_static_nodes", result["pack"]["status"])
|
||||||
|
|
||||||
def test_alias_arbitrary_call_invalidates_static_node_mapping(self):
|
def test_alias_arbitrary_call_invalidates_static_node_mapping(self):
|
||||||
source = '''
|
source = '''
|
||||||
class AliasArbitraryCallMappingNode:
|
class AliasArbitraryCallMappingNode:
|
||||||
|
|||||||
@@ -663,6 +663,48 @@ def _arbitrary_call_observed_names(stmt):
|
|||||||
return names
|
return names
|
||||||
|
|
||||||
|
|
||||||
|
def _has_arbitrary_call(stmt):
|
||||||
|
found = False
|
||||||
|
|
||||||
|
class ArbitraryCallPresenceVisitor(ast.NodeVisitor):
|
||||||
|
def _visit_function_definition_expressions(self, node):
|
||||||
|
for decorator in node.decorator_list:
|
||||||
|
self.visit(decorator)
|
||||||
|
self.visit(node.args)
|
||||||
|
if node.returns is not None:
|
||||||
|
self.visit(node.returns)
|
||||||
|
for type_param in getattr(node, "type_params", ()):
|
||||||
|
self.visit(type_param)
|
||||||
|
|
||||||
|
def visit_FunctionDef(self, node):
|
||||||
|
self._visit_function_definition_expressions(node)
|
||||||
|
|
||||||
|
def visit_AsyncFunctionDef(self, node):
|
||||||
|
self._visit_function_definition_expressions(node)
|
||||||
|
|
||||||
|
def visit_ClassDef(self, node):
|
||||||
|
for decorator in node.decorator_list:
|
||||||
|
self.visit(decorator)
|
||||||
|
for base in node.bases:
|
||||||
|
self.visit(base)
|
||||||
|
for keyword in node.keywords:
|
||||||
|
self.visit(keyword.value)
|
||||||
|
for type_param in getattr(node, "type_params", ()):
|
||||||
|
self.visit(type_param)
|
||||||
|
for child in node.body:
|
||||||
|
self.visit(child)
|
||||||
|
|
||||||
|
def visit_Lambda(self, node):
|
||||||
|
self.visit(node.args)
|
||||||
|
|
||||||
|
def visit_Call(self, node):
|
||||||
|
nonlocal found
|
||||||
|
found = True
|
||||||
|
|
||||||
|
ArbitraryCallPresenceVisitor().visit(stmt)
|
||||||
|
return found
|
||||||
|
|
||||||
|
|
||||||
def _definition_time_referenced_names(stmt):
|
def _definition_time_referenced_names(stmt):
|
||||||
names = set()
|
names = set()
|
||||||
|
|
||||||
@@ -1010,6 +1052,14 @@ def _apply_module_stmt_to_env(stmt, env, class_bindings=None):
|
|||||||
for name in observed_names:
|
for name in observed_names:
|
||||||
if name in env and _is_mutable_static_value(env[name]):
|
if name in env and _is_mutable_static_value(env[name]):
|
||||||
_invalidate_env_name(env, name)
|
_invalidate_env_name(env, name)
|
||||||
|
if _has_arbitrary_call(stmt):
|
||||||
|
for name, value in list(env.items()):
|
||||||
|
if _is_mutable_static_value(value):
|
||||||
|
_invalidate_env_name(env, name)
|
||||||
|
if class_bindings is not None and not isinstance(
|
||||||
|
stmt, (ast.Assign, ast.AnnAssign, ast.AugAssign, ast.Delete)
|
||||||
|
):
|
||||||
|
class_bindings.clear()
|
||||||
if isinstance(stmt, ast.ClassDef):
|
if isinstance(stmt, ast.ClassDef):
|
||||||
if class_bindings is not None:
|
if class_bindings is not None:
|
||||||
if _is_trivially_safe_class_def(stmt):
|
if _is_trivially_safe_class_def(stmt):
|
||||||
@@ -1236,6 +1286,9 @@ def _class_attr(cls, name, env):
|
|||||||
mutating_targets = _mutating_call_target_names(stmt)
|
mutating_targets = _mutating_call_target_names(stmt)
|
||||||
observed_targets = _arbitrary_call_observed_names(stmt)
|
observed_targets = _arbitrary_call_observed_names(stmt)
|
||||||
expression_references = _class_body_expression_referenced_names(stmt)
|
expression_references = _class_body_expression_referenced_names(stmt)
|
||||||
|
has_arbitrary_call = _has_arbitrary_call(stmt)
|
||||||
|
if value not in (_MISSING, _INVALID) and has_arbitrary_call:
|
||||||
|
value = _INVALID
|
||||||
if aliases.intersection(mutating_targets):
|
if aliases.intersection(mutating_targets):
|
||||||
value = _INVALID
|
value = _INVALID
|
||||||
if name in mutating_targets:
|
if name in mutating_targets:
|
||||||
@@ -1376,9 +1429,11 @@ def _input_types(cls, env, decorator_env):
|
|||||||
observed_targets = _arbitrary_call_observed_names(stmt)
|
observed_targets = _arbitrary_call_observed_names(stmt)
|
||||||
definition_references = _definition_time_referenced_names(stmt)
|
definition_references = _definition_time_referenced_names(stmt)
|
||||||
expression_references = _class_body_expression_referenced_names(stmt)
|
expression_references = _class_body_expression_referenced_names(stmt)
|
||||||
|
has_arbitrary_call = _has_arbitrary_call(stmt)
|
||||||
protected_definition_references = _CLASS_SIGNATURE_ATTRS | aliases
|
protected_definition_references = _CLASS_SIGNATURE_ATTRS | aliases
|
||||||
input_types_invalidated = (
|
input_types_invalidated = (
|
||||||
"INPUT_TYPES" in mutating_targets
|
(value not in (_MISSING, _INVALID) and has_arbitrary_call)
|
||||||
|
or "INPUT_TYPES" in mutating_targets
|
||||||
or bool(aliases.intersection(mutating_targets))
|
or bool(aliases.intersection(mutating_targets))
|
||||||
or "INPUT_TYPES" in observed_targets
|
or "INPUT_TYPES" in observed_targets
|
||||||
or bool(aliases.intersection(observed_targets))
|
or bool(aliases.intersection(observed_targets))
|
||||||
@@ -1389,6 +1444,10 @@ def _input_types(cls, env, decorator_env):
|
|||||||
value = _INVALID
|
value = _INVALID
|
||||||
sticky_invalid = True
|
sticky_invalid = True
|
||||||
if isinstance(stmt, ast.FunctionDef) and stmt.name == "INPUT_TYPES":
|
if isinstance(stmt, ast.FunctionDef) and stmt.name == "INPUT_TYPES":
|
||||||
|
if has_arbitrary_call:
|
||||||
|
value = _INVALID
|
||||||
|
sticky_invalid = True
|
||||||
|
continue
|
||||||
if input_types_invalidated or sticky_invalid:
|
if input_types_invalidated or sticky_invalid:
|
||||||
continue
|
continue
|
||||||
if not _input_types_decorators_are_supported(stmt.decorator_list, classmethod_shadowed):
|
if not _input_types_decorators_are_supported(stmt.decorator_list, classmethod_shadowed):
|
||||||
@@ -2053,6 +2112,9 @@ def _final_module_dict(tree, name, value_converter, value_invalidated_by_names=N
|
|||||||
if _name_invalidated_by(name, _arbitrary_call_observed_names(stmt)):
|
if _name_invalidated_by(name, _arbitrary_call_observed_names(stmt)):
|
||||||
value = _INVALID
|
value = _INVALID
|
||||||
sticky_invalid = True
|
sticky_invalid = True
|
||||||
|
if value not in (_MISSING, _INVALID) and _has_arbitrary_call(stmt):
|
||||||
|
value = _INVALID
|
||||||
|
sticky_invalid = True
|
||||||
if _name_invalidated_by(name, _namespace_alias_mutation_target_names(stmt, namespace_aliases)):
|
if _name_invalidated_by(name, _namespace_alias_mutation_target_names(stmt, namespace_aliases)):
|
||||||
value = _INVALID
|
value = _INVALID
|
||||||
sticky_invalid = True
|
sticky_invalid = True
|
||||||
|
|||||||
Reference in New Issue
Block a user