#!/usr/bin/env python3 """Print a lightweight audit for the prompt routing map. This intentionally avoids importing the ComfyUI node package. It parses Python and JSON files directly, so it can run in a plain shell without ComfyUI loaded. """ from __future__ import annotations import ast import json import re import sys from pathlib import Path from typing import Any ROOT = Path(__file__).resolve().parents[1] if str(ROOT) not in sys.path: sys.path.insert(0, str(ROOT)) import category_template_metadata as template_metadata_policy # noqa: E402 POOL_DEFINITION_KEYS = ("scene_pools", "expression_pools", "composition_pools") POOL_REFERENCE_KEYS = { "scene_pool": "scene_pools", "scene_pools": "scene_pools", "expression_pool": "expression_pools", "expression_pools": "expression_pools", "composition_pool": "composition_pools", "composition_pools": "composition_pools", } TEMPLATE_TOKEN_RE = re.compile(r"{([a-zA-Z_][a-zA-Z0-9_]*)}") CRITICAL_ROUTE_MODULES: tuple[tuple[str, str], ...] = ( ("builder_prompt_route.py", "builder_prompt_route_policy"), ("builder_config_route.py", "builder_config_route_policy"), ("krea_format_route.py", "krea_format_route_policy"), ("sdxl_format_route.py", "sdxl_format_route_policy"), ("caption_format_route.py", "caption_format_route_policy"), ("pair_builder.py", "pair_builder_policy"), ("row_assembly.py", "row_assembly_policy"), ("row_category_route.py", "row_category_route_policy"), ("row_prompt_axes.py", "row_prompt_axes_policy"), ("row_route_metadata.py", "row_route_metadata_policy"), ("row_subject_route.py", "row_subject_route_policy"), ("caption_metadata_routes.py", "caption_metadata_routes"), ("sdxl_tag_routes.py", "sdxl_tag_routes"), ) ENTRY_ROUTE_SNIPPETS: tuple[str, ...] = ( "`build_prompt` -> `builder_prompt_route.py`", "`build_prompt_from_configs` -> `builder_config_route.py`", "`format_krea2_prompt` -> `krea_format_route.py`", "`format_sdxl_prompt` -> `sdxl_format_route.py`", "`naturalize_caption` -> `caption_format_route.py`", ) def _literal_or_none(node: ast.AST) -> Any: try: return ast.literal_eval(node) except Exception: return None def _assignment_dict(path: Path, name: str) -> dict[str, Any]: tree = ast.parse(path.read_text(encoding="utf-8")) for node in tree.body: if not isinstance(node, ast.Assign): continue if not any(isinstance(target, ast.Name) and target.id == name for target in node.targets): continue value = _literal_or_none(node.value) return value if isinstance(value, dict) else {} return {} def _class_return_names(path: Path) -> dict[str, tuple[str, ...]]: tree = ast.parse(path.read_text(encoding="utf-8")) classes: dict[str, tuple[list[str], tuple[str, ...]]] = {} for node in tree.body: if not isinstance(node, ast.ClassDef): continue bases = [base.id for base in node.bases if isinstance(base, ast.Name)] return_names: tuple[str, ...] = () for item in node.body: if not isinstance(item, ast.Assign): continue if not any(isinstance(target, ast.Name) and target.id == "RETURN_NAMES" for target in item.targets): continue value = _literal_or_none(item.value) if isinstance(value, tuple) and all(isinstance(part, str) for part in value): return_names = value classes[node.name] = (bases, return_names) def resolve(class_name: str, seen: set[str] | None = None) -> tuple[str, ...]: seen = seen or set() if class_name in seen: return () seen.add(class_name) bases, return_names = classes.get(class_name, ([], ())) if return_names: return return_names for base_name in bases: inherited = resolve(base_name, seen) if inherited: return inherited return () result: dict[str, tuple[str, ...]] = {} for class_name in classes: if class_name.startswith("SxCP"): return_names = resolve(class_name) if return_names: result[class_name] = return_names return result def _category_summary(path: Path) -> dict[str, Any]: data = json.loads(path.read_text(encoding="utf-8")) categories = data.get("categories") or [] subcategory_count = 0 item_template_count = 0 for category in categories: subcategories = category.get("subcategories") or [] subcategory_count += len(subcategories) for subcategory in subcategories: item_template_count += len(subcategory.get("item_templates") or []) for item in subcategory.get("items") or []: if isinstance(item, dict): item_template_count += len(item.get("item_templates") or []) return { "categories": len(categories), "subcategories": subcategory_count, "item_templates": item_template_count, "scene_pools": len(data.get("scene_pools") or {}), "expression_pools": len(data.get("expression_pools") or {}), "composition_pools": len(data.get("composition_pools") or {}), "pool_extensions": len(data.get("pool_extensions") or {}), } def _pool_names(path: Path, key: str) -> list[str]: data = json.loads(path.read_text(encoding="utf-8")) pools = data.get(key) or {} return sorted(pools) if isinstance(pools, dict) else [] def _category_json_paths() -> list[Path]: return sorted((ROOT / "categories").glob("*.json")) def _node_python_paths() -> list[Path]: paths = [ROOT / "__init__.py", ROOT / "loop_nodes.py"] paths.extend(sorted(ROOT.glob("node_*.py"))) return [path for path in paths if path.exists()] def _load_category_json(path: Path) -> dict[str, Any]: data = json.loads(path.read_text(encoding="utf-8")) return data if isinstance(data, dict) else {} def _all_pool_names(paths: list[Path]) -> dict[str, set[str]]: names = {key: set() for key in POOL_DEFINITION_KEYS} for path in paths: data = _load_category_json(path) for key in POOL_DEFINITION_KEYS: pools = data.get(key) if isinstance(pools, dict): names[key].update(str(name) for name in pools if str(name).strip()) return names def _pool_reference_values(value: Any) -> list[str]: if isinstance(value, str): return [value] if value.strip() else [] if isinstance(value, list): return [str(item) for item in value if str(item).strip()] return [] def _path_child(path: str, key: str, value: Any) -> str: label = key if isinstance(value, dict): name = str(value.get("name") or value.get("slug") or "").strip() if name: label = f"{key}({name})" return f"{path}.{label}" if path else label def _path_index(path: str, index: int, value: Any) -> str: label = f"[{index}]" if isinstance(value, dict): name = str(value.get("name") or value.get("slug") or "").strip() if name: label = f"[{index}:{name}]" return f"{path}{label}" def _template_axis_errors(path: str, node: dict[str, Any]) -> list[tuple[str, str]]: templates = node.get("item_templates") if not isinstance(templates, list): return [] axes = node.get("item_axes") axis_names = set(axes) if isinstance(axes, dict) else set() errors: list[tuple[str, str]] = [] for index, template in enumerate(templates): template_path = f"{path}.item_templates[{index}]" if isinstance(template, dict): template_text = str( template.get("template") or template.get("prompt") or template.get("text") or template.get("description") or template.get("name") or "" ).strip() metadata = template_metadata_policy.template_metadata(template) for issue in template_metadata_policy.template_metadata_errors(metadata): errors.append((template_path, issue)) elif isinstance(template, str): template_text = template else: template_text = "" if not template_text: errors.append((template_path, "template must be a string or object with template/text")) continue tokens = set(TEMPLATE_TOKEN_RE.findall(template_text)) missing = sorted(token for token in tokens if token not in axis_names) if missing: errors.append( ( template_path, "missing item_axes for placeholders: " + ", ".join(missing), ) ) if isinstance(axes, dict): for axis_name, values in axes.items(): if not isinstance(values, list) or not values: errors.append((f"{path}.item_axes.{axis_name}", "axis must be a non-empty list")) return errors def _walk_json_references( value: Any, *, file_name: str, path: str, defined_pools: dict[str, set[str]], errors: list[tuple[str, str, str]], at_root: bool = False, ) -> None: if isinstance(value, dict): errors.extend((file_name, item_path, issue) for item_path, issue in _template_axis_errors(path, value)) for key, child in value.items(): if at_root and key in POOL_DEFINITION_KEYS and isinstance(child, dict): for pool_name, pool_values in child.items(): if not isinstance(pool_values, list) or not pool_values: errors.append((file_name, f"{key}.{pool_name}", "pool must be a non-empty list")) continue pool_type = POOL_REFERENCE_KEYS.get(key) if pool_type: refs = _pool_reference_values(child) if child and not refs: errors.append((file_name, _path_child(path, key, child), "pool reference must be a string or list")) for ref in refs: if ref not in defined_pools[pool_type]: errors.append( ( file_name, _path_child(path, key, child), f"unknown {pool_type[:-1]} reference: {ref}", ) ) _walk_json_references( child, file_name=file_name, path=_path_child(path, key, child), defined_pools=defined_pools, errors=errors, ) elif isinstance(value, list): for index, child in enumerate(value): _walk_json_references( child, file_name=file_name, path=_path_index(path, index, child), defined_pools=defined_pools, errors=errors, ) def _json_reference_errors(paths: list[Path]) -> list[tuple[str, str, str]]: defined_pools = _all_pool_names(paths) errors: list[tuple[str, str, str]] = [] for pool_type, names in defined_pools.items(): if not names: errors.append(("(all)", pool_type, "no pools defined")) for path in paths: data = _load_category_json(path) _walk_json_references( data, file_name=path.name, path="$", defined_pools=defined_pools, errors=errors, at_root=True, ) return errors def _routing_doc_errors() -> list[tuple[str, str, str]]: docs = { "docs/prompt-pool-routing-map.md": ROOT / "docs" / "prompt-pool-routing-map.md", "docs/prompt-architecture-improvement-plan.md": ROOT / "docs" / "prompt-architecture-improvement-plan.md", } smoke_path = ROOT / "tools" / "prompt_smoke.py" smoke_text = smoke_path.read_text(encoding="utf-8") if smoke_path.exists() else "" errors: list[tuple[str, str, str]] = [] for module_name, smoke_case in CRITICAL_ROUTE_MODULES: if not (ROOT / module_name).exists(): errors.append((module_name, "module", "critical route module is missing")) for doc_name, doc_path in docs.items(): doc_text = doc_path.read_text(encoding="utf-8") if doc_path.exists() else "" if module_name not in doc_text: errors.append((module_name, doc_name, "critical route module is not documented")) if smoke_case and smoke_case not in smoke_text: errors.append((module_name, "tools/prompt_smoke.py", f"missing smoke case: {smoke_case}")) route_map_text = docs["docs/prompt-pool-routing-map.md"].read_text(encoding="utf-8") for snippet in ENTRY_ROUTE_SNIPPETS: if snippet not in route_map_text: errors.append(("(entry route)", "docs/prompt-pool-routing-map.md", f"missing entry snippet: {snippet}")) return errors def print_table(headers: tuple[str, ...], rows: list[tuple[Any, ...]]) -> None: widths = [len(header) for header in headers] for row in rows: for index, value in enumerate(row): widths[index] = max(widths[index], len(str(value))) print("| " + " | ".join(header.ljust(widths[index]) for index, header in enumerate(headers)) + " |") print("| " + " | ".join("-" * width for width in widths) + " |") for row in rows: print("| " + " | ".join(str(value).ljust(widths[index]) for index, value in enumerate(row)) + " |") def main() -> int: category_paths = _category_json_paths() display: dict[str, Any] = {} returns: dict[str, tuple[str, ...]] = {} for path in _node_python_paths(): display.update(_assignment_dict(path, "NODE_DISPLAY_NAME_MAPPINGS")) display.update(_assignment_dict(path, "LOOP_NODE_DISPLAY_NAME_MAPPINGS")) returns.update(_class_return_names(path)) print("# Node Display Map") node_rows = [] for class_name, display_name in sorted(display.items(), key=lambda item: str(item[1])): return_names = ", ".join(returns.get(class_name, ())) node_rows.append((display_name, class_name, return_names or "(dynamic or unnamed)")) print_table(("Display name", "Class", "Return names"), node_rows) print("\n# Category JSON Summary") category_rows = [] for path in category_paths: summary = _category_summary(path) category_rows.append( ( path.name, summary["categories"], summary["subcategories"], summary["item_templates"], summary["scene_pools"], summary["expression_pools"], summary["composition_pools"], summary["pool_extensions"], ) ) print_table( ( "File", "Categories", "Subcategories", "Item templates", "Scene pools", "Expression pools", "Composition pools", "Extensions", ), category_rows, ) print("\n# Named Pool Inventory") pool_rows = [] for path in category_paths: for key in ("scene_pools", "expression_pools", "composition_pools"): names = _pool_names(path, key) if names: pool_rows.append((path.name, key, len(names), ", ".join(names[:8]) + (" ..." if len(names) > 8 else ""))) print_table(("File", "Pool type", "Count", "First names"), pool_rows) print("\n# JSON Reference Validation") reference_errors = _json_reference_errors(category_paths) if reference_errors: print_table(("File", "Path", "Issue"), reference_errors) return 1 print("OK: all JSON pool references and item template axes resolve.") print("\n# Routing Documentation Validation") routing_doc_errors = _routing_doc_errors() if routing_doc_errors: print_table(("Module", "Location", "Issue"), routing_doc_errors) return 1 print("OK: critical route modules are documented and covered by smoke cases.") return 0 if __name__ == "__main__": raise SystemExit(main())