Add builder generation trace metadata
This commit is contained in:
@@ -3,6 +3,11 @@ from __future__ import annotations
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable
|
||||
|
||||
try:
|
||||
from . import seed_config as seed_policy
|
||||
except ImportError: # pragma: no cover - plain-script smoke tests
|
||||
import seed_config as seed_policy
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PromptBuildRequest:
|
||||
@@ -77,6 +82,54 @@ class PromptBuildDependencies:
|
||||
normalize_prompt_row: Callable[..., dict[str, Any]]
|
||||
|
||||
|
||||
def _generation_trace(
|
||||
*,
|
||||
row: dict[str, Any],
|
||||
request: PromptBuildRequest,
|
||||
row_number: int,
|
||||
start_index: int,
|
||||
seed: int,
|
||||
category: str,
|
||||
subcategory: str,
|
||||
branch: str,
|
||||
parsed_seed_config: dict[str, Any],
|
||||
clothing: str,
|
||||
poses: str,
|
||||
figure: str,
|
||||
expression_enabled: bool,
|
||||
expression_intensity: float,
|
||||
expression_intensity_source: str,
|
||||
exact_custom_subcategory: bool,
|
||||
) -> dict[str, Any]:
|
||||
trace = {
|
||||
"builder": "prompt_builder",
|
||||
"branch": branch,
|
||||
"source": row.get("source", ""),
|
||||
"category_input": request.category,
|
||||
"subcategory_input": request.subcategory,
|
||||
"category": category,
|
||||
"subcategory": row.get("subcategory") or subcategory,
|
||||
"category_slug": row.get("category_slug", ""),
|
||||
"subcategory_slug": row.get("subcategory_slug", ""),
|
||||
"exact_custom_subcategory": bool(exact_custom_subcategory),
|
||||
"row_number": row_number,
|
||||
"start_index": start_index,
|
||||
"seed": seed,
|
||||
"seed_axes": seed_policy.axis_seed_trace(parsed_seed_config, seed, row_number),
|
||||
"content_seed_axis": row.get("content_seed_axis") or ("pose" if row.get("position_family") else "content"),
|
||||
"clothing": clothing,
|
||||
"poses": poses,
|
||||
"figure": figure,
|
||||
"expression_enabled": bool(expression_enabled),
|
||||
"expression_intensity": expression_intensity,
|
||||
"expression_intensity_source": expression_intensity_source,
|
||||
"trigger": row.get("trigger", ""),
|
||||
}
|
||||
if row.get("cast_count_adjustment"):
|
||||
trace["cast_count_adjustment"] = row.get("cast_count_adjustment")
|
||||
return trace
|
||||
|
||||
|
||||
def build_prompt_result(request: PromptBuildRequest, deps: PromptBuildDependencies) -> PromptBuildRoute:
|
||||
deps.apply_pool_extensions()
|
||||
row_number = max(1, int(request.row_number))
|
||||
@@ -204,6 +257,24 @@ def build_prompt_result(request: PromptBuildRequest, deps: PromptBuildDependenci
|
||||
)
|
||||
row.setdefault("expression_intensity", expression_intensity)
|
||||
row.setdefault("expression_intensity_source", expression_intensity_source)
|
||||
row["generation_trace"] = _generation_trace(
|
||||
row=row,
|
||||
request=request,
|
||||
row_number=row_number,
|
||||
start_index=start_index,
|
||||
seed=seed,
|
||||
category=category,
|
||||
subcategory=subcategory,
|
||||
branch=branch,
|
||||
parsed_seed_config=parsed_seed_config,
|
||||
clothing=clothing,
|
||||
poses=poses,
|
||||
figure=figure,
|
||||
expression_enabled=expression_enabled,
|
||||
expression_intensity=expression_intensity,
|
||||
expression_intensity_source=expression_intensity_source,
|
||||
exact_custom_subcategory=exact_custom_subcategory,
|
||||
)
|
||||
return PromptBuildRoute(
|
||||
row=row,
|
||||
category=category,
|
||||
|
||||
@@ -221,6 +221,10 @@ Common trap: `row_number` participates in `seed_config.axis_rng`. If two
|
||||
workflows have the same seeds but different `row_number`, they are not expected
|
||||
to match.
|
||||
|
||||
Each generated row stores `generation_trace.seed_axes` in `metadata_json`.
|
||||
Use it to verify whether an axis followed the main seed or a configured seed,
|
||||
and to compare the exact per-axis RNG seed used for the row.
|
||||
|
||||
## Category Sources
|
||||
|
||||
There are two category systems.
|
||||
@@ -513,6 +517,7 @@ plain prompt text. When debugging, inspect these fields before editing pools.
|
||||
| Field | Owner | Consumed by | Meaning |
|
||||
| --- | --- | --- | --- |
|
||||
| `source` | `build_prompt` / row builder | All formatters | Usually `json_category` or `built_in_generator`; tells which route created the row. |
|
||||
| `generation_trace` | `builder_prompt_route.build_prompt_result` | Debug | Compact generation route trace containing builder branch, input/resolved category, row seed, per-axis seed sources/RNG seeds, effective clothing/pose/figure choices, expression route, and content seed axis. |
|
||||
| `main_category`, `subcategory` | `row_category_route.select_category_item_route` | All formatters and debug | Human-readable selected category route. |
|
||||
| `category_slug`, `subcategory_slug` | `row_category_route.select_category_item_route` | Debug/filtering | Stable-ish machine labels for selected category route. |
|
||||
| `content_seed_axis` | `row_category_route.select_category_item_route` | Debug | Shows whether the item/action was driven by `content` or `pose`. Critical for hardcore pose categories. |
|
||||
|
||||
@@ -216,3 +216,24 @@ def axis_rng(seed_config: dict[str, int], axis: str, base_seed: int, row_number:
|
||||
if configured is None:
|
||||
return random.Random(row_seed(base_seed, row_number, salt))
|
||||
return random.Random(row_seed(configured, row_number, salt))
|
||||
|
||||
|
||||
def axis_seed_trace(
|
||||
seed_config: str | dict[str, Any] | None,
|
||||
base_seed: int,
|
||||
row_number: int,
|
||||
axes: Iterable[str] = SEED_LOCK_AXES,
|
||||
) -> dict[str, dict[str, int | str]]:
|
||||
parsed = parse_seed_config(seed_config)
|
||||
trace: dict[str, dict[str, int | str]] = {}
|
||||
for axis in axes:
|
||||
configured = configured_axis_seed(parsed, axis)
|
||||
seed_value = int(configured) if configured is not None else int(base_seed)
|
||||
source = "configured" if configured is not None else "main"
|
||||
salt = SEED_AXIS_SALTS.get(axis, 0)
|
||||
trace[axis] = {
|
||||
"source": source,
|
||||
"seed": seed_value,
|
||||
"rng_seed": row_seed(seed_value, row_number, salt),
|
||||
}
|
||||
return trace
|
||||
|
||||
@@ -936,6 +936,14 @@ def smoke_builder_prompt_route_policy() -> None:
|
||||
_expect(typed_route.subcategory == "Casual clothes / Smart casual", "Builder prompt route changed subcategory")
|
||||
_expect(typed_route.branch == "custom", "Builder prompt route should use custom branch for category JSON route")
|
||||
_expect(typed_route.parsed_seed_config.get("content_seed") == 3502, "Builder prompt route lost seed config")
|
||||
custom_trace = typed_route.row.get("generation_trace")
|
||||
_expect(isinstance(custom_trace, dict), "Builder custom route lost generation_trace")
|
||||
_expect(custom_trace.get("branch") == "custom", "Builder custom generation_trace lost branch")
|
||||
_expect(custom_trace.get("source") == "json_category", "Builder custom generation_trace lost source")
|
||||
_expect(custom_trace.get("category_slug") == "casual_clothes", "Builder custom generation_trace lost category slug")
|
||||
_expect(custom_trace.get("content_seed_axis") == "content", "Builder custom generation_trace lost content axis")
|
||||
_expect(custom_trace.get("seed_axes", {}).get("content", {}).get("source") == "configured", "Builder custom generation_trace lost configured content seed")
|
||||
_expect(custom_trace.get("seed_axes", {}).get("content", {}).get("seed") == 3502, "Builder custom generation_trace lost content seed value")
|
||||
_expect("typed builder route marker" in typed_route.row.get("prompt", ""), "Builder prompt route lost extra positive")
|
||||
_expect("typed builder negative marker" in typed_route.row.get("negative_prompt", ""), "Builder prompt route lost extra negative")
|
||||
_expect(
|
||||
@@ -970,6 +978,11 @@ def smoke_builder_prompt_route_policy() -> None:
|
||||
_expect(built_in_route.row == legacy_from_request(built_in_request), "Builder built-in route should match public wrapper")
|
||||
_expect(built_in_route.branch == "built_in", "Builder prompt route lost built-in branch")
|
||||
_expect(built_in_route.row.get("source") == "built_in_generator", "Builder built-in branch changed source")
|
||||
built_in_trace = built_in_route.row.get("generation_trace")
|
||||
_expect(isinstance(built_in_trace, dict), "Builder built-in route lost generation_trace")
|
||||
_expect(built_in_trace.get("branch") == "built_in", "Builder built-in generation_trace lost branch")
|
||||
_expect(built_in_trace.get("source") == "built_in_generator", "Builder built-in generation_trace lost source")
|
||||
_expect(built_in_trace.get("seed_axes", {}).get("person", {}).get("source") == "main", "Builder built-in generation_trace should follow main seed")
|
||||
_expect(built_in_route.row.get("expression_disabled") is True, "Builder built-in branch lost expression disable")
|
||||
_expect("built-in route marker" in built_in_route.row.get("prompt", ""), "Builder built-in branch lost extra positive")
|
||||
|
||||
@@ -1000,6 +1013,11 @@ def smoke_builder_prompt_route_policy() -> None:
|
||||
_expect(auto_route.row == legacy_from_request(auto_weighted_request), "Builder auto-weighted route should match public wrapper")
|
||||
_expect(auto_route.branch == "auto_weighted", "Builder prompt route lost auto-weighted branch")
|
||||
_expect(auto_route.parsed_seed_config.get("person_seed") == 3505, "Builder auto-weighted branch lost person seed lock")
|
||||
auto_trace = auto_route.row.get("generation_trace")
|
||||
_expect(isinstance(auto_trace, dict), "Builder auto-weighted route lost generation_trace")
|
||||
_expect(auto_trace.get("branch") == "auto_weighted", "Builder auto-weighted generation_trace lost branch")
|
||||
_expect(auto_trace.get("seed_axes", {}).get("person", {}).get("source") == "configured", "Builder auto-weighted trace lost configured person seed")
|
||||
_expect(auto_trace.get("seed_axes", {}).get("person", {}).get("seed") == 3505, "Builder auto-weighted trace lost person seed")
|
||||
_expect("auto route marker" in auto_route.row.get("prompt", ""), "Builder auto-weighted branch lost extra positive")
|
||||
|
||||
|
||||
@@ -1049,6 +1067,11 @@ def smoke_builder_config_route_policy() -> None:
|
||||
_expect(typed_route.cast["women_count"] == 1 and typed_route.cast["men_count"] == 0, "Config route lost cast preset")
|
||||
_expect(typed_route.profile["trigger"] == "sxcpinup_coloredpencil", "Config route lost generation profile trigger")
|
||||
_expect(typed_route.filters["ethnicity"] == "french_european", "Config route lost filter ethnicity")
|
||||
config_trace = typed_route.row.get("generation_trace")
|
||||
_expect(isinstance(config_trace, dict), "Config route row lost generation_trace")
|
||||
_expect(config_trace.get("branch") == "custom", "Config route generation_trace lost builder branch")
|
||||
_expect(config_trace.get("seed_axes", {}).get("scene", {}).get("source") == "configured", "Config route generation_trace lost scene seed lock")
|
||||
_expect(config_trace.get("seed_axes", {}).get("scene", {}).get("seed") == 3402, "Config route generation_trace lost scene reroll seed")
|
||||
kwargs = typed_route.build_kwargs
|
||||
_expect(kwargs["category"] == typed_route.category, "Config route build kwargs category drifted")
|
||||
_expect(kwargs["subcategory"] == typed_route.subcategory, "Config route build kwargs subcategory drifted")
|
||||
@@ -6611,6 +6634,15 @@ def smoke_seed_config_policy() -> None:
|
||||
_expect(locked["content_seed"] == 999, "content_pose reroll should alter content seed")
|
||||
_expect(locked["pose_seed"] == 999 and locked["role_seed"] == 999, "content_pose reroll should alter pose and role seeds")
|
||||
_expect(locked["scene_seed"] == 100, "content_pose reroll should leave scene locked")
|
||||
axis_trace = seed_config.axis_seed_trace({"content_seed": 44}, 99, 3, axes=("content", "scene"))
|
||||
_expect(axis_trace["content"]["source"] == "configured", "Seed axis trace lost configured source")
|
||||
_expect(axis_trace["content"]["seed"] == 44, "Seed axis trace lost configured seed")
|
||||
_expect(axis_trace["scene"]["source"] == "main", "Seed axis trace lost main source")
|
||||
_expect(axis_trace["scene"]["seed"] == 99, "Seed axis trace lost main seed")
|
||||
_expect(
|
||||
axis_trace["content"]["rng_seed"] == seed_config.row_seed(44, 3, seed_config.SEED_AXIS_SALTS["content"]),
|
||||
"Seed axis trace lost content RNG seed",
|
||||
)
|
||||
|
||||
rng_a = pb._axis_rng({"content_seed": 123}, "content", 999, 7)
|
||||
rng_b = seed_config.axis_rng({"content_seed": 123}, "content", 999, 7)
|
||||
@@ -7326,6 +7358,10 @@ def smoke_node_builder_registration() -> None:
|
||||
_expect_row_base(direct_row, "node_builder.direct_row")
|
||||
_expect(direct_output[0] == direct_row.get("prompt"), "Prompt Builder prompt output drifted from metadata")
|
||||
_expect(direct_output[4] == direct_row.get("main_category"), "Prompt Builder category output drifted from metadata")
|
||||
direct_trace = direct_row.get("generation_trace")
|
||||
_expect(isinstance(direct_trace, dict), "Prompt Builder metadata lost generation_trace")
|
||||
_expect(direct_trace.get("branch") == "built_in", "Prompt Builder metadata generation_trace lost branch")
|
||||
_expect(direct_trace.get("seed_axes", {}).get("content", {}).get("source") == "main", "Prompt Builder metadata trace lost content seed source")
|
||||
_expect_trigger_once("node_builder.direct_prompt", direct_output[0], Trigger)
|
||||
|
||||
config_node = sxcp_nodes.NODE_CLASS_MAPPINGS["SxCPPromptBuilderFromConfigs"]
|
||||
@@ -7343,6 +7379,9 @@ def smoke_node_builder_registration() -> None:
|
||||
_expect_row_base(config_row, "node_builder.config_row")
|
||||
_expect(config_output[0] == config_row.get("prompt"), "Prompt Builder From Configs prompt output drifted from metadata")
|
||||
_expect(config_output[4] == config_row.get("main_category"), "Prompt Builder From Configs category output drifted from metadata")
|
||||
config_trace = config_row.get("generation_trace")
|
||||
_expect(isinstance(config_trace, dict), "Prompt Builder From Configs metadata lost generation_trace")
|
||||
_expect(config_trace.get("builder") == "prompt_builder", "Prompt Builder From Configs trace lost builder label")
|
||||
_expect_text("node_builder.config_caption", config_output[2], 20)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user