Files
ComfyUI-Ethanfel-Prompt-Bui…/builder_prompt_route.py
T

296 lines
11 KiB
Python

from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Callable
try:
from . import seed_config as seed_policy
except ImportError: # pragma: no cover - plain-script smoke tests
import seed_config as seed_policy
@dataclass(frozen=True)
class PromptBuildRequest:
category: str
subcategory: str
row_number: int
start_index: int
seed: int
clothing: str
ethnicity: str
poses: str
backside_bias: float
figure: str
no_plus_women: bool
no_black: bool
minimal_clothing_ratio: float
standard_pose_ratio: float
trigger: str
prepend_trigger_to_prompt: bool
extra_positive: str
extra_negative: str
seed_config: str | dict[str, Any] | None = None
women_count: int = 1
men_count: int = 1
camera_config: str | dict[str, Any] | None = None
expression_intensity: float = 0.5
character_profile: str | dict[str, Any] | None = None
character_cast: str | dict[str, Any] | list[Any] | None = None
expression_enabled: bool = True
expression_phase: str = ""
hardcore_position_config: str | dict[str, Any] | None = None
location_config: str | dict[str, Any] | None = None
composition_config: str | dict[str, Any] | None = None
style_config: str | dict[str, Any] | None = None
@dataclass(frozen=True)
class PromptBuildRoute:
row: dict[str, Any]
category: str
subcategory: str
branch: str
parsed_seed_config: dict[str, Any]
expression_intensity: float
expression_intensity_source: str
@dataclass(frozen=True)
class PromptBuildDependencies:
default_trigger: str
default_negative: str
random_subcategory: str
apply_pool_extensions: Callable[[], Any]
normalize_ethnicity_filter: Callable[[Any, str], str]
is_false: Callable[[Any], bool]
ratio_or_none: Callable[[Any], float | None]
parse_seed_config: Callable[[str | dict[str, Any] | None], dict[str, Any]]
parse_location_config: Callable[[str | dict[str, Any] | None], dict[str, Any]]
parse_composition_config: Callable[[str | dict[str, Any] | None], dict[str, Any]]
axis_rng: Callable[[dict[str, Any], str, int, int], Any]
pick_clothing_mode: Callable[[Any, str, float | None], str]
pick_pose_mode: Callable[[Any, str, float | None], str]
pick_figure_bias: Callable[[Any, str], str]
pick_expression_intensity: Callable[[Any, Any], tuple[float, str]]
auto_full_choice: Callable[[dict[str, Any], int, int], str]
build_auto_weighted_row: Callable[..., dict[str, Any]]
build_direct_builtin_row: Callable[..., dict[str, Any]]
build_custom_row: Callable[..., dict[str, Any]]
apply_location_config_to_legacy_row: Callable[..., dict[str, Any]]
apply_composition_config_to_legacy_row: Callable[..., dict[str, Any]]
disable_row_expression: Callable[[dict[str, Any], str], dict[str, Any]]
apply_camera_config: Callable[[dict[str, Any], str | dict[str, Any] | None], dict[str, Any]]
normalize_prompt_row: Callable[..., dict[str, Any]]
def _generation_trace(
*,
row: dict[str, Any],
request: PromptBuildRequest,
row_number: int,
start_index: int,
seed: int,
category: str,
subcategory: str,
branch: str,
parsed_seed_config: dict[str, Any],
clothing: str,
poses: str,
figure: str,
expression_enabled: bool,
expression_intensity: float,
expression_intensity_source: str,
exact_custom_subcategory: bool,
) -> dict[str, Any]:
trace = {
"builder": "prompt_builder",
"branch": branch,
"source": row.get("source", ""),
"category_input": request.category,
"subcategory_input": request.subcategory,
"category": category,
"subcategory": row.get("subcategory") or subcategory,
"category_slug": row.get("category_slug", ""),
"subcategory_slug": row.get("subcategory_slug", ""),
"exact_custom_subcategory": bool(exact_custom_subcategory),
"row_number": row_number,
"start_index": start_index,
"seed": seed,
"seed_axes": seed_policy.axis_seed_trace(parsed_seed_config, seed, row_number),
"content_seed_axis": row.get("content_seed_axis") or ("pose" if row.get("position_family") else "content"),
"clothing": clothing,
"poses": poses,
"figure": figure,
"expression_enabled": bool(expression_enabled),
"expression_intensity": expression_intensity,
"expression_intensity_source": expression_intensity_source,
"trigger": row.get("trigger", ""),
}
if row.get("cast_count_adjustment"):
trace["cast_count_adjustment"] = row.get("cast_count_adjustment")
return trace
def build_prompt_result(request: PromptBuildRequest, deps: PromptBuildDependencies) -> PromptBuildRoute:
deps.apply_pool_extensions()
row_number = max(1, int(request.row_number))
start_index = max(1, int(request.start_index))
seed = int(request.seed)
category = request.category
subcategory = request.subcategory
ethnicity = deps.normalize_ethnicity_filter(request.ethnicity, "any")
expression_enabled = not deps.is_false(request.expression_enabled)
minimal_ratio = deps.ratio_or_none(request.minimal_clothing_ratio)
pose_ratio = deps.ratio_or_none(request.standard_pose_ratio)
parsed_seed_config = deps.parse_seed_config(request.seed_config)
parsed_location_config = deps.parse_location_config(request.location_config)
parsed_composition_config = deps.parse_composition_config(request.composition_config)
content_rng = deps.axis_rng(parsed_seed_config, "content", seed, row_number)
clothing_rng = deps.axis_rng(parsed_seed_config, "clothing", seed, row_number)
pose_axis_rng = deps.axis_rng(parsed_seed_config, "pose", seed, row_number)
person_rng = deps.axis_rng(parsed_seed_config, "person", seed, row_number)
expression_rng = deps.axis_rng(parsed_seed_config, "expression", seed, row_number)
clothing = request.clothing if request.clothing in ("full", "minimal", "random") else "full"
poses = request.poses if request.poses in ("standard", "evocative", "random") else "standard"
figure = request.figure if request.figure in ("curvy", "balanced", "bombshell", "random") else "curvy"
clothing = deps.pick_clothing_mode(clothing_rng, clothing, minimal_ratio)
poses = deps.pick_pose_mode(pose_axis_rng, poses, pose_ratio)
figure = deps.pick_figure_bias(person_rng, figure)
minimal_ratio = None
pose_ratio = None
expression_intensity, expression_intensity_source = deps.pick_expression_intensity(
expression_rng,
request.expression_intensity,
)
exact_custom_subcategory = bool(
subcategory and subcategory != deps.random_subcategory and " / " in subcategory
)
if category == "auto_full" and not exact_custom_subcategory:
category = deps.auto_full_choice(parsed_seed_config, seed, row_number)
branch = "custom"
if category == "auto_weighted" and not exact_custom_subcategory:
branch = "auto_weighted"
row = deps.build_auto_weighted_row(
row_number,
start_index,
clothing,
ethnicity,
poses,
float(request.backside_bias),
figure,
bool(request.no_plus_women),
bool(request.no_black),
minimal_ratio,
pose_ratio,
seed,
seed_config=parsed_seed_config,
)
elif category in ("woman", "man", "couple", "group_or_layout") and not exact_custom_subcategory:
branch = "built_in"
row = deps.build_direct_builtin_row(
category,
row_number,
start_index,
clothing,
ethnicity,
poses,
float(request.backside_bias),
figure,
bool(request.no_plus_women),
bool(request.no_black),
minimal_ratio,
pose_ratio,
seed,
seed_config=parsed_seed_config,
)
else:
row = deps.build_custom_row(
category,
subcategory,
row_number,
start_index,
ethnicity,
poses,
figure,
bool(request.no_plus_women),
bool(request.no_black),
int(request.women_count),
int(request.men_count),
seed,
parsed_seed_config,
expression_enabled,
expression_intensity,
expression_intensity_source,
request.character_profile,
request.character_cast,
request.expression_phase,
request.hardcore_position_config,
parsed_location_config,
parsed_composition_config,
request.style_config,
)
if row.get("source") == "built_in_generator":
row = deps.apply_location_config_to_legacy_row(
row,
parsed_location_config,
parsed_seed_config,
seed,
row_number,
)
row = deps.apply_composition_config_to_legacy_row(
row,
parsed_composition_config,
parsed_seed_config,
seed,
row_number,
)
if not expression_enabled:
row = deps.disable_row_expression(row, "disabled")
row = deps.apply_camera_config(row, request.camera_config)
active_trigger = request.trigger.strip() or deps.default_trigger
row = deps.normalize_prompt_row(
row,
active_trigger=active_trigger,
prepend_trigger_to_prompt=bool(request.prepend_trigger_to_prompt),
extra_positive=request.extra_positive,
extra_negative=request.extra_negative,
default_negative=deps.default_negative,
)
row.setdefault("expression_intensity", expression_intensity)
row.setdefault("expression_intensity_source", expression_intensity_source)
row["generation_trace"] = _generation_trace(
row=row,
request=request,
row_number=row_number,
start_index=start_index,
seed=seed,
category=category,
subcategory=subcategory,
branch=branch,
parsed_seed_config=parsed_seed_config,
clothing=clothing,
poses=poses,
figure=figure,
expression_enabled=expression_enabled,
expression_intensity=expression_intensity,
expression_intensity_source=expression_intensity_source,
exact_custom_subcategory=exact_custom_subcategory,
)
return PromptBuildRoute(
row=row,
category=category,
subcategory=subcategory,
branch=branch,
parsed_seed_config=dict(parsed_seed_config),
expression_intensity=expression_intensity,
expression_intensity_source=expression_intensity_source,
)
def build_prompt(request: PromptBuildRequest, deps: PromptBuildDependencies) -> dict[str, Any]:
return build_prompt_result(request, deps).row