Extract builder prompt route
This commit is contained in:
@@ -0,0 +1,219 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PromptBuildRequest:
|
||||
category: str
|
||||
subcategory: str
|
||||
row_number: int
|
||||
start_index: int
|
||||
seed: int
|
||||
clothing: str
|
||||
ethnicity: str
|
||||
poses: str
|
||||
backside_bias: float
|
||||
figure: str
|
||||
no_plus_women: bool
|
||||
no_black: bool
|
||||
minimal_clothing_ratio: float
|
||||
standard_pose_ratio: float
|
||||
trigger: str
|
||||
prepend_trigger_to_prompt: bool
|
||||
extra_positive: str
|
||||
extra_negative: str
|
||||
seed_config: str | dict[str, Any] | None = None
|
||||
women_count: int = 1
|
||||
men_count: int = 1
|
||||
camera_config: str | dict[str, Any] | None = None
|
||||
expression_intensity: float = 0.5
|
||||
character_profile: str | dict[str, Any] | None = None
|
||||
character_cast: str | dict[str, Any] | list[Any] | None = None
|
||||
expression_enabled: bool = True
|
||||
expression_phase: str = ""
|
||||
hardcore_position_config: str | dict[str, Any] | None = None
|
||||
location_config: str | dict[str, Any] | None = None
|
||||
composition_config: str | dict[str, Any] | None = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PromptBuildRoute:
|
||||
row: dict[str, Any]
|
||||
category: str
|
||||
subcategory: str
|
||||
branch: str
|
||||
parsed_seed_config: dict[str, Any]
|
||||
expression_intensity: float
|
||||
expression_intensity_source: str
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PromptBuildDependencies:
|
||||
default_trigger: str
|
||||
default_negative: str
|
||||
random_subcategory: str
|
||||
apply_pool_extensions: Callable[[], Any]
|
||||
normalize_ethnicity_filter: Callable[[Any, str], str]
|
||||
is_false: Callable[[Any], bool]
|
||||
ratio_or_none: Callable[[Any], float | None]
|
||||
parse_seed_config: Callable[[str | dict[str, Any] | None], dict[str, Any]]
|
||||
parse_location_config: Callable[[str | dict[str, Any] | None], dict[str, Any]]
|
||||
parse_composition_config: Callable[[str | dict[str, Any] | None], dict[str, Any]]
|
||||
axis_rng: Callable[[dict[str, Any], str, int, int], Any]
|
||||
pick_clothing_mode: Callable[[Any, str, float | None], str]
|
||||
pick_pose_mode: Callable[[Any, str, float | None], str]
|
||||
pick_figure_bias: Callable[[Any, str], str]
|
||||
pick_expression_intensity: Callable[[Any, Any], tuple[float, str]]
|
||||
auto_full_choice: Callable[[dict[str, Any], int, int], str]
|
||||
build_auto_weighted_row: Callable[..., dict[str, Any]]
|
||||
build_direct_builtin_row: Callable[..., dict[str, Any]]
|
||||
build_custom_row: Callable[..., dict[str, Any]]
|
||||
apply_location_config_to_legacy_row: Callable[..., dict[str, Any]]
|
||||
apply_composition_config_to_legacy_row: Callable[..., dict[str, Any]]
|
||||
disable_row_expression: Callable[[dict[str, Any], str], dict[str, Any]]
|
||||
apply_camera_config: Callable[[dict[str, Any], str | dict[str, Any] | None], dict[str, Any]]
|
||||
normalize_prompt_row: Callable[..., dict[str, Any]]
|
||||
|
||||
|
||||
def build_prompt_result(request: PromptBuildRequest, deps: PromptBuildDependencies) -> PromptBuildRoute:
|
||||
deps.apply_pool_extensions()
|
||||
row_number = max(1, int(request.row_number))
|
||||
start_index = max(1, int(request.start_index))
|
||||
seed = int(request.seed)
|
||||
category = request.category
|
||||
subcategory = request.subcategory
|
||||
ethnicity = deps.normalize_ethnicity_filter(request.ethnicity, "any")
|
||||
expression_enabled = not deps.is_false(request.expression_enabled)
|
||||
minimal_ratio = deps.ratio_or_none(request.minimal_clothing_ratio)
|
||||
pose_ratio = deps.ratio_or_none(request.standard_pose_ratio)
|
||||
parsed_seed_config = deps.parse_seed_config(request.seed_config)
|
||||
parsed_location_config = deps.parse_location_config(request.location_config)
|
||||
parsed_composition_config = deps.parse_composition_config(request.composition_config)
|
||||
content_rng = deps.axis_rng(parsed_seed_config, "content", seed, row_number)
|
||||
pose_axis_rng = deps.axis_rng(parsed_seed_config, "pose", seed, row_number)
|
||||
person_rng = deps.axis_rng(parsed_seed_config, "person", seed, row_number)
|
||||
expression_rng = deps.axis_rng(parsed_seed_config, "expression", seed, row_number)
|
||||
clothing = request.clothing if request.clothing in ("full", "minimal", "random") else "full"
|
||||
poses = request.poses if request.poses in ("standard", "evocative", "random") else "standard"
|
||||
figure = request.figure if request.figure in ("curvy", "balanced", "bombshell", "random") else "curvy"
|
||||
clothing = deps.pick_clothing_mode(content_rng, clothing, minimal_ratio)
|
||||
poses = deps.pick_pose_mode(pose_axis_rng, poses, pose_ratio)
|
||||
figure = deps.pick_figure_bias(person_rng, figure)
|
||||
minimal_ratio = None
|
||||
pose_ratio = None
|
||||
expression_intensity, expression_intensity_source = deps.pick_expression_intensity(
|
||||
expression_rng,
|
||||
request.expression_intensity,
|
||||
)
|
||||
|
||||
exact_custom_subcategory = bool(
|
||||
subcategory and subcategory != deps.random_subcategory and " / " in subcategory
|
||||
)
|
||||
|
||||
if category == "auto_full" and not exact_custom_subcategory:
|
||||
category = deps.auto_full_choice(parsed_seed_config, seed, row_number)
|
||||
|
||||
branch = "custom"
|
||||
if category == "auto_weighted" and not exact_custom_subcategory:
|
||||
branch = "auto_weighted"
|
||||
row = deps.build_auto_weighted_row(
|
||||
row_number,
|
||||
start_index,
|
||||
clothing,
|
||||
ethnicity,
|
||||
poses,
|
||||
float(request.backside_bias),
|
||||
figure,
|
||||
bool(request.no_plus_women),
|
||||
bool(request.no_black),
|
||||
minimal_ratio,
|
||||
pose_ratio,
|
||||
seed,
|
||||
)
|
||||
elif category in ("woman", "man", "couple", "group_or_layout") and not exact_custom_subcategory:
|
||||
branch = "built_in"
|
||||
row = deps.build_direct_builtin_row(
|
||||
category,
|
||||
row_number,
|
||||
start_index,
|
||||
clothing,
|
||||
ethnicity,
|
||||
poses,
|
||||
float(request.backside_bias),
|
||||
figure,
|
||||
bool(request.no_plus_women),
|
||||
bool(request.no_black),
|
||||
minimal_ratio,
|
||||
pose_ratio,
|
||||
seed,
|
||||
)
|
||||
else:
|
||||
row = deps.build_custom_row(
|
||||
category,
|
||||
subcategory,
|
||||
row_number,
|
||||
start_index,
|
||||
ethnicity,
|
||||
poses,
|
||||
figure,
|
||||
bool(request.no_plus_women),
|
||||
bool(request.no_black),
|
||||
int(request.women_count),
|
||||
int(request.men_count),
|
||||
seed,
|
||||
parsed_seed_config,
|
||||
expression_enabled,
|
||||
expression_intensity,
|
||||
expression_intensity_source,
|
||||
request.character_profile,
|
||||
request.character_cast,
|
||||
request.expression_phase,
|
||||
request.hardcore_position_config,
|
||||
parsed_location_config,
|
||||
parsed_composition_config,
|
||||
)
|
||||
|
||||
if row.get("source") == "built_in_generator":
|
||||
row = deps.apply_location_config_to_legacy_row(
|
||||
row,
|
||||
parsed_location_config,
|
||||
parsed_seed_config,
|
||||
seed,
|
||||
row_number,
|
||||
)
|
||||
row = deps.apply_composition_config_to_legacy_row(
|
||||
row,
|
||||
parsed_composition_config,
|
||||
parsed_seed_config,
|
||||
seed,
|
||||
row_number,
|
||||
)
|
||||
if not expression_enabled:
|
||||
row = deps.disable_row_expression(row, "disabled")
|
||||
row = deps.apply_camera_config(row, request.camera_config)
|
||||
active_trigger = request.trigger.strip() or deps.default_trigger
|
||||
row = deps.normalize_prompt_row(
|
||||
row,
|
||||
active_trigger=active_trigger,
|
||||
prepend_trigger_to_prompt=bool(request.prepend_trigger_to_prompt),
|
||||
extra_positive=request.extra_positive,
|
||||
extra_negative=request.extra_negative,
|
||||
default_negative=deps.default_negative,
|
||||
)
|
||||
row.setdefault("expression_intensity", expression_intensity)
|
||||
row.setdefault("expression_intensity_source", expression_intensity_source)
|
||||
return PromptBuildRoute(
|
||||
row=row,
|
||||
category=category,
|
||||
subcategory=subcategory,
|
||||
branch=branch,
|
||||
parsed_seed_config=dict(parsed_seed_config),
|
||||
expression_intensity=expression_intensity,
|
||||
expression_intensity_source=expression_intensity_source,
|
||||
)
|
||||
|
||||
|
||||
def build_prompt(request: PromptBuildRequest, deps: PromptBuildDependencies) -> dict[str, Any]:
|
||||
return build_prompt_result(request, deps).row
|
||||
Reference in New Issue
Block a user