Resolve node mappings at assignment time

This commit is contained in:
2026-07-02 14:40:16 +02:00
parent 45e3cbaad8
commit 65c3a57052
2 changed files with 193 additions and 109 deletions
@@ -1224,6 +1224,42 @@ NODE_CLASS_MAPPINGS = {
self.assertEqual({}, result["nodes"])
self.assertEqual("no_static_nodes", result["pack"]["status"])
def test_node_mapping_uses_assignment_time_class_binding(self):
source = '''
class Node:
RETURN_TYPES = ("IMAGE",)
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"image": ("IMAGE",),
},
}
NODE_CLASS_MAPPINGS = {
"Node": Node,
}
class Node:
RETURN_TYPES = ("MASK",)
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"mask": ("MASK",),
},
}
'''
result = self._extract_source(source, "assignment-time-class-binding-pack")
self.assertEqual(["IMAGE"], result["nodes"]["Node"]["outputs"])
self.assertEqual({"image": "IMAGE"}, result["nodes"]["Node"]["inputs"])
self.assertEqual("ok", result["pack"]["status"])
def test_conditional_class_mapping_skips_node(self):
source = '''
if True:
@@ -1271,6 +1307,35 @@ NODE_CLASS_MAPPINGS = {
self.assertIn("TopLevelMappedNode", result["nodes"])
self.assertEqual("ok", result["pack"]["status"])
def test_node_mapping_key_uses_assignment_time_env(self):
source = '''
KEY = "Original"
class Node:
RETURN_TYPES = ("IMAGE",)
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"image": ("IMAGE",),
},
}
NODE_CLASS_MAPPINGS = {
KEY: Node,
}
KEY = "Wrong"
'''
result = self._extract_source(source, "assignment-time-key-pack")
self.assertIn("Original", result["nodes"])
self.assertNotIn("Wrong", result["nodes"])
self.assertEqual(["IMAGE"], result["nodes"]["Original"]["outputs"])
self.assertEqual("ok", result["pack"]["status"])
def test_mutated_node_class_mapping_skips_node(self):
source = '''
class MutatedMappingNode:
+128 -109
View File
@@ -307,104 +307,104 @@ def _invalidate_class_bindings(class_bindings, names):
class_bindings.pop(name, None)
def _collect_module_env(tree, class_bindings=None):
env = {}
for stmt in tree.body:
names = _mutating_call_target_names(stmt)
def _apply_module_stmt_to_env(stmt, env, class_bindings=None):
names = _mutating_call_target_names(stmt)
_invalidate_class_bindings(class_bindings, names)
for name in names:
env.pop(name, None)
if isinstance(stmt, ast.ClassDef):
if class_bindings is not None:
class_bindings[stmt.name] = (stmt, dict(env))
env.pop(stmt.name, None)
return
if isinstance(stmt, ast.Assign):
names = _assignment_target_names(stmt)
_invalidate_class_bindings(class_bindings, names)
if len(stmt.targets) == 1 and isinstance(stmt.targets[0], ast.Name):
name = stmt.targets[0].id
if (
isinstance(stmt.value, ast.Name)
and stmt.value.id in env
and _is_mutable_static_value(env[stmt.value.id])
):
env.pop(stmt.value.id, None)
env.pop(name, None)
return
try:
env[name] = _literal(stmt.value, env)
except UnsupportedStaticExpression:
env.pop(name, None)
else:
for name in names:
env.pop(name, None)
return
if isinstance(stmt, ast.AnnAssign):
names = _assignment_target_names(stmt)
_invalidate_class_bindings(class_bindings, names)
if stmt.value is None:
return
if isinstance(stmt.target, ast.Name):
name = stmt.target.id
if (
isinstance(stmt.value, ast.Name)
and stmt.value.id in env
and _is_mutable_static_value(env[stmt.value.id])
):
env.pop(stmt.value.id, None)
env.pop(name, None)
return
try:
env[name] = _literal(stmt.value, env)
except UnsupportedStaticExpression:
env.pop(name, None)
else:
for name in names:
env.pop(name, None)
return
if isinstance(stmt, ast.AugAssign):
names = _assignment_target_names(stmt)
_invalidate_class_bindings(class_bindings, names)
for name in names:
env.pop(name, None)
if isinstance(stmt, ast.ClassDef):
if class_bindings is not None:
class_bindings[stmt.name] = (stmt, dict(env))
env.pop(stmt.name, None)
continue
if isinstance(stmt, ast.Assign):
names = _assignment_target_names(stmt)
_invalidate_class_bindings(class_bindings, names)
if len(stmt.targets) == 1 and isinstance(stmt.targets[0], ast.Name):
name = stmt.targets[0].id
if (
isinstance(stmt.value, ast.Name)
and stmt.value.id in env
and _is_mutable_static_value(env[stmt.value.id])
):
env.pop(stmt.value.id, None)
env.pop(name, None)
continue
try:
env[name] = _literal(stmt.value, env)
except UnsupportedStaticExpression:
env.pop(name, None)
else:
for name in names:
env.pop(name, None)
continue
if isinstance(stmt, ast.AnnAssign):
names = _assignment_target_names(stmt)
_invalidate_class_bindings(class_bindings, names)
if stmt.value is None:
continue
if isinstance(stmt.target, ast.Name):
name = stmt.target.id
if (
isinstance(stmt.value, ast.Name)
and stmt.value.id in env
and _is_mutable_static_value(env[stmt.value.id])
):
env.pop(stmt.value.id, None)
env.pop(name, None)
continue
try:
env[name] = _literal(stmt.value, env)
except UnsupportedStaticExpression:
env.pop(name, None)
else:
for name in names:
env.pop(name, None)
continue
if isinstance(stmt, ast.AugAssign):
names = _assignment_target_names(stmt)
_invalidate_class_bindings(class_bindings, names)
for name in names:
env.pop(name, None)
continue
if isinstance(stmt, ast.Delete):
names = _delete_target_names(stmt)
_invalidate_class_bindings(class_bindings, names)
for name in names:
env.pop(name, None)
continue
if isinstance(stmt, ast.Expr):
names = _mutating_call_target_names(stmt)
_invalidate_class_bindings(class_bindings, names)
for name in names:
env.pop(name, None)
names = _bound_names(stmt)
_invalidate_class_bindings(class_bindings, names)
for name in names:
env.pop(name, None)
continue
if isinstance(stmt, (ast.If, ast.For, ast.AsyncFor, ast.While, ast.Try, ast.With, ast.AsyncWith, ast.Match)):
if _has_wildcard_import_in_control_flow(stmt):
env.clear()
if class_bindings is not None:
class_bindings.clear()
continue
names = _assigned_names_in_control_flow(stmt)
_invalidate_class_bindings(class_bindings, names)
for name in names:
env.pop(name, None)
continue
if _has_wildcard_import(stmt):
env.clear()
if class_bindings is not None:
class_bindings.clear()
continue
return
if isinstance(stmt, ast.Delete):
names = _delete_target_names(stmt)
_invalidate_class_bindings(class_bindings, names)
for name in names:
env.pop(name, None)
return
if isinstance(stmt, ast.Expr):
names = _bound_names(stmt)
_invalidate_class_bindings(class_bindings, names)
for name in names:
env.pop(name, None)
return
if isinstance(stmt, (ast.If, ast.For, ast.AsyncFor, ast.While, ast.Try, ast.With, ast.AsyncWith, ast.Match)):
if _has_wildcard_import_in_control_flow(stmt):
env.clear()
if class_bindings is not None:
class_bindings.clear()
return
names = _assigned_names_in_control_flow(stmt)
_invalidate_class_bindings(class_bindings, names)
for name in names:
env.pop(name, None)
return
if _has_wildcard_import(stmt):
env.clear()
if class_bindings is not None:
class_bindings.clear()
return
names = _bound_names(stmt)
_invalidate_class_bindings(class_bindings, names)
for name in names:
env.pop(name, None)
def _collect_module_env(tree, class_bindings=None):
env = {}
for stmt in tree.body:
_apply_module_stmt_to_env(stmt, env, class_bindings)
return env
@@ -588,22 +588,24 @@ def _name_is_assigned(stmt, name):
return name in _assignment_target_names(stmt)
def _module_dict_entries(node, env, value_converter):
def _module_dict_entries(node, env, class_bindings, value_converter):
if not isinstance(node, ast.Dict):
raise UnsupportedStaticExpression(type(node).__name__)
result = {}
for key, value in zip(node.keys, node.values):
if key is None:
raise UnsupportedStaticExpression("dict unpacking is not supported")
converted_value = value_converter(value)
converted_value = value_converter(value, env, class_bindings)
if converted_value is None:
continue
result[_literal(key, env)] = converted_value
return result
def _final_module_dict(tree, env, name, value_converter):
def _final_module_dict(tree, name, value_converter):
value = _MISSING
env = {}
class_bindings = {}
for stmt in tree.body:
if name in _mutating_call_target_names(stmt):
value = _INVALID
@@ -611,67 +613,88 @@ def _final_module_dict(tree, env, name, value_converter):
if not _name_is_assigned(stmt, name):
if isinstance(stmt.value, ast.Name) and stmt.value.id == name:
value = _INVALID
_apply_module_stmt_to_env(stmt, env, class_bindings)
continue
if len(stmt.targets) == 1 and isinstance(stmt.targets[0], ast.Name):
try:
value = _module_dict_entries(stmt.value, env, value_converter)
value = _module_dict_entries(stmt.value, env, class_bindings, value_converter)
except UnsupportedStaticExpression:
value = _INVALID
else:
value = _INVALID
_apply_module_stmt_to_env(stmt, env, class_bindings)
continue
if isinstance(stmt, ast.AnnAssign):
if not _name_is_assigned(stmt, name):
if isinstance(stmt.value, ast.Name) and stmt.value.id == name:
value = _INVALID
_apply_module_stmt_to_env(stmt, env, class_bindings)
continue
if isinstance(stmt.target, ast.Name) and stmt.value is not None:
try:
value = _module_dict_entries(stmt.value, env, value_converter)
value = _module_dict_entries(stmt.value, env, class_bindings, value_converter)
except UnsupportedStaticExpression:
value = _INVALID
else:
value = _INVALID
_apply_module_stmt_to_env(stmt, env, class_bindings)
continue
if isinstance(stmt, ast.AugAssign):
if _name_is_assigned(stmt, name):
value = _INVALID
_apply_module_stmt_to_env(stmt, env, class_bindings)
continue
if isinstance(stmt, ast.Delete):
if name in _delete_target_names(stmt):
value = _INVALID
_apply_module_stmt_to_env(stmt, env, class_bindings)
continue
if isinstance(stmt, ast.Expr):
if name in _mutating_call_target_names(stmt):
value = _INVALID
if name in _bound_names(stmt):
value = _INVALID
_apply_module_stmt_to_env(stmt, env, class_bindings)
continue
if isinstance(stmt, (ast.If, ast.For, ast.AsyncFor, ast.While, ast.Try, ast.With, ast.AsyncWith, ast.Match)):
if name in _assigned_names_in_control_flow(stmt):
value = _INVALID
if _has_wildcard_import_in_control_flow(stmt):
value = _INVALID
_apply_module_stmt_to_env(stmt, env, class_bindings)
continue
if _has_wildcard_import(stmt):
value = _INVALID
_apply_module_stmt_to_env(stmt, env, class_bindings)
continue
if name in _bound_names(stmt):
value = _INVALID
_apply_module_stmt_to_env(stmt, env, class_bindings)
if value in (_MISSING, _INVALID):
return {}
return value
def _node_class_mappings(tree, env):
def _mapping_value_binding(value, env, class_bindings):
class_name = _mapping_value_name(value)
if class_name is None:
return None
return class_bindings.get(class_name)
def _node_class_mappings(tree):
if _has_module_wildcard_import(tree):
return {}
mappings = _final_module_dict(tree, env, "NODE_CLASS_MAPPINGS", _mapping_value_name)
return {str(node_type): class_name for node_type, class_name in mappings.items() if node_type and class_name}
mappings = _final_module_dict(tree, "NODE_CLASS_MAPPINGS", _mapping_value_binding)
return {str(node_type): binding for node_type, binding in mappings.items() if node_type and binding is not None}
def _display_mappings(tree, env):
displays = _final_module_dict(tree, env, "NODE_DISPLAY_NAME_MAPPINGS", lambda value: _literal(value, env))
def _display_mappings(tree):
displays = _final_module_dict(
tree,
"NODE_DISPLAY_NAME_MAPPINGS",
lambda value, env, _class_bindings: _literal(value, env),
)
return {str(k): str(v) for k, v in displays.items()}
@@ -740,14 +763,10 @@ def extract_repo_signatures(repo_dir, pack_meta):
tree = _parse_python_file(path)
if tree is None:
continue
class_bindings = {}
env = _collect_module_env(tree, class_bindings)
mappings = _node_class_mappings(tree, env)
displays = _display_mappings(tree, env)
for node_type, class_name in sorted(mappings.items()):
binding = class_bindings.get(class_name)
if binding is None:
continue
env = _collect_module_env(tree)
mappings = _node_class_mappings(tree)
displays = _display_mappings(tree)
for node_type, binding in sorted(mappings.items()):
cls, class_env = binding
sig = _signature_from_class(node_type, cls, displays.get(node_type), pack_meta, class_env, env)
if sig is not None: