Extract row text field resolution

This commit is contained in:
2026-06-27 10:18:26 +02:00
parent a5b648eb98
commit 09eaafc8f6
5 changed files with 78 additions and 17 deletions
+29
View File
@@ -1,11 +1,16 @@
from __future__ import annotations
from dataclasses import dataclass
from string import Formatter
from typing import Any
try:
from . import category_library as category_policy
from . import generate_prompt_batches as g
from . import row_camera as row_camera_policy
except ImportError: # Allows local smoke tests from the repository root.
import category_library as category_policy
import generate_prompt_batches as g
import row_camera as row_camera_policy
@@ -14,6 +19,17 @@ GENERIC_POSITIVE_SUFFIX = (
"pastel skin tones, muted blues and pinks, warm sensual lighting, and tactile textured paper."
)
DEFAULT_STYLE = "sexy but tasteful adult pin-up coloured-pencil comic illustration"
@dataclass(frozen=True)
class RowTextFields:
negative_prompt: str
positive_suffix: str
style: str
item_label: str
SINGLE_TEMPLATE = (
"A {subject}: {style}, {age}, {body_phrase}, {skin}, {hair}, {eyes}. "
"{item_label}: {item}. Scene: {scene}. Pose: {pose}. Facial expression: {expression}. "
@@ -56,6 +72,19 @@ def format_template(template: str, context: dict[str, Any]) -> str:
return template.format_map(safe_context)
def resolve_row_text_fields(category: dict[str, Any], subcategory: dict[str, Any], item: Any) -> RowTextFields:
return RowTextFields(
negative_prompt=str(
category_policy.merged_field(category, subcategory, item, "negative_prompt", g.NEGATIVE_PROMPT)
),
positive_suffix=str(
category_policy.merged_field(category, subcategory, item, "positive_suffix", GENERIC_POSITIVE_SUFFIX)
),
style=str(category_policy.merged_field(category, subcategory, item, "style", DEFAULT_STYLE)),
item_label=str(category_policy.merged_field(category, subcategory, item, "item_label", category["name"])),
)
def default_prompt_template(subject_type: str) -> str:
if subject_type in ("woman", "man"):
return SINGLE_TEMPLATE