diff --git a/tests/test_generate_popular_node_signatures.py b/tests/test_generate_popular_node_signatures.py index b531ccd..23b9fd7 100644 --- a/tests/test_generate_popular_node_signatures.py +++ b/tests/test_generate_popular_node_signatures.py @@ -1224,6 +1224,42 @@ NODE_CLASS_MAPPINGS = { self.assertEqual({}, result["nodes"]) self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_node_mapping_uses_assignment_time_class_binding(self): + source = ''' +class Node: + RETURN_TYPES = ("IMAGE",) + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image": ("IMAGE",), + }, + } + + +NODE_CLASS_MAPPINGS = { + "Node": Node, +} + + +class Node: + RETURN_TYPES = ("MASK",) + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "mask": ("MASK",), + }, + } +''' + result = self._extract_source(source, "assignment-time-class-binding-pack") + + self.assertEqual(["IMAGE"], result["nodes"]["Node"]["outputs"]) + self.assertEqual({"image": "IMAGE"}, result["nodes"]["Node"]["inputs"]) + self.assertEqual("ok", result["pack"]["status"]) + def test_conditional_class_mapping_skips_node(self): source = ''' if True: @@ -1271,6 +1307,35 @@ NODE_CLASS_MAPPINGS = { self.assertIn("TopLevelMappedNode", result["nodes"]) self.assertEqual("ok", result["pack"]["status"]) + def test_node_mapping_key_uses_assignment_time_env(self): + source = ''' +KEY = "Original" + + +class Node: + RETURN_TYPES = ("IMAGE",) + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image": ("IMAGE",), + }, + } + + +NODE_CLASS_MAPPINGS = { + KEY: Node, +} +KEY = "Wrong" +''' + result = self._extract_source(source, "assignment-time-key-pack") + + self.assertIn("Original", result["nodes"]) + self.assertNotIn("Wrong", result["nodes"]) + self.assertEqual(["IMAGE"], result["nodes"]["Original"]["outputs"]) + self.assertEqual("ok", result["pack"]["status"]) + def test_mutated_node_class_mapping_skips_node(self): source = ''' class MutatedMappingNode: diff --git a/tools/generate_popular_node_signatures.py b/tools/generate_popular_node_signatures.py index 9a657ae..77cd19d 100644 --- a/tools/generate_popular_node_signatures.py +++ b/tools/generate_popular_node_signatures.py @@ -307,104 +307,104 @@ def _invalidate_class_bindings(class_bindings, names): class_bindings.pop(name, None) -def _collect_module_env(tree, class_bindings=None): - env = {} - for stmt in tree.body: - names = _mutating_call_target_names(stmt) +def _apply_module_stmt_to_env(stmt, env, class_bindings=None): + names = _mutating_call_target_names(stmt) + _invalidate_class_bindings(class_bindings, names) + for name in names: + env.pop(name, None) + if isinstance(stmt, ast.ClassDef): + if class_bindings is not None: + class_bindings[stmt.name] = (stmt, dict(env)) + env.pop(stmt.name, None) + return + if isinstance(stmt, ast.Assign): + names = _assignment_target_names(stmt) + _invalidate_class_bindings(class_bindings, names) + if len(stmt.targets) == 1 and isinstance(stmt.targets[0], ast.Name): + name = stmt.targets[0].id + if ( + isinstance(stmt.value, ast.Name) + and stmt.value.id in env + and _is_mutable_static_value(env[stmt.value.id]) + ): + env.pop(stmt.value.id, None) + env.pop(name, None) + return + try: + env[name] = _literal(stmt.value, env) + except UnsupportedStaticExpression: + env.pop(name, None) + else: + for name in names: + env.pop(name, None) + return + if isinstance(stmt, ast.AnnAssign): + names = _assignment_target_names(stmt) + _invalidate_class_bindings(class_bindings, names) + if stmt.value is None: + return + if isinstance(stmt.target, ast.Name): + name = stmt.target.id + if ( + isinstance(stmt.value, ast.Name) + and stmt.value.id in env + and _is_mutable_static_value(env[stmt.value.id]) + ): + env.pop(stmt.value.id, None) + env.pop(name, None) + return + try: + env[name] = _literal(stmt.value, env) + except UnsupportedStaticExpression: + env.pop(name, None) + else: + for name in names: + env.pop(name, None) + return + if isinstance(stmt, ast.AugAssign): + names = _assignment_target_names(stmt) _invalidate_class_bindings(class_bindings, names) for name in names: env.pop(name, None) - if isinstance(stmt, ast.ClassDef): - if class_bindings is not None: - class_bindings[stmt.name] = (stmt, dict(env)) - env.pop(stmt.name, None) - continue - if isinstance(stmt, ast.Assign): - names = _assignment_target_names(stmt) - _invalidate_class_bindings(class_bindings, names) - if len(stmt.targets) == 1 and isinstance(stmt.targets[0], ast.Name): - name = stmt.targets[0].id - if ( - isinstance(stmt.value, ast.Name) - and stmt.value.id in env - and _is_mutable_static_value(env[stmt.value.id]) - ): - env.pop(stmt.value.id, None) - env.pop(name, None) - continue - try: - env[name] = _literal(stmt.value, env) - except UnsupportedStaticExpression: - env.pop(name, None) - else: - for name in names: - env.pop(name, None) - continue - if isinstance(stmt, ast.AnnAssign): - names = _assignment_target_names(stmt) - _invalidate_class_bindings(class_bindings, names) - if stmt.value is None: - continue - if isinstance(stmt.target, ast.Name): - name = stmt.target.id - if ( - isinstance(stmt.value, ast.Name) - and stmt.value.id in env - and _is_mutable_static_value(env[stmt.value.id]) - ): - env.pop(stmt.value.id, None) - env.pop(name, None) - continue - try: - env[name] = _literal(stmt.value, env) - except UnsupportedStaticExpression: - env.pop(name, None) - else: - for name in names: - env.pop(name, None) - continue - if isinstance(stmt, ast.AugAssign): - names = _assignment_target_names(stmt) - _invalidate_class_bindings(class_bindings, names) - for name in names: - env.pop(name, None) - continue - if isinstance(stmt, ast.Delete): - names = _delete_target_names(stmt) - _invalidate_class_bindings(class_bindings, names) - for name in names: - env.pop(name, None) - continue - if isinstance(stmt, ast.Expr): - names = _mutating_call_target_names(stmt) - _invalidate_class_bindings(class_bindings, names) - for name in names: - env.pop(name, None) - names = _bound_names(stmt) - _invalidate_class_bindings(class_bindings, names) - for name in names: - env.pop(name, None) - continue - if isinstance(stmt, (ast.If, ast.For, ast.AsyncFor, ast.While, ast.Try, ast.With, ast.AsyncWith, ast.Match)): - if _has_wildcard_import_in_control_flow(stmt): - env.clear() - if class_bindings is not None: - class_bindings.clear() - continue - names = _assigned_names_in_control_flow(stmt) - _invalidate_class_bindings(class_bindings, names) - for name in names: - env.pop(name, None) - continue - if _has_wildcard_import(stmt): - env.clear() - if class_bindings is not None: - class_bindings.clear() - continue + return + if isinstance(stmt, ast.Delete): + names = _delete_target_names(stmt) + _invalidate_class_bindings(class_bindings, names) + for name in names: + env.pop(name, None) + return + if isinstance(stmt, ast.Expr): names = _bound_names(stmt) _invalidate_class_bindings(class_bindings, names) for name in names: env.pop(name, None) + return + if isinstance(stmt, (ast.If, ast.For, ast.AsyncFor, ast.While, ast.Try, ast.With, ast.AsyncWith, ast.Match)): + if _has_wildcard_import_in_control_flow(stmt): + env.clear() + if class_bindings is not None: + class_bindings.clear() + return + names = _assigned_names_in_control_flow(stmt) + _invalidate_class_bindings(class_bindings, names) + for name in names: + env.pop(name, None) + return + if _has_wildcard_import(stmt): + env.clear() + if class_bindings is not None: + class_bindings.clear() + return + names = _bound_names(stmt) + _invalidate_class_bindings(class_bindings, names) + for name in names: + env.pop(name, None) + + +def _collect_module_env(tree, class_bindings=None): + env = {} + for stmt in tree.body: + _apply_module_stmt_to_env(stmt, env, class_bindings) return env @@ -588,22 +588,24 @@ def _name_is_assigned(stmt, name): return name in _assignment_target_names(stmt) -def _module_dict_entries(node, env, value_converter): +def _module_dict_entries(node, env, class_bindings, value_converter): if not isinstance(node, ast.Dict): raise UnsupportedStaticExpression(type(node).__name__) result = {} for key, value in zip(node.keys, node.values): if key is None: raise UnsupportedStaticExpression("dict unpacking is not supported") - converted_value = value_converter(value) + converted_value = value_converter(value, env, class_bindings) if converted_value is None: continue result[_literal(key, env)] = converted_value return result -def _final_module_dict(tree, env, name, value_converter): +def _final_module_dict(tree, name, value_converter): value = _MISSING + env = {} + class_bindings = {} for stmt in tree.body: if name in _mutating_call_target_names(stmt): value = _INVALID @@ -611,67 +613,88 @@ def _final_module_dict(tree, env, name, value_converter): if not _name_is_assigned(stmt, name): if isinstance(stmt.value, ast.Name) and stmt.value.id == name: value = _INVALID + _apply_module_stmt_to_env(stmt, env, class_bindings) continue if len(stmt.targets) == 1 and isinstance(stmt.targets[0], ast.Name): try: - value = _module_dict_entries(stmt.value, env, value_converter) + value = _module_dict_entries(stmt.value, env, class_bindings, value_converter) except UnsupportedStaticExpression: value = _INVALID else: value = _INVALID + _apply_module_stmt_to_env(stmt, env, class_bindings) continue if isinstance(stmt, ast.AnnAssign): if not _name_is_assigned(stmt, name): if isinstance(stmt.value, ast.Name) and stmt.value.id == name: value = _INVALID + _apply_module_stmt_to_env(stmt, env, class_bindings) continue if isinstance(stmt.target, ast.Name) and stmt.value is not None: try: - value = _module_dict_entries(stmt.value, env, value_converter) + value = _module_dict_entries(stmt.value, env, class_bindings, value_converter) except UnsupportedStaticExpression: value = _INVALID else: value = _INVALID + _apply_module_stmt_to_env(stmt, env, class_bindings) continue if isinstance(stmt, ast.AugAssign): if _name_is_assigned(stmt, name): value = _INVALID + _apply_module_stmt_to_env(stmt, env, class_bindings) continue if isinstance(stmt, ast.Delete): if name in _delete_target_names(stmt): value = _INVALID + _apply_module_stmt_to_env(stmt, env, class_bindings) continue if isinstance(stmt, ast.Expr): if name in _mutating_call_target_names(stmt): value = _INVALID if name in _bound_names(stmt): value = _INVALID + _apply_module_stmt_to_env(stmt, env, class_bindings) continue if isinstance(stmt, (ast.If, ast.For, ast.AsyncFor, ast.While, ast.Try, ast.With, ast.AsyncWith, ast.Match)): if name in _assigned_names_in_control_flow(stmt): value = _INVALID if _has_wildcard_import_in_control_flow(stmt): value = _INVALID + _apply_module_stmt_to_env(stmt, env, class_bindings) continue if _has_wildcard_import(stmt): value = _INVALID + _apply_module_stmt_to_env(stmt, env, class_bindings) continue if name in _bound_names(stmt): value = _INVALID + _apply_module_stmt_to_env(stmt, env, class_bindings) if value in (_MISSING, _INVALID): return {} return value -def _node_class_mappings(tree, env): +def _mapping_value_binding(value, env, class_bindings): + class_name = _mapping_value_name(value) + if class_name is None: + return None + return class_bindings.get(class_name) + + +def _node_class_mappings(tree): if _has_module_wildcard_import(tree): return {} - mappings = _final_module_dict(tree, env, "NODE_CLASS_MAPPINGS", _mapping_value_name) - return {str(node_type): class_name for node_type, class_name in mappings.items() if node_type and class_name} + mappings = _final_module_dict(tree, "NODE_CLASS_MAPPINGS", _mapping_value_binding) + return {str(node_type): binding for node_type, binding in mappings.items() if node_type and binding is not None} -def _display_mappings(tree, env): - displays = _final_module_dict(tree, env, "NODE_DISPLAY_NAME_MAPPINGS", lambda value: _literal(value, env)) +def _display_mappings(tree): + displays = _final_module_dict( + tree, + "NODE_DISPLAY_NAME_MAPPINGS", + lambda value, env, _class_bindings: _literal(value, env), + ) return {str(k): str(v) for k, v in displays.items()} @@ -740,14 +763,10 @@ def extract_repo_signatures(repo_dir, pack_meta): tree = _parse_python_file(path) if tree is None: continue - class_bindings = {} - env = _collect_module_env(tree, class_bindings) - mappings = _node_class_mappings(tree, env) - displays = _display_mappings(tree, env) - for node_type, class_name in sorted(mappings.items()): - binding = class_bindings.get(class_name) - if binding is None: - continue + env = _collect_module_env(tree) + mappings = _node_class_mappings(tree) + displays = _display_mappings(tree) + for node_type, binding in sorted(mappings.items()): cls, class_env = binding sig = _signature_from_class(node_type, cls, displays.get(node_type), pack_meta, class_env, env) if sig is not None: