Fail closed on final input bindings and transitive aliases
This commit is contained in:
@@ -590,6 +590,34 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
self.assertEqual({}, result["nodes"])
|
self.assertEqual({}, result["nodes"])
|
||||||
self.assertEqual("no_static_nodes", result["pack"]["status"])
|
self.assertEqual("no_static_nodes", result["pack"]["status"])
|
||||||
|
|
||||||
|
def test_nested_mutable_env_literal_skips_static_node(self):
|
||||||
|
source = '''
|
||||||
|
REQ = {
|
||||||
|
"image": ("IMAGE",),
|
||||||
|
}
|
||||||
|
INPUTS = {
|
||||||
|
"required": REQ,
|
||||||
|
}
|
||||||
|
REQ.clear()
|
||||||
|
|
||||||
|
|
||||||
|
class NestedMutableEnvLiteralNode:
|
||||||
|
RETURN_TYPES = ("IMAGE",)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return INPUTS
|
||||||
|
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"NestedMutableEnvLiteralNode": NestedMutableEnvLiteralNode,
|
||||||
|
}
|
||||||
|
'''
|
||||||
|
result = self._extract_source(source, "nested-mutable-env-literal-pack")
|
||||||
|
|
||||||
|
self.assertEqual({}, result["nodes"])
|
||||||
|
self.assertEqual("no_static_nodes", result["pack"]["status"])
|
||||||
|
|
||||||
def test_post_class_input_reassignment_skips_static_node(self):
|
def test_post_class_input_reassignment_skips_static_node(self):
|
||||||
source = '''
|
source = '''
|
||||||
def build_inputs():
|
def build_inputs():
|
||||||
@@ -650,6 +678,36 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
self.assertIn("LiteralInputTypesAfterEnvChangeNode", result["nodes"])
|
self.assertIn("LiteralInputTypesAfterEnvChangeNode", result["nodes"])
|
||||||
self.assertEqual("ok", result["pack"]["status"])
|
self.assertEqual("ok", result["pack"]["status"])
|
||||||
|
|
||||||
|
def test_later_dynamic_input_types_binding_skips_node(self):
|
||||||
|
source = '''
|
||||||
|
def build_inputs():
|
||||||
|
return {"required": {"mask": ("MASK",)}}
|
||||||
|
|
||||||
|
|
||||||
|
class LaterDynamicInputTypesNode:
|
||||||
|
RETURN_TYPES = ("IMAGE",)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"image": ("IMAGE",),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return build_inputs()
|
||||||
|
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"LaterDynamicInputTypesNode": LaterDynamicInputTypesNode,
|
||||||
|
}
|
||||||
|
'''
|
||||||
|
result = self._extract_source(source, "later-dynamic-input-types-pack")
|
||||||
|
|
||||||
|
self.assertEqual({}, result["nodes"])
|
||||||
|
self.assertEqual("no_static_nodes", result["pack"]["status"])
|
||||||
|
|
||||||
def test_dynamic_return_types_reassignment_skips_node(self):
|
def test_dynamic_return_types_reassignment_skips_node(self):
|
||||||
source = '''
|
source = '''
|
||||||
def build_outputs():
|
def build_outputs():
|
||||||
@@ -827,6 +885,32 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
self.assertEqual({}, result["nodes"])
|
self.assertEqual({}, result["nodes"])
|
||||||
self.assertEqual("no_static_nodes", result["pack"]["status"])
|
self.assertEqual("no_static_nodes", result["pack"]["status"])
|
||||||
|
|
||||||
|
def test_return_types_transitive_alias_mutation_skips_node(self):
|
||||||
|
source = '''
|
||||||
|
class TransitiveAliasMutatedReturnTypesNode:
|
||||||
|
RETURN_TYPES = ["IMAGE"]
|
||||||
|
A = RETURN_TYPES
|
||||||
|
B = A
|
||||||
|
B.clear()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"image": ("IMAGE",),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"TransitiveAliasMutatedReturnTypesNode": TransitiveAliasMutatedReturnTypesNode,
|
||||||
|
}
|
||||||
|
'''
|
||||||
|
result = self._extract_source(source, "transitive-alias-mutated-return-types-pack")
|
||||||
|
|
||||||
|
self.assertEqual({}, result["nodes"])
|
||||||
|
self.assertEqual("no_static_nodes", result["pack"]["status"])
|
||||||
|
|
||||||
def test_class_return_types_uses_definition_time_module_env(self):
|
def test_class_return_types_uses_definition_time_module_env(self):
|
||||||
source = '''
|
source = '''
|
||||||
RETURNS = ("IMAGE",)
|
RETURNS = ("IMAGE",)
|
||||||
|
|||||||
@@ -35,22 +35,25 @@ _MUTATING_METHODS = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def _literal(node, env):
|
def _literal(node, env, allow_mutable_env=True):
|
||||||
if isinstance(node, ast.Constant):
|
if isinstance(node, ast.Constant):
|
||||||
return node.value
|
return node.value
|
||||||
if isinstance(node, ast.List):
|
if isinstance(node, ast.List):
|
||||||
return [_literal(item, env) for item in node.elts]
|
return [_literal(item, env, allow_mutable_env=False) for item in node.elts]
|
||||||
if isinstance(node, ast.Tuple):
|
if isinstance(node, ast.Tuple):
|
||||||
return tuple(_literal(item, env) for item in node.elts)
|
return tuple(_literal(item, env, allow_mutable_env=False) for item in node.elts)
|
||||||
if isinstance(node, ast.Dict):
|
if isinstance(node, ast.Dict):
|
||||||
result = {}
|
result = {}
|
||||||
for key, value in zip(node.keys, node.values):
|
for key, value in zip(node.keys, node.values):
|
||||||
if key is None:
|
if key is None:
|
||||||
raise UnsupportedStaticExpression("dict unpacking is not supported")
|
raise UnsupportedStaticExpression("dict unpacking is not supported")
|
||||||
result[_literal(key, env)] = _literal(value, env)
|
result[_literal(key, env, allow_mutable_env=False)] = _literal(value, env, allow_mutable_env=False)
|
||||||
return result
|
return result
|
||||||
if isinstance(node, ast.Name) and node.id in env:
|
if isinstance(node, ast.Name) and node.id in env:
|
||||||
return env[node.id]
|
value = env[node.id]
|
||||||
|
if not allow_mutable_env and _is_mutable_static_value(value):
|
||||||
|
raise UnsupportedStaticExpression(f"mutable env reference {node.id!r} is not supported")
|
||||||
|
return value
|
||||||
raise UnsupportedStaticExpression(type(node).__name__)
|
raise UnsupportedStaticExpression(type(node).__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -413,7 +416,7 @@ def _class_attr(cls, name, env):
|
|||||||
and isinstance(stmt.targets[0], ast.Name)
|
and isinstance(stmt.targets[0], ast.Name)
|
||||||
and stmt.targets[0].id != name
|
and stmt.targets[0].id != name
|
||||||
and isinstance(stmt.value, ast.Name)
|
and isinstance(stmt.value, ast.Name)
|
||||||
and stmt.value.id == name
|
and (stmt.value.id == name or stmt.value.id in aliases)
|
||||||
):
|
):
|
||||||
aliases.add(stmt.targets[0].id)
|
aliases.add(stmt.targets[0].id)
|
||||||
continue
|
continue
|
||||||
@@ -439,7 +442,7 @@ def _class_attr(cls, name, env):
|
|||||||
isinstance(stmt.target, ast.Name)
|
isinstance(stmt.target, ast.Name)
|
||||||
and stmt.target.id != name
|
and stmt.target.id != name
|
||||||
and isinstance(stmt.value, ast.Name)
|
and isinstance(stmt.value, ast.Name)
|
||||||
and stmt.value.id == name
|
and (stmt.value.id == name or stmt.value.id in aliases)
|
||||||
):
|
):
|
||||||
aliases.add(stmt.target.id)
|
aliases.add(stmt.target.id)
|
||||||
continue
|
continue
|
||||||
@@ -505,17 +508,47 @@ def _class_attr(cls, name, env):
|
|||||||
|
|
||||||
|
|
||||||
def _input_types(cls, env):
|
def _input_types(cls, env):
|
||||||
|
value = _MISSING
|
||||||
for stmt in cls.body:
|
for stmt in cls.body:
|
||||||
if not isinstance(stmt, ast.FunctionDef) or stmt.name != "INPUT_TYPES":
|
if isinstance(stmt, ast.FunctionDef) and stmt.name == "INPUT_TYPES":
|
||||||
|
if len(stmt.body) != 1 or not isinstance(stmt.body[0], ast.Return):
|
||||||
|
value = _INVALID
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
candidate = _literal(stmt.body[0].value, env)
|
||||||
|
except UnsupportedStaticExpression:
|
||||||
|
value = _INVALID
|
||||||
|
continue
|
||||||
|
value = candidate if isinstance(candidate, dict) else _INVALID
|
||||||
continue
|
continue
|
||||||
if len(stmt.body) != 1 or not isinstance(stmt.body[0], ast.Return):
|
if isinstance(stmt, ast.AsyncFunctionDef) and stmt.name == "INPUT_TYPES":
|
||||||
return None
|
value = _INVALID
|
||||||
try:
|
continue
|
||||||
value = _literal(stmt.body[0].value, env)
|
if isinstance(stmt, (ast.Assign, ast.AnnAssign, ast.AugAssign)):
|
||||||
except UnsupportedStaticExpression:
|
if "INPUT_TYPES" in _assignment_target_names(stmt):
|
||||||
return None
|
value = _INVALID
|
||||||
return value if isinstance(value, dict) else None
|
continue
|
||||||
return None
|
if isinstance(stmt, ast.Delete):
|
||||||
|
if "INPUT_TYPES" in _delete_target_names(stmt):
|
||||||
|
value = _INVALID
|
||||||
|
continue
|
||||||
|
if isinstance(stmt, ast.Expr):
|
||||||
|
if "INPUT_TYPES" in _mutating_call_target_names(stmt):
|
||||||
|
value = _INVALID
|
||||||
|
if "INPUT_TYPES" in _bound_names(stmt):
|
||||||
|
value = _INVALID
|
||||||
|
continue
|
||||||
|
if isinstance(stmt, (ast.If, ast.For, ast.AsyncFor, ast.While, ast.Try, ast.With, ast.AsyncWith, ast.Match)):
|
||||||
|
if "INPUT_TYPES" in _assigned_names_in_control_flow(stmt):
|
||||||
|
value = _INVALID
|
||||||
|
if _has_wildcard_import_in_control_flow(stmt):
|
||||||
|
value = _INVALID
|
||||||
|
continue
|
||||||
|
if "INPUT_TYPES" in _bound_names(stmt):
|
||||||
|
value = _INVALID
|
||||||
|
if value in (_MISSING, _INVALID):
|
||||||
|
return None
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
def _mapping_value_name(value):
|
def _mapping_value_name(value):
|
||||||
|
|||||||
Reference in New Issue
Block a user