Files
ComfyUI-Ethanfel-Prompt-Bui…/tools/prompt_map_audit.py
T

452 lines
17 KiB
Python

#!/usr/bin/env python3
"""Print a lightweight audit for the prompt routing map.
This intentionally avoids importing the ComfyUI node package. It parses Python
and JSON files directly, so it can run in a plain shell without ComfyUI loaded.
"""
from __future__ import annotations
import ast
import json
import re
import sys
from pathlib import Path
from typing import Any
ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
import category_template_metadata as template_metadata_policy # noqa: E402
POOL_DEFINITION_KEYS = ("scene_pools", "expression_pools", "composition_pools")
POOL_REFERENCE_KEYS = {
"scene_pool": "scene_pools",
"scene_pools": "scene_pools",
"expression_pool": "expression_pools",
"expression_pools": "expression_pools",
"composition_pool": "composition_pools",
"composition_pools": "composition_pools",
}
TEMPLATE_TOKEN_RE = re.compile(r"{([a-zA-Z_][a-zA-Z0-9_]*)}")
CRITICAL_ROUTE_MODULES: tuple[tuple[str, str], ...] = (
("builder_prompt_route.py", "builder_prompt_route_policy"),
("builder_config_route.py", "builder_config_route_policy"),
("krea_format_route.py", "krea_format_route_policy"),
("sdxl_format_route.py", "sdxl_format_route_policy"),
("caption_format_route.py", "caption_format_route_policy"),
("pair_builder.py", "pair_builder_policy"),
("row_assembly.py", "row_assembly_policy"),
("row_category_route.py", "row_category_route_policy"),
("row_prompt_axes.py", "row_prompt_axes_policy"),
("row_route_metadata.py", "row_route_metadata_policy"),
("row_subject_route.py", "row_subject_route_policy"),
("caption_metadata_routes.py", "caption_metadata_routes"),
("sdxl_tag_routes.py", "sdxl_tag_routes"),
)
ENTRY_ROUTE_SNIPPETS: tuple[str, ...] = (
"`build_prompt` -> `builder_prompt_route.py`",
"`build_prompt_from_configs` -> `builder_config_route.py`",
"`format_krea2_prompt` -> `krea_format_route.py`",
"`format_sdxl_prompt` -> `sdxl_format_route.py`",
"`naturalize_caption` -> `caption_format_route.py`",
)
def _literal_or_none(node: ast.AST) -> Any:
try:
return ast.literal_eval(node)
except Exception:
return None
def _assignment_dict(path: Path, name: str) -> dict[str, Any]:
tree = ast.parse(path.read_text(encoding="utf-8"))
for node in tree.body:
if not isinstance(node, ast.Assign):
continue
if not any(isinstance(target, ast.Name) and target.id == name for target in node.targets):
continue
value = _literal_or_none(node.value)
return value if isinstance(value, dict) else {}
return {}
def _class_return_names(path: Path) -> dict[str, tuple[str, ...]]:
tree = ast.parse(path.read_text(encoding="utf-8"))
classes: dict[str, tuple[list[str], tuple[str, ...]]] = {}
for node in tree.body:
if not isinstance(node, ast.ClassDef):
continue
bases = [base.id for base in node.bases if isinstance(base, ast.Name)]
return_names: tuple[str, ...] = ()
for item in node.body:
if not isinstance(item, ast.Assign):
continue
if not any(isinstance(target, ast.Name) and target.id == "RETURN_NAMES" for target in item.targets):
continue
value = _literal_or_none(item.value)
if isinstance(value, tuple) and all(isinstance(part, str) for part in value):
return_names = value
classes[node.name] = (bases, return_names)
def resolve(class_name: str, seen: set[str] | None = None) -> tuple[str, ...]:
seen = seen or set()
if class_name in seen:
return ()
seen.add(class_name)
bases, return_names = classes.get(class_name, ([], ()))
if return_names:
return return_names
for base_name in bases:
inherited = resolve(base_name, seen)
if inherited:
return inherited
return ()
result: dict[str, tuple[str, ...]] = {}
for class_name in classes:
if class_name.startswith("SxCP"):
return_names = resolve(class_name)
if return_names:
result[class_name] = return_names
return result
def _category_summary(path: Path) -> dict[str, Any]:
data = json.loads(path.read_text(encoding="utf-8"))
categories = data.get("categories") or []
subcategory_count = 0
item_template_count = 0
for category in categories:
subcategories = category.get("subcategories") or []
subcategory_count += len(subcategories)
for subcategory in subcategories:
item_template_count += len(subcategory.get("item_templates") or [])
for item in subcategory.get("items") or []:
if isinstance(item, dict):
item_template_count += len(item.get("item_templates") or [])
return {
"categories": len(categories),
"subcategories": subcategory_count,
"item_templates": item_template_count,
"scene_pools": len(data.get("scene_pools") or {}),
"expression_pools": len(data.get("expression_pools") or {}),
"composition_pools": len(data.get("composition_pools") or {}),
"pool_extensions": len(data.get("pool_extensions") or {}),
}
def _pool_names(path: Path, key: str) -> list[str]:
data = json.loads(path.read_text(encoding="utf-8"))
pools = data.get(key) or {}
return sorted(pools) if isinstance(pools, dict) else []
def _category_json_paths() -> list[Path]:
return sorted((ROOT / "categories").glob("*.json"))
def _node_python_paths() -> list[Path]:
paths = [ROOT / "__init__.py", ROOT / "loop_nodes.py"]
paths.extend(sorted(ROOT.glob("node_*.py")))
return [path for path in paths if path.exists()]
def _load_category_json(path: Path) -> dict[str, Any]:
data = json.loads(path.read_text(encoding="utf-8"))
return data if isinstance(data, dict) else {}
def _all_pool_names(paths: list[Path]) -> dict[str, set[str]]:
names = {key: set() for key in POOL_DEFINITION_KEYS}
for path in paths:
data = _load_category_json(path)
for key in POOL_DEFINITION_KEYS:
pools = data.get(key)
if isinstance(pools, dict):
names[key].update(str(name) for name in pools if str(name).strip())
return names
def _pool_reference_values(value: Any) -> list[str]:
if isinstance(value, str):
return [value] if value.strip() else []
if isinstance(value, list):
return [str(item) for item in value if str(item).strip()]
return []
def _path_child(path: str, key: str, value: Any) -> str:
label = key
if isinstance(value, dict):
name = str(value.get("name") or value.get("slug") or "").strip()
if name:
label = f"{key}({name})"
return f"{path}.{label}" if path else label
def _path_index(path: str, index: int, value: Any) -> str:
label = f"[{index}]"
if isinstance(value, dict):
name = str(value.get("name") or value.get("slug") or "").strip()
if name:
label = f"[{index}:{name}]"
return f"{path}{label}"
def _template_axis_errors(path: str, node: dict[str, Any]) -> list[tuple[str, str]]:
templates = node.get("item_templates")
if not isinstance(templates, list):
return []
axes = node.get("item_axes")
axis_names = set(axes) if isinstance(axes, dict) else set()
errors: list[tuple[str, str]] = []
for index, template in enumerate(templates):
template_path = f"{path}.item_templates[{index}]"
if isinstance(template, dict):
template_text = str(
template.get("template")
or template.get("prompt")
or template.get("text")
or template.get("description")
or template.get("name")
or ""
).strip()
metadata = template_metadata_policy.template_metadata(template)
for issue in template_metadata_policy.template_metadata_errors(metadata):
errors.append((template_path, issue))
elif isinstance(template, str):
template_text = template
else:
template_text = ""
if not template_text:
errors.append((template_path, "template must be a string or object with template/text"))
continue
tokens = set(TEMPLATE_TOKEN_RE.findall(template_text))
missing = sorted(token for token in tokens if token not in axis_names)
if missing:
errors.append(
(
template_path,
"missing item_axes for placeholders: " + ", ".join(missing),
)
)
if isinstance(axes, dict):
for axis_name, values in axes.items():
if not isinstance(values, list) or not values:
errors.append((f"{path}.item_axes.{axis_name}", "axis must be a non-empty list"))
return errors
def _walk_json_references(
value: Any,
*,
file_name: str,
path: str,
defined_pools: dict[str, set[str]],
errors: list[tuple[str, str, str]],
at_root: bool = False,
) -> None:
if isinstance(value, dict):
errors.extend((file_name, item_path, issue) for item_path, issue in _template_axis_errors(path, value))
for key, child in value.items():
if at_root and key in POOL_DEFINITION_KEYS and isinstance(child, dict):
for pool_name, pool_values in child.items():
if not isinstance(pool_values, list) or not pool_values:
errors.append((file_name, f"{key}.{pool_name}", "pool must be a non-empty list"))
continue
pool_type = POOL_REFERENCE_KEYS.get(key)
if pool_type:
refs = _pool_reference_values(child)
if child and not refs:
errors.append((file_name, _path_child(path, key, child), "pool reference must be a string or list"))
for ref in refs:
if ref not in defined_pools[pool_type]:
errors.append(
(
file_name,
_path_child(path, key, child),
f"unknown {pool_type[:-1]} reference: {ref}",
)
)
_walk_json_references(
child,
file_name=file_name,
path=_path_child(path, key, child),
defined_pools=defined_pools,
errors=errors,
)
elif isinstance(value, list):
for index, child in enumerate(value):
_walk_json_references(
child,
file_name=file_name,
path=_path_index(path, index, child),
defined_pools=defined_pools,
errors=errors,
)
def _json_reference_errors(paths: list[Path]) -> list[tuple[str, str, str]]:
defined_pools = _all_pool_names(paths)
errors: list[tuple[str, str, str]] = []
for pool_type, names in defined_pools.items():
if not names:
errors.append(("(all)", pool_type, "no pools defined"))
for path in paths:
data = _load_category_json(path)
_walk_json_references(
data,
file_name=path.name,
path="$",
defined_pools=defined_pools,
errors=errors,
at_root=True,
)
return errors
def _smoke_case_names(path: Path) -> set[str]:
if not path.exists():
return set()
tree = ast.parse(path.read_text(encoding="utf-8"))
for node in tree.body:
value: ast.AST | None = None
if isinstance(node, ast.AnnAssign) and isinstance(node.target, ast.Name) and node.target.id == "SMOKE_CASES":
value = node.value
elif isinstance(node, ast.Assign) and any(
isinstance(target, ast.Name) and target.id == "SMOKE_CASES" for target in node.targets
):
value = node.value
if value is None:
continue
if not isinstance(value, (ast.List, ast.Tuple)):
return set()
names: set[str] = set()
for item in value.elts:
if not isinstance(item, (ast.List, ast.Tuple)) or not item.elts:
continue
first = item.elts[0]
if isinstance(first, ast.Constant) and isinstance(first.value, str):
names.add(first.value)
return names
return set()
def _routing_doc_errors() -> list[tuple[str, str, str]]:
docs = {
"docs/prompt-pool-routing-map.md": ROOT / "docs" / "prompt-pool-routing-map.md",
"docs/prompt-architecture-improvement-plan.md": ROOT / "docs" / "prompt-architecture-improvement-plan.md",
}
smoke_path = ROOT / "tools" / "prompt_smoke.py"
smoke_cases = _smoke_case_names(smoke_path)
errors: list[tuple[str, str, str]] = []
if not smoke_cases:
errors.append(("tools/prompt_smoke.py", "SMOKE_CASES", "no registered smoke cases found"))
for module_name, smoke_case in CRITICAL_ROUTE_MODULES:
if not (ROOT / module_name).exists():
errors.append((module_name, "module", "critical route module is missing"))
for doc_name, doc_path in docs.items():
doc_text = doc_path.read_text(encoding="utf-8") if doc_path.exists() else ""
if module_name not in doc_text:
errors.append((module_name, doc_name, "critical route module is not documented"))
if smoke_case and smoke_case not in smoke_cases:
errors.append((module_name, "tools/prompt_smoke.py", f"missing registered smoke case: {smoke_case}"))
route_map_text = docs["docs/prompt-pool-routing-map.md"].read_text(encoding="utf-8")
for snippet in ENTRY_ROUTE_SNIPPETS:
if snippet not in route_map_text:
errors.append(("(entry route)", "docs/prompt-pool-routing-map.md", f"missing entry snippet: {snippet}"))
return errors
def print_table(headers: tuple[str, ...], rows: list[tuple[Any, ...]]) -> None:
widths = [len(header) for header in headers]
for row in rows:
for index, value in enumerate(row):
widths[index] = max(widths[index], len(str(value)))
print("| " + " | ".join(header.ljust(widths[index]) for index, header in enumerate(headers)) + " |")
print("| " + " | ".join("-" * width for width in widths) + " |")
for row in rows:
print("| " + " | ".join(str(value).ljust(widths[index]) for index, value in enumerate(row)) + " |")
def main() -> int:
category_paths = _category_json_paths()
display: dict[str, Any] = {}
returns: dict[str, tuple[str, ...]] = {}
for path in _node_python_paths():
display.update(_assignment_dict(path, "NODE_DISPLAY_NAME_MAPPINGS"))
display.update(_assignment_dict(path, "LOOP_NODE_DISPLAY_NAME_MAPPINGS"))
returns.update(_class_return_names(path))
print("# Node Display Map")
node_rows = []
for class_name, display_name in sorted(display.items(), key=lambda item: str(item[1])):
return_names = ", ".join(returns.get(class_name, ()))
node_rows.append((display_name, class_name, return_names or "(dynamic or unnamed)"))
print_table(("Display name", "Class", "Return names"), node_rows)
print("\n# Category JSON Summary")
category_rows = []
for path in category_paths:
summary = _category_summary(path)
category_rows.append(
(
path.name,
summary["categories"],
summary["subcategories"],
summary["item_templates"],
summary["scene_pools"],
summary["expression_pools"],
summary["composition_pools"],
summary["pool_extensions"],
)
)
print_table(
(
"File",
"Categories",
"Subcategories",
"Item templates",
"Scene pools",
"Expression pools",
"Composition pools",
"Extensions",
),
category_rows,
)
print("\n# Named Pool Inventory")
pool_rows = []
for path in category_paths:
for key in ("scene_pools", "expression_pools", "composition_pools"):
names = _pool_names(path, key)
if names:
pool_rows.append((path.name, key, len(names), ", ".join(names[:8]) + (" ..." if len(names) > 8 else "")))
print_table(("File", "Pool type", "Count", "First names"), pool_rows)
print("\n# JSON Reference Validation")
reference_errors = _json_reference_errors(category_paths)
if reference_errors:
print_table(("File", "Path", "Issue"), reference_errors)
return 1
print("OK: all JSON pool references and item template axes resolve.")
print("\n# Routing Documentation Validation")
routing_doc_errors = _routing_doc_errors()
if routing_doc_errors:
print_table(("Module", "Location", "Issue"), routing_doc_errors)
return 1
print("OK: critical route modules are documented and covered by smoke cases.")
return 0
if __name__ == "__main__":
raise SystemExit(main())