Resolve node mappings at assignment time

This commit is contained in:
2026-07-02 14:40:16 +02:00
parent 45e3cbaad8
commit 65c3a57052
2 changed files with 193 additions and 109 deletions
@@ -1224,6 +1224,42 @@ NODE_CLASS_MAPPINGS = {
self.assertEqual({}, result["nodes"]) self.assertEqual({}, result["nodes"])
self.assertEqual("no_static_nodes", result["pack"]["status"]) self.assertEqual("no_static_nodes", result["pack"]["status"])
def test_node_mapping_uses_assignment_time_class_binding(self):
source = '''
class Node:
RETURN_TYPES = ("IMAGE",)
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"image": ("IMAGE",),
},
}
NODE_CLASS_MAPPINGS = {
"Node": Node,
}
class Node:
RETURN_TYPES = ("MASK",)
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"mask": ("MASK",),
},
}
'''
result = self._extract_source(source, "assignment-time-class-binding-pack")
self.assertEqual(["IMAGE"], result["nodes"]["Node"]["outputs"])
self.assertEqual({"image": "IMAGE"}, result["nodes"]["Node"]["inputs"])
self.assertEqual("ok", result["pack"]["status"])
def test_conditional_class_mapping_skips_node(self): def test_conditional_class_mapping_skips_node(self):
source = ''' source = '''
if True: if True:
@@ -1271,6 +1307,35 @@ NODE_CLASS_MAPPINGS = {
self.assertIn("TopLevelMappedNode", result["nodes"]) self.assertIn("TopLevelMappedNode", result["nodes"])
self.assertEqual("ok", result["pack"]["status"]) self.assertEqual("ok", result["pack"]["status"])
def test_node_mapping_key_uses_assignment_time_env(self):
source = '''
KEY = "Original"
class Node:
RETURN_TYPES = ("IMAGE",)
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"image": ("IMAGE",),
},
}
NODE_CLASS_MAPPINGS = {
KEY: Node,
}
KEY = "Wrong"
'''
result = self._extract_source(source, "assignment-time-key-pack")
self.assertIn("Original", result["nodes"])
self.assertNotIn("Wrong", result["nodes"])
self.assertEqual(["IMAGE"], result["nodes"]["Original"]["outputs"])
self.assertEqual("ok", result["pack"]["status"])
def test_mutated_node_class_mapping_skips_node(self): def test_mutated_node_class_mapping_skips_node(self):
source = ''' source = '''
class MutatedMappingNode: class MutatedMappingNode:
+128 -109
View File
@@ -307,104 +307,104 @@ def _invalidate_class_bindings(class_bindings, names):
class_bindings.pop(name, None) class_bindings.pop(name, None)
def _collect_module_env(tree, class_bindings=None): def _apply_module_stmt_to_env(stmt, env, class_bindings=None):
env = {} names = _mutating_call_target_names(stmt)
for stmt in tree.body: _invalidate_class_bindings(class_bindings, names)
names = _mutating_call_target_names(stmt) for name in names:
env.pop(name, None)
if isinstance(stmt, ast.ClassDef):
if class_bindings is not None:
class_bindings[stmt.name] = (stmt, dict(env))
env.pop(stmt.name, None)
return
if isinstance(stmt, ast.Assign):
names = _assignment_target_names(stmt)
_invalidate_class_bindings(class_bindings, names)
if len(stmt.targets) == 1 and isinstance(stmt.targets[0], ast.Name):
name = stmt.targets[0].id
if (
isinstance(stmt.value, ast.Name)
and stmt.value.id in env
and _is_mutable_static_value(env[stmt.value.id])
):
env.pop(stmt.value.id, None)
env.pop(name, None)
return
try:
env[name] = _literal(stmt.value, env)
except UnsupportedStaticExpression:
env.pop(name, None)
else:
for name in names:
env.pop(name, None)
return
if isinstance(stmt, ast.AnnAssign):
names = _assignment_target_names(stmt)
_invalidate_class_bindings(class_bindings, names)
if stmt.value is None:
return
if isinstance(stmt.target, ast.Name):
name = stmt.target.id
if (
isinstance(stmt.value, ast.Name)
and stmt.value.id in env
and _is_mutable_static_value(env[stmt.value.id])
):
env.pop(stmt.value.id, None)
env.pop(name, None)
return
try:
env[name] = _literal(stmt.value, env)
except UnsupportedStaticExpression:
env.pop(name, None)
else:
for name in names:
env.pop(name, None)
return
if isinstance(stmt, ast.AugAssign):
names = _assignment_target_names(stmt)
_invalidate_class_bindings(class_bindings, names) _invalidate_class_bindings(class_bindings, names)
for name in names: for name in names:
env.pop(name, None) env.pop(name, None)
if isinstance(stmt, ast.ClassDef): return
if class_bindings is not None: if isinstance(stmt, ast.Delete):
class_bindings[stmt.name] = (stmt, dict(env)) names = _delete_target_names(stmt)
env.pop(stmt.name, None) _invalidate_class_bindings(class_bindings, names)
continue for name in names:
if isinstance(stmt, ast.Assign): env.pop(name, None)
names = _assignment_target_names(stmt) return
_invalidate_class_bindings(class_bindings, names) if isinstance(stmt, ast.Expr):
if len(stmt.targets) == 1 and isinstance(stmt.targets[0], ast.Name):
name = stmt.targets[0].id
if (
isinstance(stmt.value, ast.Name)
and stmt.value.id in env
and _is_mutable_static_value(env[stmt.value.id])
):
env.pop(stmt.value.id, None)
env.pop(name, None)
continue
try:
env[name] = _literal(stmt.value, env)
except UnsupportedStaticExpression:
env.pop(name, None)
else:
for name in names:
env.pop(name, None)
continue
if isinstance(stmt, ast.AnnAssign):
names = _assignment_target_names(stmt)
_invalidate_class_bindings(class_bindings, names)
if stmt.value is None:
continue
if isinstance(stmt.target, ast.Name):
name = stmt.target.id
if (
isinstance(stmt.value, ast.Name)
and stmt.value.id in env
and _is_mutable_static_value(env[stmt.value.id])
):
env.pop(stmt.value.id, None)
env.pop(name, None)
continue
try:
env[name] = _literal(stmt.value, env)
except UnsupportedStaticExpression:
env.pop(name, None)
else:
for name in names:
env.pop(name, None)
continue
if isinstance(stmt, ast.AugAssign):
names = _assignment_target_names(stmt)
_invalidate_class_bindings(class_bindings, names)
for name in names:
env.pop(name, None)
continue
if isinstance(stmt, ast.Delete):
names = _delete_target_names(stmt)
_invalidate_class_bindings(class_bindings, names)
for name in names:
env.pop(name, None)
continue
if isinstance(stmt, ast.Expr):
names = _mutating_call_target_names(stmt)
_invalidate_class_bindings(class_bindings, names)
for name in names:
env.pop(name, None)
names = _bound_names(stmt)
_invalidate_class_bindings(class_bindings, names)
for name in names:
env.pop(name, None)
continue
if isinstance(stmt, (ast.If, ast.For, ast.AsyncFor, ast.While, ast.Try, ast.With, ast.AsyncWith, ast.Match)):
if _has_wildcard_import_in_control_flow(stmt):
env.clear()
if class_bindings is not None:
class_bindings.clear()
continue
names = _assigned_names_in_control_flow(stmt)
_invalidate_class_bindings(class_bindings, names)
for name in names:
env.pop(name, None)
continue
if _has_wildcard_import(stmt):
env.clear()
if class_bindings is not None:
class_bindings.clear()
continue
names = _bound_names(stmt) names = _bound_names(stmt)
_invalidate_class_bindings(class_bindings, names) _invalidate_class_bindings(class_bindings, names)
for name in names: for name in names:
env.pop(name, None) env.pop(name, None)
return
if isinstance(stmt, (ast.If, ast.For, ast.AsyncFor, ast.While, ast.Try, ast.With, ast.AsyncWith, ast.Match)):
if _has_wildcard_import_in_control_flow(stmt):
env.clear()
if class_bindings is not None:
class_bindings.clear()
return
names = _assigned_names_in_control_flow(stmt)
_invalidate_class_bindings(class_bindings, names)
for name in names:
env.pop(name, None)
return
if _has_wildcard_import(stmt):
env.clear()
if class_bindings is not None:
class_bindings.clear()
return
names = _bound_names(stmt)
_invalidate_class_bindings(class_bindings, names)
for name in names:
env.pop(name, None)
def _collect_module_env(tree, class_bindings=None):
env = {}
for stmt in tree.body:
_apply_module_stmt_to_env(stmt, env, class_bindings)
return env return env
@@ -588,22 +588,24 @@ def _name_is_assigned(stmt, name):
return name in _assignment_target_names(stmt) return name in _assignment_target_names(stmt)
def _module_dict_entries(node, env, value_converter): def _module_dict_entries(node, env, class_bindings, value_converter):
if not isinstance(node, ast.Dict): if not isinstance(node, ast.Dict):
raise UnsupportedStaticExpression(type(node).__name__) raise UnsupportedStaticExpression(type(node).__name__)
result = {} result = {}
for key, value in zip(node.keys, node.values): for key, value in zip(node.keys, node.values):
if key is None: if key is None:
raise UnsupportedStaticExpression("dict unpacking is not supported") raise UnsupportedStaticExpression("dict unpacking is not supported")
converted_value = value_converter(value) converted_value = value_converter(value, env, class_bindings)
if converted_value is None: if converted_value is None:
continue continue
result[_literal(key, env)] = converted_value result[_literal(key, env)] = converted_value
return result return result
def _final_module_dict(tree, env, name, value_converter): def _final_module_dict(tree, name, value_converter):
value = _MISSING value = _MISSING
env = {}
class_bindings = {}
for stmt in tree.body: for stmt in tree.body:
if name in _mutating_call_target_names(stmt): if name in _mutating_call_target_names(stmt):
value = _INVALID value = _INVALID
@@ -611,67 +613,88 @@ def _final_module_dict(tree, env, name, value_converter):
if not _name_is_assigned(stmt, name): if not _name_is_assigned(stmt, name):
if isinstance(stmt.value, ast.Name) and stmt.value.id == name: if isinstance(stmt.value, ast.Name) and stmt.value.id == name:
value = _INVALID value = _INVALID
_apply_module_stmt_to_env(stmt, env, class_bindings)
continue continue
if len(stmt.targets) == 1 and isinstance(stmt.targets[0], ast.Name): if len(stmt.targets) == 1 and isinstance(stmt.targets[0], ast.Name):
try: try:
value = _module_dict_entries(stmt.value, env, value_converter) value = _module_dict_entries(stmt.value, env, class_bindings, value_converter)
except UnsupportedStaticExpression: except UnsupportedStaticExpression:
value = _INVALID value = _INVALID
else: else:
value = _INVALID value = _INVALID
_apply_module_stmt_to_env(stmt, env, class_bindings)
continue continue
if isinstance(stmt, ast.AnnAssign): if isinstance(stmt, ast.AnnAssign):
if not _name_is_assigned(stmt, name): if not _name_is_assigned(stmt, name):
if isinstance(stmt.value, ast.Name) and stmt.value.id == name: if isinstance(stmt.value, ast.Name) and stmt.value.id == name:
value = _INVALID value = _INVALID
_apply_module_stmt_to_env(stmt, env, class_bindings)
continue continue
if isinstance(stmt.target, ast.Name) and stmt.value is not None: if isinstance(stmt.target, ast.Name) and stmt.value is not None:
try: try:
value = _module_dict_entries(stmt.value, env, value_converter) value = _module_dict_entries(stmt.value, env, class_bindings, value_converter)
except UnsupportedStaticExpression: except UnsupportedStaticExpression:
value = _INVALID value = _INVALID
else: else:
value = _INVALID value = _INVALID
_apply_module_stmt_to_env(stmt, env, class_bindings)
continue continue
if isinstance(stmt, ast.AugAssign): if isinstance(stmt, ast.AugAssign):
if _name_is_assigned(stmt, name): if _name_is_assigned(stmt, name):
value = _INVALID value = _INVALID
_apply_module_stmt_to_env(stmt, env, class_bindings)
continue continue
if isinstance(stmt, ast.Delete): if isinstance(stmt, ast.Delete):
if name in _delete_target_names(stmt): if name in _delete_target_names(stmt):
value = _INVALID value = _INVALID
_apply_module_stmt_to_env(stmt, env, class_bindings)
continue continue
if isinstance(stmt, ast.Expr): if isinstance(stmt, ast.Expr):
if name in _mutating_call_target_names(stmt): if name in _mutating_call_target_names(stmt):
value = _INVALID value = _INVALID
if name in _bound_names(stmt): if name in _bound_names(stmt):
value = _INVALID value = _INVALID
_apply_module_stmt_to_env(stmt, env, class_bindings)
continue continue
if isinstance(stmt, (ast.If, ast.For, ast.AsyncFor, ast.While, ast.Try, ast.With, ast.AsyncWith, ast.Match)): if isinstance(stmt, (ast.If, ast.For, ast.AsyncFor, ast.While, ast.Try, ast.With, ast.AsyncWith, ast.Match)):
if name in _assigned_names_in_control_flow(stmt): if name in _assigned_names_in_control_flow(stmt):
value = _INVALID value = _INVALID
if _has_wildcard_import_in_control_flow(stmt): if _has_wildcard_import_in_control_flow(stmt):
value = _INVALID value = _INVALID
_apply_module_stmt_to_env(stmt, env, class_bindings)
continue continue
if _has_wildcard_import(stmt): if _has_wildcard_import(stmt):
value = _INVALID value = _INVALID
_apply_module_stmt_to_env(stmt, env, class_bindings)
continue continue
if name in _bound_names(stmt): if name in _bound_names(stmt):
value = _INVALID value = _INVALID
_apply_module_stmt_to_env(stmt, env, class_bindings)
if value in (_MISSING, _INVALID): if value in (_MISSING, _INVALID):
return {} return {}
return value return value
def _node_class_mappings(tree, env): def _mapping_value_binding(value, env, class_bindings):
class_name = _mapping_value_name(value)
if class_name is None:
return None
return class_bindings.get(class_name)
def _node_class_mappings(tree):
if _has_module_wildcard_import(tree): if _has_module_wildcard_import(tree):
return {} return {}
mappings = _final_module_dict(tree, env, "NODE_CLASS_MAPPINGS", _mapping_value_name) mappings = _final_module_dict(tree, "NODE_CLASS_MAPPINGS", _mapping_value_binding)
return {str(node_type): class_name for node_type, class_name in mappings.items() if node_type and class_name} return {str(node_type): binding for node_type, binding in mappings.items() if node_type and binding is not None}
def _display_mappings(tree, env): def _display_mappings(tree):
displays = _final_module_dict(tree, env, "NODE_DISPLAY_NAME_MAPPINGS", lambda value: _literal(value, env)) displays = _final_module_dict(
tree,
"NODE_DISPLAY_NAME_MAPPINGS",
lambda value, env, _class_bindings: _literal(value, env),
)
return {str(k): str(v) for k, v in displays.items()} return {str(k): str(v) for k, v in displays.items()}
@@ -740,14 +763,10 @@ def extract_repo_signatures(repo_dir, pack_meta):
tree = _parse_python_file(path) tree = _parse_python_file(path)
if tree is None: if tree is None:
continue continue
class_bindings = {} env = _collect_module_env(tree)
env = _collect_module_env(tree, class_bindings) mappings = _node_class_mappings(tree)
mappings = _node_class_mappings(tree, env) displays = _display_mappings(tree)
displays = _display_mappings(tree, env) for node_type, binding in sorted(mappings.items()):
for node_type, class_name in sorted(mappings.items()):
binding = class_bindings.get(class_name)
if binding is None:
continue
cls, class_env = binding cls, class_env = binding
sig = _signature_from_class(node_type, cls, displays.get(node_type), pack_meta, class_env, env) sig = _signature_from_class(node_type, cls, displays.get(node_type), pack_meta, class_env, env)
if sig is not None: if sig is not None: