104 lines
4.0 KiB
Python
104 lines
4.0 KiB
Python
from __future__ import annotations
|
|
|
|
from dataclasses import dataclass
|
|
from typing import Any, Callable
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class PromptFromConfigsRequest:
|
|
row_number: int
|
|
start_index: int
|
|
seed: int
|
|
category_config: str | dict[str, Any] | None = ""
|
|
cast_config: str | dict[str, Any] | None = ""
|
|
generation_profile: str | dict[str, Any] | None = ""
|
|
filter_config: str | dict[str, Any] | None = ""
|
|
seed_config: str | dict[str, Any] | None = ""
|
|
camera_config: str | dict[str, Any] | None = ""
|
|
character_profile: str | dict[str, Any] | None = ""
|
|
character_cast: str | dict[str, Any] | list[Any] | None = ""
|
|
hardcore_position_config: str | dict[str, Any] | None = ""
|
|
location_config: str | dict[str, Any] | None = ""
|
|
composition_config: str | dict[str, Any] | None = ""
|
|
style_config: str | dict[str, Any] | None = ""
|
|
extra_positive: str = ""
|
|
extra_negative: str = ""
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class PromptFromConfigsRoute:
|
|
row: dict[str, Any]
|
|
category: str
|
|
subcategory: str
|
|
cast: dict[str, Any]
|
|
profile: dict[str, Any]
|
|
filters: dict[str, Any]
|
|
build_kwargs: dict[str, Any]
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class PromptFromConfigsDependencies:
|
|
parse_category_config: Callable[[str | dict[str, Any] | None], tuple[str, str]]
|
|
parse_cast_config: Callable[[str | dict[str, Any] | None], dict[str, Any]]
|
|
parse_generation_profile: Callable[[str | dict[str, Any] | None], dict[str, Any]]
|
|
parse_filter_config: Callable[[str | dict[str, Any] | None], dict[str, Any]]
|
|
build_prompt: Callable[..., dict[str, Any]]
|
|
|
|
|
|
def build_prompt_from_configs_result(
|
|
request: PromptFromConfigsRequest,
|
|
deps: PromptFromConfigsDependencies,
|
|
) -> PromptFromConfigsRoute:
|
|
category, subcategory = deps.parse_category_config(request.category_config)
|
|
cast = deps.parse_cast_config(request.cast_config)
|
|
profile = deps.parse_generation_profile(request.generation_profile)
|
|
filters = deps.parse_filter_config(request.filter_config)
|
|
build_kwargs: dict[str, Any] = {
|
|
"category": category,
|
|
"subcategory": subcategory,
|
|
"row_number": request.row_number,
|
|
"start_index": request.start_index,
|
|
"seed": request.seed,
|
|
"clothing": profile["clothing"],
|
|
"ethnicity": filters["ethnicity"],
|
|
"poses": profile["poses"],
|
|
"expression_enabled": profile["expression_enabled"],
|
|
"expression_intensity": profile["expression_intensity"],
|
|
"backside_bias": profile["backside_bias"],
|
|
"figure": filters["figure"],
|
|
"no_plus_women": filters["no_plus_women"],
|
|
"no_black": filters["no_black"],
|
|
"women_count": int(cast["women_count"]),
|
|
"men_count": int(cast["men_count"]),
|
|
"minimal_clothing_ratio": profile["minimal_clothing_ratio"],
|
|
"standard_pose_ratio": profile["standard_pose_ratio"],
|
|
"trigger": profile["trigger"],
|
|
"prepend_trigger_to_prompt": profile["prepend_trigger_to_prompt"],
|
|
"extra_positive": request.extra_positive or "",
|
|
"extra_negative": request.extra_negative or "",
|
|
"seed_config": request.seed_config or "",
|
|
"camera_config": request.camera_config or "",
|
|
"character_profile": request.character_profile or "",
|
|
"character_cast": request.character_cast or "",
|
|
"hardcore_position_config": request.hardcore_position_config or "",
|
|
"location_config": request.location_config or "",
|
|
"composition_config": request.composition_config or "",
|
|
"style_config": request.style_config or "",
|
|
}
|
|
return PromptFromConfigsRoute(
|
|
row=deps.build_prompt(**build_kwargs),
|
|
category=category,
|
|
subcategory=subcategory,
|
|
cast=dict(cast),
|
|
profile=dict(profile),
|
|
filters=dict(filters),
|
|
build_kwargs=build_kwargs,
|
|
)
|
|
|
|
|
|
def build_prompt_from_configs(
|
|
request: PromptFromConfigsRequest,
|
|
deps: PromptFromConfigsDependencies,
|
|
) -> dict[str, Any]:
|
|
return build_prompt_from_configs_result(request, deps).row
|