Extract row text field resolution
This commit is contained in:
@@ -1,11 +1,16 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from string import Formatter
|
||||
from typing import Any
|
||||
|
||||
try:
|
||||
from . import category_library as category_policy
|
||||
from . import generate_prompt_batches as g
|
||||
from . import row_camera as row_camera_policy
|
||||
except ImportError: # Allows local smoke tests from the repository root.
|
||||
import category_library as category_policy
|
||||
import generate_prompt_batches as g
|
||||
import row_camera as row_camera_policy
|
||||
|
||||
|
||||
@@ -14,6 +19,17 @@ GENERIC_POSITIVE_SUFFIX = (
|
||||
"pastel skin tones, muted blues and pinks, warm sensual lighting, and tactile textured paper."
|
||||
)
|
||||
|
||||
DEFAULT_STYLE = "sexy but tasteful adult pin-up coloured-pencil comic illustration"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RowTextFields:
|
||||
negative_prompt: str
|
||||
positive_suffix: str
|
||||
style: str
|
||||
item_label: str
|
||||
|
||||
|
||||
SINGLE_TEMPLATE = (
|
||||
"A {subject}: {style}, {age}, {body_phrase}, {skin}, {hair}, {eyes}. "
|
||||
"{item_label}: {item}. Scene: {scene}. Pose: {pose}. Facial expression: {expression}. "
|
||||
@@ -56,6 +72,19 @@ def format_template(template: str, context: dict[str, Any]) -> str:
|
||||
return template.format_map(safe_context)
|
||||
|
||||
|
||||
def resolve_row_text_fields(category: dict[str, Any], subcategory: dict[str, Any], item: Any) -> RowTextFields:
|
||||
return RowTextFields(
|
||||
negative_prompt=str(
|
||||
category_policy.merged_field(category, subcategory, item, "negative_prompt", g.NEGATIVE_PROMPT)
|
||||
),
|
||||
positive_suffix=str(
|
||||
category_policy.merged_field(category, subcategory, item, "positive_suffix", GENERIC_POSITIVE_SUFFIX)
|
||||
),
|
||||
style=str(category_policy.merged_field(category, subcategory, item, "style", DEFAULT_STYLE)),
|
||||
item_label=str(category_policy.merged_field(category, subcategory, item, "item_label", category["name"])),
|
||||
)
|
||||
|
||||
|
||||
def default_prompt_template(subject_type: str) -> str:
|
||||
if subject_type in ("woman", "man"):
|
||||
return SINGLE_TEMPLATE
|
||||
|
||||
Reference in New Issue
Block a user