Detect mutating calls inside statements
This commit is contained in:
@@ -590,6 +590,33 @@ NODE_CLASS_MAPPINGS = {
|
||||
self.assertEqual({}, result["nodes"])
|
||||
self.assertEqual("no_static_nodes", result["pack"]["status"])
|
||||
|
||||
def test_rhs_mutating_call_invalidates_static_env_value(self):
|
||||
source = '''
|
||||
INPUTS = {
|
||||
"required": {
|
||||
"image": ("IMAGE",),
|
||||
},
|
||||
}
|
||||
X = INPUTS.clear()
|
||||
|
||||
|
||||
class RhsMutatedInputEnvNode:
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return INPUTS
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"RhsMutatedInputEnvNode": RhsMutatedInputEnvNode,
|
||||
}
|
||||
'''
|
||||
result = self._extract_source(source, "rhs-mutated-input-env-pack")
|
||||
|
||||
self.assertEqual({}, result["nodes"])
|
||||
self.assertEqual("no_static_nodes", result["pack"]["status"])
|
||||
|
||||
def test_nested_mutable_env_literal_skips_static_node(self):
|
||||
source = '''
|
||||
REQ = {
|
||||
@@ -784,6 +811,30 @@ NODE_CLASS_MAPPINGS = {
|
||||
self.assertEqual({}, result["nodes"])
|
||||
self.assertEqual("no_static_nodes", result["pack"]["status"])
|
||||
|
||||
def test_rhs_mutating_call_to_return_types_skips_node(self):
|
||||
source = '''
|
||||
class RhsMutatedReturnTypesNode:
|
||||
RETURN_TYPES = ["IMAGE"]
|
||||
X = RETURN_TYPES.pop()
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"image": ("IMAGE",),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"RhsMutatedReturnTypesNode": RhsMutatedReturnTypesNode,
|
||||
}
|
||||
'''
|
||||
result = self._extract_source(source, "rhs-mutated-return-types-pack")
|
||||
|
||||
self.assertEqual({}, result["nodes"])
|
||||
self.assertEqual("no_static_nodes", result["pack"]["status"])
|
||||
|
||||
def test_return_types_alias_mutation_skips_node(self):
|
||||
source = '''
|
||||
class AliasMutatedReturnTypesNode:
|
||||
@@ -1244,6 +1295,30 @@ NODE_CLASS_MAPPINGS.clear()
|
||||
self.assertEqual({}, result["nodes"])
|
||||
self.assertEqual("no_static_nodes", result["pack"]["status"])
|
||||
|
||||
def test_rhs_mutating_call_to_node_mapping_skips_node(self):
|
||||
source = '''
|
||||
class RhsMutatedMappingNode:
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"image": ("IMAGE",),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"RhsMutatedMappingNode": RhsMutatedMappingNode,
|
||||
}
|
||||
X = NODE_CLASS_MAPPINGS.clear()
|
||||
'''
|
||||
result = self._extract_source(source, "rhs-mutated-mapping-pack")
|
||||
|
||||
self.assertEqual({}, result["nodes"])
|
||||
self.assertEqual("no_static_nodes", result["pack"]["status"])
|
||||
|
||||
def test_annotated_alias_mutation_invalidates_static_node_mapping(self):
|
||||
source = '''
|
||||
class AnnotatedAliasMutatedMappingNode:
|
||||
|
||||
@@ -178,18 +178,32 @@ def _delete_target_names(stmt):
|
||||
|
||||
|
||||
def _mutating_call_target_names(stmt):
|
||||
if not isinstance(stmt, ast.Expr):
|
||||
return set()
|
||||
call = stmt.value
|
||||
if not isinstance(call, ast.Call) or not isinstance(call.func, ast.Attribute):
|
||||
return set()
|
||||
if call.func.attr not in _MUTATING_METHODS:
|
||||
return set()
|
||||
return _target_names(call.func.value)
|
||||
names = set()
|
||||
|
||||
class MutatingCallVisitor(ast.NodeVisitor):
|
||||
def visit_FunctionDef(self, node):
|
||||
return None
|
||||
|
||||
def visit_AsyncFunctionDef(self, node):
|
||||
return None
|
||||
|
||||
def visit_ClassDef(self, node):
|
||||
return None
|
||||
|
||||
def visit_Lambda(self, node):
|
||||
return None
|
||||
|
||||
def visit_Call(self, node):
|
||||
if isinstance(node.func, ast.Attribute) and node.func.attr in _MUTATING_METHODS:
|
||||
names.update(_target_names(node.func.value))
|
||||
self.generic_visit(node)
|
||||
|
||||
MutatingCallVisitor().visit(stmt)
|
||||
return names
|
||||
|
||||
|
||||
def _assigned_names_in_control_flow(stmt):
|
||||
names = set()
|
||||
names = _mutating_call_target_names(stmt)
|
||||
|
||||
class AssignmentVisitor(ast.NodeVisitor):
|
||||
def visit_FunctionDef(self, node):
|
||||
@@ -296,6 +310,10 @@ def _invalidate_class_bindings(class_bindings, names):
|
||||
def _collect_module_env(tree, class_bindings=None):
|
||||
env = {}
|
||||
for stmt in tree.body:
|
||||
names = _mutating_call_target_names(stmt)
|
||||
_invalidate_class_bindings(class_bindings, names)
|
||||
for name in names:
|
||||
env.pop(name, None)
|
||||
if isinstance(stmt, ast.ClassDef):
|
||||
if class_bindings is not None:
|
||||
class_bindings[stmt.name] = (stmt, dict(env))
|
||||
@@ -409,6 +427,11 @@ def _class_attr(cls, name, env):
|
||||
value = _MISSING
|
||||
aliases = set()
|
||||
for stmt in cls.body:
|
||||
mutating_targets = _mutating_call_target_names(stmt)
|
||||
if aliases.intersection(mutating_targets):
|
||||
value = _INVALID
|
||||
if name in mutating_targets:
|
||||
value = _INVALID
|
||||
if isinstance(stmt, ast.Assign):
|
||||
target_names = _assignment_target_names(stmt)
|
||||
if (
|
||||
@@ -510,6 +533,8 @@ def _class_attr(cls, name, env):
|
||||
def _input_types(cls, env):
|
||||
value = _MISSING
|
||||
for stmt in cls.body:
|
||||
if "INPUT_TYPES" in _mutating_call_target_names(stmt):
|
||||
value = _INVALID
|
||||
if isinstance(stmt, ast.FunctionDef) and stmt.name == "INPUT_TYPES":
|
||||
if len(stmt.body) != 1 or not isinstance(stmt.body[0], ast.Return):
|
||||
value = _INVALID
|
||||
@@ -580,6 +605,8 @@ def _module_dict_entries(node, env, value_converter):
|
||||
def _final_module_dict(tree, env, name, value_converter):
|
||||
value = _MISSING
|
||||
for stmt in tree.body:
|
||||
if name in _mutating_call_target_names(stmt):
|
||||
value = _INVALID
|
||||
if isinstance(stmt, ast.Assign):
|
||||
if not _name_is_assigned(stmt, name):
|
||||
if isinstance(stmt.value, ast.Name) and stmt.value.id == name:
|
||||
|
||||
Reference in New Issue
Block a user