from __future__ import annotations from dataclasses import dataclass from string import Formatter from typing import Any try: from . import category_library as category_policy from . import generate_prompt_batches as g from . import row_camera as row_camera_policy from . import style_config as style_config_policy except ImportError: # Allows local smoke tests from the repository root. import category_library as category_policy import generate_prompt_batches as g import row_camera as row_camera_policy import style_config as style_config_policy GENERIC_POSITIVE_SUFFIX = ( "Use coherent anatomy, readable body placement, natural light response, " "clear material texture, stable spatial depth, and polished visual detail." ) DEFAULT_STYLE = "realistic adult scene with natural camera realism" @dataclass(frozen=True) class RowTextFields: negative_prompt: str positive_suffix: str style: str item_label: str SINGLE_TEMPLATE = ( "A {subject}: {style}, {age}, {body_phrase}, {skin}, {hair}, {eyes}. " "{item_label}: {item}. Scene: {scene}. Pose: {pose}. Facial expression: {expression}. " "Composition: {composition_prompt}. {positive_suffix} Avoid: {negative_prompt}." ) COUPLE_TEMPLATE = ( "{subject_phrase}: {style}. Ages: {age}. Body types: {body}. {item_label}: {item}. " "Scene: {scene}. Pose: {pose}. Facial expressions: {expression}. " "Composition: {composition_prompt}. {positive_suffix} Avoid: {negative_prompt}." ) GROUP_TEMPLATE = ( "{subject_phrase}: {style}, ages {age}, diverse adult body types. {item_label}: {item}. " "Scene: {scene}. Facial expressions: {expression}. Composition: {composition_prompt}. " "{positive_suffix} Avoid: {negative_prompt}." ) LAYOUT_TEMPLATE = ( "{item}: {style}, adults only, clean designed composition. Scene: {scene}. " "Facial expression: {expression}. Composition: {composition}. {positive_suffix} " "Avoid: {negative_prompt}. Use no readable text unless the layout naturally needs small decorative placeholder marks." ) DEFAULT_CAPTION_TEMPLATE = ( "{trigger}, {subject_phrase}, {age}, {item}, {scene}, {composition}" ) class SafeFormatDict(dict): def __missing__(self, key: str) -> str: return "{" + key + "}" def format_template(template: str, context: dict[str, Any]) -> str: fields = {key for _, key, _, _ in Formatter().parse(template) if key} safe_context = SafeFormatDict({key: str(value) for key, value in context.items()}) for field in fields: safe_context.setdefault(field, "{" + field + "}") return template.format_map(safe_context) def resolve_row_text_fields( category: dict[str, Any], subcategory: dict[str, Any], item: Any, style_config: str | dict[str, Any] | None = None, ) -> RowTextFields: base_negative = str(category_policy.merged_field(category, subcategory, item, "negative_prompt", g.NEGATIVE_PROMPT)) base_suffix = str(category_policy.merged_field(category, subcategory, item, "positive_suffix", GENERIC_POSITIVE_SUFFIX)) base_style = str(category_policy.merged_field(category, subcategory, item, "style", DEFAULT_STYLE)) style, positive_suffix = style_config_policy.resolve_style_fields(base_style, base_suffix, style_config) return RowTextFields( negative_prompt=style_config_policy.merge_negative_prompt(base_negative, style_config), positive_suffix=positive_suffix, style=style, item_label=str(category_policy.merged_field(category, subcategory, item, "item_label", category["name"])), ) def default_prompt_template(subject_type: str) -> str: if subject_type in ("woman", "man"): return SINGLE_TEMPLATE if subject_type == "couple": return COUPLE_TEMPLATE if subject_type == "group": return GROUP_TEMPLATE return LAYOUT_TEMPLATE def prompt_template_for(item: Any, subcategory: dict[str, Any], category: dict[str, Any], subject_type: str) -> str: if isinstance(item, dict) and "prompt_template" in item: return str(item["prompt_template"]) template = str(subcategory.get("prompt_template") or category.get("prompt_template") or "") return template or default_prompt_template(subject_type) def caption_template_for(item: Any, subcategory: dict[str, Any], category: dict[str, Any]) -> str: return str( (item.get("caption_template") if isinstance(item, dict) else None) or subcategory.get("caption_template") or category.get("caption_template") or DEFAULT_CAPTION_TEMPLATE ) def render_prompt_caption( *, item: Any, subcategory: dict[str, Any], category: dict[str, Any], subject_type: str, context: dict[str, Any], cast_descriptor_text: str = "", pov_prompt_directive: str = "", ) -> dict[str, str]: prompt_template = prompt_template_for(item, subcategory, category, subject_type) caption_template = caption_template_for(item, subcategory, category) prompt = format_template(prompt_template, context) if subject_type == "configured_cast" and cast_descriptor_text and "{cast_descriptors}" not in prompt_template: prompt = row_camera_policy.insert_positive_directive(prompt, f"Characters: {cast_descriptor_text}.") if subject_type == "configured_cast" and pov_prompt_directive: prompt = row_camera_policy.insert_positive_directive(prompt, pov_prompt_directive) caption = format_template(caption_template, context) if subject_type == "configured_cast" and cast_descriptor_text and "{cast_descriptors}" not in caption_template: caption = f"{caption.rstrip()}, {cast_descriptor_text}" return { "prompt": prompt, "caption": caption, "prompt_template": prompt_template, "caption_template": caption_template, }