From de1d23fb378a6bf76db3573b6e7425044b76b300 Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Sat, 27 Jun 2026 02:05:53 +0200 Subject: [PATCH] Extract item template metadata policy --- category_template_metadata.py | 88 ++++++++++++++++++++ docs/prompt-architecture-improvement-plan.md | 3 + docs/prompt-pool-routing-map.md | 1 + prompt_builder.py | 43 ++-------- tools/prompt_map_audit.py | 9 ++ tools/prompt_smoke.py | 18 ++++ 6 files changed, 128 insertions(+), 34 deletions(-) create mode 100644 category_template_metadata.py diff --git a/category_template_metadata.py b/category_template_metadata.py new file mode 100644 index 0000000..17cebd4 --- /dev/null +++ b/category_template_metadata.py @@ -0,0 +1,88 @@ +from __future__ import annotations + +import re +from typing import Any + +try: + from .hardcore_action_metadata import normalize_hardcore_action_family + from .hardcore_position_config import normalize_hardcore_position_family, normalize_hardcore_position_values +except ImportError: # Allows local smoke tests from the repository root. + from hardcore_action_metadata import normalize_hardcore_action_family + from hardcore_position_config import normalize_hardcore_position_family, normalize_hardcore_position_values + + +TEMPLATE_METADATA_KEYS = ( + "action_family", + "action_type", + "family", + "position_family", + "position_key", + "position_keys", + "formatter_hint", +) + + +def template_metadata(item: Any) -> dict[str, Any]: + if not isinstance(item, dict): + return {} + return {key: item[key] for key in TEMPLATE_METADATA_KEYS if key in item} + + +def template_position_family(metadata: dict[str, Any]) -> str: + return normalize_hardcore_position_family( + metadata.get("position_family") or metadata.get("family"), + "", + ) + + +def template_position_keys(metadata: dict[str, Any]) -> list[str]: + keys: list[Any] = [] + if metadata.get("position_keys") is not None: + raw_keys = metadata.get("position_keys") + keys.extend(raw_keys if isinstance(raw_keys, list) else [raw_keys]) + if metadata.get("position_key") is not None: + keys.append(metadata.get("position_key")) + return normalize_hardcore_position_values(keys) + + +def template_action_family(metadata: dict[str, Any]) -> str: + return normalize_hardcore_action_family(metadata.get("action_family") or metadata.get("action_type"), "") + + +def merge_position_keys(primary: list[str], fallback: list[str]) -> list[str]: + merged: list[str] = [] + for key in [*primary, *fallback]: + if key and key not in merged: + merged.append(key) + return merged + + +def _position_key_slug(value: Any) -> str: + return re.sub(r"[^a-z0-9]+", "_", str(value or "").strip().lower()).strip("_") + + +def template_metadata_errors(metadata: dict[str, Any]) -> list[str]: + errors: list[str] = [] + raw_action_family = metadata.get("action_family") or metadata.get("action_type") + if raw_action_family and not template_action_family(metadata): + errors.append(f"unknown action_family/action_type: {raw_action_family}") + raw_position_family = metadata.get("position_family") or metadata.get("family") + if raw_position_family and not template_position_family(metadata): + errors.append(f"unknown position_family/family: {raw_position_family}") + raw_position_keys = [] + if metadata.get("position_keys") is not None: + values = metadata.get("position_keys") + raw_position_keys.extend(values if isinstance(values, list) else [values]) + if metadata.get("position_key") is not None: + raw_position_keys.append(metadata.get("position_key")) + normalized_keys = template_position_keys(metadata) + invalid_keys = [ + str(value) + for value in raw_position_keys + if str(value or "").strip() + and str(value or "").strip() != "any" + and _position_key_slug(value) not in normalized_keys + ] + if invalid_keys: + errors.append("unknown position key(s): " + ", ".join(invalid_keys)) + return errors diff --git a/docs/prompt-architecture-improvement-plan.md b/docs/prompt-architecture-improvement-plan.md index 524a2e0..d1f7dcb 100644 --- a/docs/prompt-architecture-improvement-plan.md +++ b/docs/prompt-architecture-improvement-plan.md @@ -124,6 +124,9 @@ Already isolated: - JSON category loading, subcategory normalization, named scene/expression/ composition pool loading, cast compatibility filtering, exact subcategory lookup, and inheritance-based pool merging live in `category_library.py`. +- object-style item-template metadata extraction, action/position family + normalization, position-key normalization, and metadata audit errors live in + `category_template_metadata.py`. - category/cast route preset schemas, config JSON builders, choice lists, and parsers live in `category_cast_config.py`; `prompt_builder.py` keeps public delegate wrappers for existing nodes and tests. diff --git a/docs/prompt-pool-routing-map.md b/docs/prompt-pool-routing-map.md index 378886d..9b7bfda 100644 --- a/docs/prompt-pool-routing-map.md +++ b/docs/prompt-pool-routing-map.md @@ -68,6 +68,7 @@ Core helper ownership: | Python module | What it owns | | --- | --- | | `category_library.py` | JSON category loading, subcategory normalization, named scene/expression/composition pool loading, cast compatibility filtering, exact subcategory lookup, and inheritance-based pool merging. | +| `category_template_metadata.py` | Object-style item-template metadata extraction, action/position family normalization, position-key normalization, key merging, and audit validation errors. | | `category_cast_config.py` | Category preset and cast preset schemas, category/cast config JSON builders, choice lists, and config parsers used by route nodes. | | `camera_config.py` | Camera option schema, direct/orbit/Qwen camera JSON builders, camera config parsing, plain camera directive text, and camera caption labels. | | `character_config.py` | Character choice lists, descriptor detail/presence/slot-seed normalization, characteristic-list JSON builders/parsers, eye labels, hair config builders/parsers, and hair phrase helpers. | diff --git a/prompt_builder.py b/prompt_builder.py index 6f223f3..8a62938 100644 --- a/prompt_builder.py +++ b/prompt_builder.py @@ -24,6 +24,7 @@ try: template_list as _template_list, ) from . import camera_config as camera_policy + from . import category_template_metadata as item_template_policy from . import character_config as character_policy from . import character_profile as character_profile_policy from . import category_cast_config as category_cast_policy @@ -45,7 +46,7 @@ try: sanitize_hardcore_axis_values as _sanitize_hardcore_axis_values, sanitize_hardcore_environment_anchors as _sanitize_hardcore_environment_anchors, ) - from .hardcore_action_metadata import normalize_hardcore_action_family, source_hardcore_action_family + from .hardcore_action_metadata import source_hardcore_action_family from .hardcore_role_graphs import build_hardcore_role_graph except ImportError: # Allows local smoke tests with `python -c`. from category_library import ( @@ -64,6 +65,7 @@ except ImportError: # Allows local smoke tests with `python -c`. template_list as _template_list, ) import camera_config as camera_policy + import category_template_metadata as item_template_policy import character_config as character_policy import character_profile as character_profile_policy import category_cast_config as category_cast_policy @@ -85,7 +87,7 @@ except ImportError: # Allows local smoke tests with `python -c`. sanitize_hardcore_axis_values as _sanitize_hardcore_axis_values, sanitize_hardcore_environment_anchors as _sanitize_hardcore_environment_anchors, ) - from hardcore_action_metadata import normalize_hardcore_action_family, source_hardcore_action_family + from hardcore_action_metadata import source_hardcore_action_family from hardcore_role_graphs import build_hardcore_role_graph @@ -301,50 +303,23 @@ def _item_name(item: Any) -> str: def _template_metadata(item: Any) -> dict[str, Any]: - if not isinstance(item, dict): - return {} - metadata: dict[str, Any] = {} - for key in ( - "action_family", - "action_type", - "family", - "position_family", - "position_key", - "position_keys", - "formatter_hint", - ): - if key in item: - metadata[key] = item[key] - return metadata + return item_template_policy.template_metadata(item) def _template_position_family(metadata: dict[str, Any]) -> str: - return _normalize_hardcore_position_family( - metadata.get("position_family") or metadata.get("family"), - "", - ) + return item_template_policy.template_position_family(metadata) def _template_position_keys(metadata: dict[str, Any]) -> list[str]: - keys: list[Any] = [] - if metadata.get("position_keys") is not None: - raw_keys = metadata.get("position_keys") - keys.extend(raw_keys if isinstance(raw_keys, list) else [raw_keys]) - if metadata.get("position_key") is not None: - keys.append(metadata.get("position_key")) - return _normalize_hardcore_position_values(keys) + return item_template_policy.template_position_keys(metadata) def _template_action_family(metadata: dict[str, Any]) -> str: - return normalize_hardcore_action_family(metadata.get("action_family") or metadata.get("action_type"), "") + return item_template_policy.template_action_family(metadata) def _merge_position_keys(primary: list[str], fallback: list[str]) -> list[str]: - merged: list[str] = [] - for key in [*primary, *fallback]: - if key and key not in merged: - merged.append(key) - return merged + return item_template_policy.merge_position_keys(primary, fallback) def _oral_acts_for_position(values: list[Any], position: str) -> list[Any]: diff --git a/tools/prompt_map_audit.py b/tools/prompt_map_audit.py index 85f9fe5..26939bc 100644 --- a/tools/prompt_map_audit.py +++ b/tools/prompt_map_audit.py @@ -10,11 +10,17 @@ from __future__ import annotations import ast import json import re +import sys from pathlib import Path from typing import Any ROOT = Path(__file__).resolve().parents[1] +if str(ROOT) not in sys.path: + sys.path.insert(0, str(ROOT)) + +import category_template_metadata as template_metadata_policy # noqa: E402 + POOL_DEFINITION_KEYS = ("scene_pools", "expression_pools", "composition_pools") POOL_REFERENCE_KEYS = { "scene_pool": "scene_pools", @@ -187,6 +193,9 @@ def _template_axis_errors(path: str, node: dict[str, Any]) -> list[tuple[str, st or template.get("name") or "" ).strip() + metadata = template_metadata_policy.template_metadata(template) + for issue in template_metadata_policy.template_metadata_errors(metadata): + errors.append((template_path, issue)) elif isinstance(template, str): template_text = template else: diff --git a/tools/prompt_smoke.py b/tools/prompt_smoke.py index 2cc44eb..d3e2497 100644 --- a/tools/prompt_smoke.py +++ b/tools/prompt_smoke.py @@ -25,6 +25,7 @@ if str(ROOT) not in sys.path: import caption_naturalizer # noqa: E402 import caption_policy # noqa: E402 +import category_template_metadata # noqa: E402 import character_config # noqa: E402 import character_profile # noqa: E402 import category_cast_config # noqa: E402 @@ -1212,6 +1213,23 @@ def smoke_hardcore_position_config_policy() -> None: _expect(pb._template_position_family(template_metadata) == "oral", "Template metadata route lost position family") _expect(pb._template_position_keys(template_metadata) == ["kneeling", "open_thighs"], "Template metadata route lost position keys") _expect(pb._template_action_family(template_metadata) == "oral", "Template metadata route lost normalized action family") + _expect( + pb._template_action_family(template_metadata) == category_template_metadata.template_action_family(template_metadata), + "Prompt builder template action policy should delegate", + ) + _expect( + category_template_metadata.template_metadata_errors(template_metadata) == [], + "Valid template metadata should not report audit errors", + ) + invalid_metadata = { + "action_family": "bad_action", + "position_family": "bad_family", + "position_keys": ["kneeling", "bad_position"], + } + invalid_errors = category_template_metadata.template_metadata_errors(invalid_metadata) + _expect(any("bad_action" in error for error in invalid_errors), "Template metadata validation missed bad action") + _expect(any("bad_family" in error for error in invalid_errors), "Template metadata validation missed bad family") + _expect(any("bad_position" in error for error in invalid_errors), "Template metadata validation missed bad position key") def smoke_category_library_route() -> None: