Extract builder config route
This commit is contained in:
@@ -0,0 +1,101 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PromptFromConfigsRequest:
|
||||
row_number: int
|
||||
start_index: int
|
||||
seed: int
|
||||
category_config: str | dict[str, Any] | None = ""
|
||||
cast_config: str | dict[str, Any] | None = ""
|
||||
generation_profile: str | dict[str, Any] | None = ""
|
||||
filter_config: str | dict[str, Any] | None = ""
|
||||
seed_config: str | dict[str, Any] | None = ""
|
||||
camera_config: str | dict[str, Any] | None = ""
|
||||
character_profile: str | dict[str, Any] | None = ""
|
||||
character_cast: str | dict[str, Any] | list[Any] | None = ""
|
||||
hardcore_position_config: str | dict[str, Any] | None = ""
|
||||
location_config: str | dict[str, Any] | None = ""
|
||||
composition_config: str | dict[str, Any] | None = ""
|
||||
extra_positive: str = ""
|
||||
extra_negative: str = ""
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PromptFromConfigsRoute:
|
||||
row: dict[str, Any]
|
||||
category: str
|
||||
subcategory: str
|
||||
cast: dict[str, Any]
|
||||
profile: dict[str, Any]
|
||||
filters: dict[str, Any]
|
||||
build_kwargs: dict[str, Any]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PromptFromConfigsDependencies:
|
||||
parse_category_config: Callable[[str | dict[str, Any] | None], tuple[str, str]]
|
||||
parse_cast_config: Callable[[str | dict[str, Any] | None], dict[str, Any]]
|
||||
parse_generation_profile: Callable[[str | dict[str, Any] | None], dict[str, Any]]
|
||||
parse_filter_config: Callable[[str | dict[str, Any] | None], dict[str, Any]]
|
||||
build_prompt: Callable[..., dict[str, Any]]
|
||||
|
||||
|
||||
def build_prompt_from_configs_result(
|
||||
request: PromptFromConfigsRequest,
|
||||
deps: PromptFromConfigsDependencies,
|
||||
) -> PromptFromConfigsRoute:
|
||||
category, subcategory = deps.parse_category_config(request.category_config)
|
||||
cast = deps.parse_cast_config(request.cast_config)
|
||||
profile = deps.parse_generation_profile(request.generation_profile)
|
||||
filters = deps.parse_filter_config(request.filter_config)
|
||||
build_kwargs: dict[str, Any] = {
|
||||
"category": category,
|
||||
"subcategory": subcategory,
|
||||
"row_number": request.row_number,
|
||||
"start_index": request.start_index,
|
||||
"seed": request.seed,
|
||||
"clothing": profile["clothing"],
|
||||
"ethnicity": filters["ethnicity"],
|
||||
"poses": profile["poses"],
|
||||
"expression_enabled": profile["expression_enabled"],
|
||||
"expression_intensity": profile["expression_intensity"],
|
||||
"backside_bias": profile["backside_bias"],
|
||||
"figure": filters["figure"],
|
||||
"no_plus_women": filters["no_plus_women"],
|
||||
"no_black": filters["no_black"],
|
||||
"women_count": int(cast["women_count"]),
|
||||
"men_count": int(cast["men_count"]),
|
||||
"minimal_clothing_ratio": profile["minimal_clothing_ratio"],
|
||||
"standard_pose_ratio": profile["standard_pose_ratio"],
|
||||
"trigger": profile["trigger"],
|
||||
"prepend_trigger_to_prompt": profile["prepend_trigger_to_prompt"],
|
||||
"extra_positive": request.extra_positive or "",
|
||||
"extra_negative": request.extra_negative or "",
|
||||
"seed_config": request.seed_config or "",
|
||||
"camera_config": request.camera_config or "",
|
||||
"character_profile": request.character_profile or "",
|
||||
"character_cast": request.character_cast or "",
|
||||
"hardcore_position_config": request.hardcore_position_config or "",
|
||||
"location_config": request.location_config or "",
|
||||
"composition_config": request.composition_config or "",
|
||||
}
|
||||
return PromptFromConfigsRoute(
|
||||
row=deps.build_prompt(**build_kwargs),
|
||||
category=category,
|
||||
subcategory=subcategory,
|
||||
cast=dict(cast),
|
||||
profile=dict(profile),
|
||||
filters=dict(filters),
|
||||
build_kwargs=build_kwargs,
|
||||
)
|
||||
|
||||
|
||||
def build_prompt_from_configs(
|
||||
request: PromptFromConfigsRequest,
|
||||
deps: PromptFromConfigsDependencies,
|
||||
) -> dict[str, Any]:
|
||||
return build_prompt_from_configs_result(request, deps).row
|
||||
Reference in New Issue
Block a user