diff --git a/__init__.py b/__init__.py index 58614ca..fe42bda 100644 --- a/__init__.py +++ b/__init__.py @@ -397,10 +397,6 @@ try: from .loop_nodes import ( LOOP_NODE_CLASS_MAPPINGS, LOOP_NODE_DISPLAY_NAME_MAPPINGS, - accumulator_delete_entries, - accumulator_list_entries, - accumulator_move_entry, - accumulator_save_entries, ) from .node_camera import ( NODE_CLASS_MAPPINGS as CAMERA_NODE_CLASS_MAPPINGS, @@ -438,17 +434,17 @@ try: NODE_CLASS_MAPPINGS as SEED_RESOLUTION_NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS as SEED_RESOLUTION_NODE_DISPLAY_NAME_MAPPINGS, ) - from .prompt_builder import ( - save_character_profile_payload, + from .server_routes import ( + accumulator_delete_payload, + accumulator_list_payload, + accumulator_move_payload, + accumulator_save_payload, + profile_save_cached_payload, ) except ImportError: from loop_nodes import ( LOOP_NODE_CLASS_MAPPINGS, LOOP_NODE_DISPLAY_NAME_MAPPINGS, - accumulator_delete_entries, - accumulator_list_entries, - accumulator_move_entry, - accumulator_save_entries, ) from node_camera import ( NODE_CLASS_MAPPINGS as CAMERA_NODE_CLASS_MAPPINGS, @@ -486,8 +482,12 @@ except ImportError: NODE_CLASS_MAPPINGS as SEED_RESOLUTION_NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS as SEED_RESOLUTION_NODE_DISPLAY_NAME_MAPPINGS, ) - from prompt_builder import ( - save_character_profile_payload, + from server_routes import ( + accumulator_delete_payload, + accumulator_list_payload, + accumulator_move_payload, + accumulator_save_payload, + profile_save_cached_payload, ) @@ -496,10 +496,7 @@ if PromptServer is not None and web is not None: async def sxcp_save_cached_profile(request): try: payload = await request.json() - result = save_character_profile_payload( - profile_name=str(payload.get("profile_name") or ""), - profile_json=payload.get("profile_json") or "", - ) + result = profile_save_cached_payload(payload) return web.json_response(result) except Exception as exc: return web.json_response({"error": str(exc)}, status=400) @@ -508,10 +505,7 @@ if PromptServer is not None and web is not None: async def sxcp_accumulator_list(request): try: payload = await request.json() - result = accumulator_list_entries( - str(payload.get("store_key") or ""), - preview_limit=int(payload.get("preview_limit") or 0), - ) + result = accumulator_list_payload(payload) return web.json_response(result) except Exception as exc: return web.json_response({"error": str(exc)}, status=400) @@ -520,14 +514,7 @@ if PromptServer is not None and web is not None: async def sxcp_accumulator_delete(request): try: payload = await request.json() - result = accumulator_delete_entries( - store_key=str(payload.get("store_key") or ""), - preview_key=str(payload.get("preview_key") or ""), - entry_id=str(payload.get("entry_id") or ""), - index=int(payload.get("index") or 0), - clear=bool(payload.get("clear")), - preview_limit=int(payload.get("preview_limit") or 0), - ) + result = accumulator_delete_payload(payload) return web.json_response(result) except Exception as exc: return web.json_response({"error": str(exc)}, status=400) @@ -536,13 +523,7 @@ if PromptServer is not None and web is not None: async def sxcp_accumulator_save(request): try: payload = await request.json() - result = accumulator_save_entries( - store_key=str(payload.get("store_key") or ""), - save_path=str(payload.get("save_path") or "sxcp_accumulator"), - filename_prefix=str(payload.get("filename_prefix") or "sxcp_accum"), - clear_after_save=bool(payload.get("clear_after_save")), - preview_limit=int(payload.get("preview_limit") or 0), - ) + result = accumulator_save_payload(payload) return web.json_response(result) except Exception as exc: return web.json_response({"error": str(exc)}, status=400) @@ -551,15 +532,7 @@ if PromptServer is not None and web is not None: async def sxcp_accumulator_move(request): try: payload = await request.json() - result = accumulator_move_entry( - store_key=str(payload.get("store_key") or ""), - preview_key=str(payload.get("preview_key") or ""), - entry_id=str(payload.get("entry_id") or ""), - index=int(payload.get("index") or 0), - direction=str(payload.get("direction") or "up"), - target_index=int(payload.get("target_index") or 0), - preview_limit=int(payload.get("preview_limit") or 0), - ) + result = accumulator_move_payload(payload) return web.json_response(result) except Exception as exc: return web.json_response({"error": str(exc)}, status=400) diff --git a/docs/prompt-architecture-improvement-plan.md b/docs/prompt-architecture-improvement-plan.md index aef4b10..c00c471 100644 --- a/docs/prompt-architecture-improvement-plan.md +++ b/docs/prompt-architecture-improvement-plan.md @@ -416,13 +416,17 @@ Already isolated: by `__init__.py`. - generation profile, advanced filter, and ethnicity list utility nodes live in `node_profile_filter.py`, with registration maps imported by `__init__.py`. +- profile-save and accumulator server payload handling lives in + `server_routes.py`; `__init__.py` only wires those pure handlers to ComfyUI + JSON responses, and `tools/prompt_smoke.py` covers the handlers without + importing ComfyUI. Improve later: - split remaining large node classes into files by family; - keep node display names, return names, and docs in sync through the audit helper; -- add small endpoint tests for profile/accumulator/index-switch routes. +- add more endpoint tests when new server routes are introduced. ## Path-Specific Improvements diff --git a/docs/prompt-pool-routing-map.md b/docs/prompt-pool-routing-map.md index 640bc49..e03c927 100644 --- a/docs/prompt-pool-routing-map.md +++ b/docs/prompt-pool-routing-map.md @@ -98,6 +98,7 @@ Core helper ownership: | `prompt_hygiene.py` | Generic prompt, caption, and negative-prompt cleanup. | | `row_normalization.py` | Final prompt-row and pair metadata normalization: trigger prepending, extra-positive append, negative merge/dedupe, caption-part joining, embedded soft/hard row output and side-metadata synchronization, and embedded row sanitation. | | `formatter_input.py` | Shared formatter input parsing: text cleanup, metadata/source JSON detection, trigger-prefix stripping, shared prompt field-label inventory, fallback field-label stripping, `Avoid:` splitting, prompt-field extraction, and metadata row-value fallback. | +| `server_routes.py` | Pure payload handlers for profile-save and accumulator server endpoints, used by ComfyUI routes and smoke tests without importing ComfyUI. | | `sdxl_presets.py` | SDXL formatter profiles, style presets, quality presets, default negative prompt, and metadata-family tag hints used by the SDXL formatter and node choice lists. | | `caption_policy.py` | Caption naturalizer policy data and helpers: caption profiles, style tails, item labels, metadata-family caption labels, detail/style-policy normalization, clothing cleanup, and composition cleanup. | @@ -813,6 +814,9 @@ pair metadata through the core Python APIs, then verifies: - static formatter metadata fixtures keep source-provided action families stable across Krea2 prose, SDXL tags, and natural captions even when raw item text contains distracting wording. +- profile-save and accumulator endpoint payload handlers are smoke-tested + without importing ComfyUI, and the reversible index switch keeps pick/input + and route/output behavior stable. ## Editing Cheatsheet diff --git a/server_routes.py b/server_routes.py new file mode 100644 index 0000000..5c3f928 --- /dev/null +++ b/server_routes.py @@ -0,0 +1,76 @@ +from __future__ import annotations + +from typing import Any + +try: + from .loop_nodes import ( + accumulator_delete_entries, + accumulator_list_entries, + accumulator_move_entry, + accumulator_save_entries, + ) + from .prompt_builder import save_character_profile_payload +except ImportError: # Allows local smoke tests from the repository root. + from loop_nodes import ( + accumulator_delete_entries, + accumulator_list_entries, + accumulator_move_entry, + accumulator_save_entries, + ) + from prompt_builder import save_character_profile_payload + + +def _payload(payload: Any) -> dict[str, Any]: + return payload if isinstance(payload, dict) else {} + + +def profile_save_cached_payload(payload: Any) -> dict[str, Any]: + data = _payload(payload) + return save_character_profile_payload( + profile_name=str(data.get("profile_name") or ""), + profile_json=data.get("profile_json") or "", + ) + + +def accumulator_list_payload(payload: Any) -> dict[str, Any]: + data = _payload(payload) + return accumulator_list_entries( + str(data.get("store_key") or ""), + preview_limit=int(data.get("preview_limit") or 0), + ) + + +def accumulator_delete_payload(payload: Any) -> dict[str, Any]: + data = _payload(payload) + return accumulator_delete_entries( + store_key=str(data.get("store_key") or ""), + preview_key=str(data.get("preview_key") or ""), + entry_id=str(data.get("entry_id") or ""), + index=int(data.get("index") or 0), + clear=bool(data.get("clear")), + preview_limit=int(data.get("preview_limit") or 0), + ) + + +def accumulator_save_payload(payload: Any) -> dict[str, Any]: + data = _payload(payload) + return accumulator_save_entries( + store_key=str(data.get("store_key") or ""), + save_path=str(data.get("save_path") or "sxcp_accumulator"), + filename_prefix=str(data.get("filename_prefix") or "sxcp_accum"), + clear_after_save=bool(data.get("clear_after_save")), + preview_limit=int(data.get("preview_limit") or 0), + ) + + +def accumulator_move_payload(payload: Any) -> dict[str, Any]: + data = _payload(payload) + return accumulator_move_entry( + store_key=str(data.get("store_key") or ""), + preview_key=str(data.get("preview_key") or ""), + entry_id=str(data.get("entry_id") or ""), + index=int(data.get("index") or 0), + direction=str(data.get("direction") or "up"), + target_index=int(data.get("target_index") or 0), + preview_limit=int(data.get("preview_limit") or 0), + ) diff --git a/tools/prompt_smoke.py b/tools/prompt_smoke.py index 3acffd1..762e29b 100644 --- a/tools/prompt_smoke.py +++ b/tools/prompt_smoke.py @@ -14,6 +14,7 @@ import json import random import re import sys +import tempfile from dataclasses import dataclass, field from pathlib import Path from typing import Any, Callable @@ -38,9 +39,11 @@ import generation_profile_config # noqa: E402 import krea_cast # noqa: E402 import krea_formatter # noqa: E402 import location_config # noqa: E402 +import loop_nodes # noqa: E402 import prompt_builder as pb # noqa: E402 import row_normalization # noqa: E402 import route_metadata # noqa: E402 +import server_routes # noqa: E402 import sdxl_formatter # noqa: E402 import sdxl_presets # noqa: E402 import seed_config # noqa: E402 @@ -2643,6 +2646,74 @@ def smoke_node_utility_registration() -> None: _expect(krea_config.get("width") == krea_width and krea_config.get("height") == krea_height, "Krea2 config_json dimensions mismatch") +def smoke_server_route_payload_policy() -> None: + switch = loop_nodes.SxCPIndexSwitch() + picked = switch.switch( + 2, + "pick_input", + "one_based", + "fallback", + input_1="first", + input_2="second", + fallback="fallback", + ) + _expect(picked[0] == "second", "Index Switch pick_input did not select the requested input") + _expect(picked[1] == 2, "Index Switch pick_input selected_index changed") + _expect("selected=input_2" in picked[2], "Index Switch pick_input status lost selected input") + + routed = switch.switch(3, "route_output", "one_based", "fallback", route_value="routed") + _expect(routed[0] == "routed", "Index Switch route_output primary value changed") + _expect(routed[1] == 3, "Index Switch route_output selected_index changed") + _expect(routed[5] == "routed", "Index Switch route_output did not route value to output_3") + + key = "smoke_route_payload" + loop_nodes._ACCUMULATOR_STORES[key] = [ + {"id": "first", "value": "alpha", "_sxcp_preview_key": "first-key"}, + {"id": "second", "value": "beta", "_sxcp_preview_key": "second-key"}, + ] + try: + listed = server_routes.accumulator_list_payload({"store_key": key, "preview_limit": "0"}) + _expect(listed.get("count") == 2, "Accumulator list payload lost stored entries") + _expect(listed["entries"][0].get("value") == "alpha", "Accumulator list payload lost value summary") + + moved = server_routes.accumulator_move_payload({"store_key": key, "entry_id": "second", "target_index": "1"}) + _expect(moved.get("moved") is True, "Accumulator move payload did not report movement") + _expect(moved.get("from_index") == 2 and moved.get("to_index") == 1, "Accumulator move payload changed indices") + _expect(moved["entries"][0].get("id") == "second", "Accumulator move payload did not reorder entries") + + deleted = server_routes.accumulator_delete_payload({"store_key": key, "preview_key": "first-key"}) + _expect(deleted.get("removed") == 1, "Accumulator delete payload did not remove by preview key") + _expect(deleted.get("count") == 1, "Accumulator delete payload count changed") + + cleared = server_routes.accumulator_delete_payload({"store_key": key, "clear": True}) + _expect(cleared.get("removed") == 1 and cleared.get("count") == 0, "Accumulator clear payload changed") + finally: + loop_nodes._ACCUMULATOR_STORES.pop(key, None) + + with tempfile.TemporaryDirectory() as tmpdir: + previous_profile_dir = character_profile.PROFILE_DIR + character_profile.PROFILE_DIR = Path(tmpdir) + try: + profile = character_profile.build_character_profile_json( + profile_name="route source", + source="manual", + subject_type="woman", + age="28-year-old adult", + body="slim", + hair="long black hair", + save_now=False, + ) + saved = server_routes.profile_save_cached_payload( + {"profile_name": "Route Save!*", "profile_json": profile["profile_json"]} + ) + saved_path = Path(saved.get("saved_path") or "") + _expect(saved.get("status") == "saved", "Profile save payload did not save") + _expect(saved.get("profile_name") == "Route_Save", "Profile save payload did not sanitize requested name") + _expect(saved_path.exists(), "Profile save payload did not write profile file") + finally: + character_profile.PROFILE_DIR = previous_profile_dir + + def smoke_seed_config_policy() -> None: _expect(pb.SEED_AXIS_SALTS is seed_config.SEED_AXIS_SALTS, "prompt_builder seed salts should delegate to seed_config") _expect(pb.seed_mode_choices() == seed_config.seed_mode_choices(), "seed mode choices drifted from seed_config") @@ -3326,6 +3397,7 @@ SMOKE_CASES: list[tuple[str, Callable[[], None]]] = [ ("expression_disabled", smoke_no_expression_fallback), ("formatter_metadata_fixtures", smoke_formatter_metadata_fixtures), ("node_utility_registration", smoke_node_utility_registration), + ("server_route_payload_policy", smoke_server_route_payload_policy), ("seed_config_policy", smoke_seed_config_policy), ("node_camera_registration", smoke_node_camera_registration), ("node_route_config_registration", smoke_node_route_config_registration),