Fail closed on final input bindings and transitive aliases

This commit is contained in:
2026-07-02 14:18:38 +02:00
parent 99d2bb25da
commit 6c2653c803
2 changed files with 133 additions and 16 deletions
@@ -590,6 +590,34 @@ NODE_CLASS_MAPPINGS = {
self.assertEqual({}, result["nodes"]) self.assertEqual({}, result["nodes"])
self.assertEqual("no_static_nodes", result["pack"]["status"]) self.assertEqual("no_static_nodes", result["pack"]["status"])
def test_nested_mutable_env_literal_skips_static_node(self):
source = '''
REQ = {
"image": ("IMAGE",),
}
INPUTS = {
"required": REQ,
}
REQ.clear()
class NestedMutableEnvLiteralNode:
RETURN_TYPES = ("IMAGE",)
@classmethod
def INPUT_TYPES(cls):
return INPUTS
NODE_CLASS_MAPPINGS = {
"NestedMutableEnvLiteralNode": NestedMutableEnvLiteralNode,
}
'''
result = self._extract_source(source, "nested-mutable-env-literal-pack")
self.assertEqual({}, result["nodes"])
self.assertEqual("no_static_nodes", result["pack"]["status"])
def test_post_class_input_reassignment_skips_static_node(self): def test_post_class_input_reassignment_skips_static_node(self):
source = ''' source = '''
def build_inputs(): def build_inputs():
@@ -650,6 +678,36 @@ NODE_CLASS_MAPPINGS = {
self.assertIn("LiteralInputTypesAfterEnvChangeNode", result["nodes"]) self.assertIn("LiteralInputTypesAfterEnvChangeNode", result["nodes"])
self.assertEqual("ok", result["pack"]["status"]) self.assertEqual("ok", result["pack"]["status"])
def test_later_dynamic_input_types_binding_skips_node(self):
source = '''
def build_inputs():
return {"required": {"mask": ("MASK",)}}
class LaterDynamicInputTypesNode:
RETURN_TYPES = ("IMAGE",)
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"image": ("IMAGE",),
},
}
def INPUT_TYPES(cls):
return build_inputs()
NODE_CLASS_MAPPINGS = {
"LaterDynamicInputTypesNode": LaterDynamicInputTypesNode,
}
'''
result = self._extract_source(source, "later-dynamic-input-types-pack")
self.assertEqual({}, result["nodes"])
self.assertEqual("no_static_nodes", result["pack"]["status"])
def test_dynamic_return_types_reassignment_skips_node(self): def test_dynamic_return_types_reassignment_skips_node(self):
source = ''' source = '''
def build_outputs(): def build_outputs():
@@ -827,6 +885,32 @@ NODE_CLASS_MAPPINGS = {
self.assertEqual({}, result["nodes"]) self.assertEqual({}, result["nodes"])
self.assertEqual("no_static_nodes", result["pack"]["status"]) self.assertEqual("no_static_nodes", result["pack"]["status"])
def test_return_types_transitive_alias_mutation_skips_node(self):
source = '''
class TransitiveAliasMutatedReturnTypesNode:
RETURN_TYPES = ["IMAGE"]
A = RETURN_TYPES
B = A
B.clear()
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"image": ("IMAGE",),
},
}
NODE_CLASS_MAPPINGS = {
"TransitiveAliasMutatedReturnTypesNode": TransitiveAliasMutatedReturnTypesNode,
}
'''
result = self._extract_source(source, "transitive-alias-mutated-return-types-pack")
self.assertEqual({}, result["nodes"])
self.assertEqual("no_static_nodes", result["pack"]["status"])
def test_class_return_types_uses_definition_time_module_env(self): def test_class_return_types_uses_definition_time_module_env(self):
source = ''' source = '''
RETURNS = ("IMAGE",) RETURNS = ("IMAGE",)
+49 -16
View File
@@ -35,22 +35,25 @@ _MUTATING_METHODS = {
} }
def _literal(node, env): def _literal(node, env, allow_mutable_env=True):
if isinstance(node, ast.Constant): if isinstance(node, ast.Constant):
return node.value return node.value
if isinstance(node, ast.List): if isinstance(node, ast.List):
return [_literal(item, env) for item in node.elts] return [_literal(item, env, allow_mutable_env=False) for item in node.elts]
if isinstance(node, ast.Tuple): if isinstance(node, ast.Tuple):
return tuple(_literal(item, env) for item in node.elts) return tuple(_literal(item, env, allow_mutable_env=False) for item in node.elts)
if isinstance(node, ast.Dict): if isinstance(node, ast.Dict):
result = {} result = {}
for key, value in zip(node.keys, node.values): for key, value in zip(node.keys, node.values):
if key is None: if key is None:
raise UnsupportedStaticExpression("dict unpacking is not supported") raise UnsupportedStaticExpression("dict unpacking is not supported")
result[_literal(key, env)] = _literal(value, env) result[_literal(key, env, allow_mutable_env=False)] = _literal(value, env, allow_mutable_env=False)
return result return result
if isinstance(node, ast.Name) and node.id in env: if isinstance(node, ast.Name) and node.id in env:
return env[node.id] value = env[node.id]
if not allow_mutable_env and _is_mutable_static_value(value):
raise UnsupportedStaticExpression(f"mutable env reference {node.id!r} is not supported")
return value
raise UnsupportedStaticExpression(type(node).__name__) raise UnsupportedStaticExpression(type(node).__name__)
@@ -413,7 +416,7 @@ def _class_attr(cls, name, env):
and isinstance(stmt.targets[0], ast.Name) and isinstance(stmt.targets[0], ast.Name)
and stmt.targets[0].id != name and stmt.targets[0].id != name
and isinstance(stmt.value, ast.Name) and isinstance(stmt.value, ast.Name)
and stmt.value.id == name and (stmt.value.id == name or stmt.value.id in aliases)
): ):
aliases.add(stmt.targets[0].id) aliases.add(stmt.targets[0].id)
continue continue
@@ -439,7 +442,7 @@ def _class_attr(cls, name, env):
isinstance(stmt.target, ast.Name) isinstance(stmt.target, ast.Name)
and stmt.target.id != name and stmt.target.id != name
and isinstance(stmt.value, ast.Name) and isinstance(stmt.value, ast.Name)
and stmt.value.id == name and (stmt.value.id == name or stmt.value.id in aliases)
): ):
aliases.add(stmt.target.id) aliases.add(stmt.target.id)
continue continue
@@ -505,17 +508,47 @@ def _class_attr(cls, name, env):
def _input_types(cls, env): def _input_types(cls, env):
value = _MISSING
for stmt in cls.body: for stmt in cls.body:
if not isinstance(stmt, ast.FunctionDef) or stmt.name != "INPUT_TYPES": if isinstance(stmt, ast.FunctionDef) and stmt.name == "INPUT_TYPES":
if len(stmt.body) != 1 or not isinstance(stmt.body[0], ast.Return):
value = _INVALID
continue
try:
candidate = _literal(stmt.body[0].value, env)
except UnsupportedStaticExpression:
value = _INVALID
continue
value = candidate if isinstance(candidate, dict) else _INVALID
continue continue
if len(stmt.body) != 1 or not isinstance(stmt.body[0], ast.Return): if isinstance(stmt, ast.AsyncFunctionDef) and stmt.name == "INPUT_TYPES":
return None value = _INVALID
try: continue
value = _literal(stmt.body[0].value, env) if isinstance(stmt, (ast.Assign, ast.AnnAssign, ast.AugAssign)):
except UnsupportedStaticExpression: if "INPUT_TYPES" in _assignment_target_names(stmt):
return None value = _INVALID
return value if isinstance(value, dict) else None continue
return None if isinstance(stmt, ast.Delete):
if "INPUT_TYPES" in _delete_target_names(stmt):
value = _INVALID
continue
if isinstance(stmt, ast.Expr):
if "INPUT_TYPES" in _mutating_call_target_names(stmt):
value = _INVALID
if "INPUT_TYPES" in _bound_names(stmt):
value = _INVALID
continue
if isinstance(stmt, (ast.If, ast.For, ast.AsyncFor, ast.While, ast.Try, ast.With, ast.AsyncWith, ast.Match)):
if "INPUT_TYPES" in _assigned_names_in_control_flow(stmt):
value = _INVALID
if _has_wildcard_import_in_control_flow(stmt):
value = _INVALID
continue
if "INPUT_TYPES" in _bound_names(stmt):
value = _INVALID
if value in (_MISSING, _INVALID):
return None
return value
def _mapping_value_name(value): def _mapping_value_name(value):