from __future__ import annotations from dataclasses import dataclass from typing import Any, Callable try: from . import seed_config as seed_policy except ImportError: # pragma: no cover - plain-script smoke tests import seed_config as seed_policy @dataclass(frozen=True) class PromptBuildRequest: category: str subcategory: str row_number: int start_index: int seed: int clothing: str ethnicity: str poses: str backside_bias: float figure: str no_plus_women: bool no_black: bool minimal_clothing_ratio: float standard_pose_ratio: float trigger: str prepend_trigger_to_prompt: bool extra_positive: str extra_negative: str seed_config: str | dict[str, Any] | None = None women_count: int = 1 men_count: int = 1 camera_config: str | dict[str, Any] | None = None expression_intensity: float = 0.5 character_profile: str | dict[str, Any] | None = None character_cast: str | dict[str, Any] | list[Any] | None = None expression_enabled: bool = True expression_phase: str = "" hardcore_position_config: str | dict[str, Any] | None = None location_config: str | dict[str, Any] | None = None composition_config: str | dict[str, Any] | None = None style_config: str | dict[str, Any] | None = None @dataclass(frozen=True) class PromptBuildRoute: row: dict[str, Any] category: str subcategory: str branch: str parsed_seed_config: dict[str, Any] expression_intensity: float expression_intensity_source: str @dataclass(frozen=True) class PromptBuildDependencies: default_trigger: str default_negative: str random_subcategory: str apply_pool_extensions: Callable[[], Any] normalize_ethnicity_filter: Callable[[Any, str], str] is_false: Callable[[Any], bool] ratio_or_none: Callable[[Any], float | None] parse_seed_config: Callable[[str | dict[str, Any] | None], dict[str, Any]] parse_location_config: Callable[[str | dict[str, Any] | None], dict[str, Any]] parse_composition_config: Callable[[str | dict[str, Any] | None], dict[str, Any]] axis_rng: Callable[[dict[str, Any], str, int, int], Any] pick_clothing_mode: Callable[[Any, str, float | None], str] pick_pose_mode: Callable[[Any, str, float | None], str] pick_figure_bias: Callable[[Any, str], str] pick_expression_intensity: Callable[[Any, Any], tuple[float, str]] auto_full_choice: Callable[[dict[str, Any], int, int], str] build_auto_weighted_row: Callable[..., dict[str, Any]] build_direct_builtin_row: Callable[..., dict[str, Any]] build_custom_row: Callable[..., dict[str, Any]] apply_location_config_to_legacy_row: Callable[..., dict[str, Any]] apply_composition_config_to_legacy_row: Callable[..., dict[str, Any]] disable_row_expression: Callable[[dict[str, Any], str], dict[str, Any]] apply_camera_config: Callable[[dict[str, Any], str | dict[str, Any] | None], dict[str, Any]] normalize_prompt_row: Callable[..., dict[str, Any]] def _generation_trace( *, row: dict[str, Any], request: PromptBuildRequest, row_number: int, start_index: int, seed: int, category: str, subcategory: str, branch: str, parsed_seed_config: dict[str, Any], clothing: str, poses: str, figure: str, expression_enabled: bool, expression_intensity: float, expression_intensity_source: str, exact_custom_subcategory: bool, ) -> dict[str, Any]: trace = { "builder": "prompt_builder", "branch": branch, "source": row.get("source", ""), "category_input": request.category, "subcategory_input": request.subcategory, "category": category, "subcategory": row.get("subcategory") or subcategory, "category_slug": row.get("category_slug", ""), "subcategory_slug": row.get("subcategory_slug", ""), "exact_custom_subcategory": bool(exact_custom_subcategory), "row_number": row_number, "start_index": start_index, "seed": seed, "seed_axes": seed_policy.axis_seed_trace(parsed_seed_config, seed, row_number), "content_seed_axis": row.get("content_seed_axis") or ("pose" if row.get("position_family") else "content"), "clothing": clothing, "poses": poses, "figure": figure, "expression_enabled": bool(expression_enabled), "expression_intensity": expression_intensity, "expression_intensity_source": expression_intensity_source, "trigger": row.get("trigger", ""), } if row.get("cast_count_adjustment"): trace["cast_count_adjustment"] = row.get("cast_count_adjustment") return trace def build_prompt_result(request: PromptBuildRequest, deps: PromptBuildDependencies) -> PromptBuildRoute: deps.apply_pool_extensions() row_number = max(1, int(request.row_number)) start_index = max(1, int(request.start_index)) seed = int(request.seed) category = request.category subcategory = request.subcategory ethnicity = deps.normalize_ethnicity_filter(request.ethnicity, "any") expression_enabled = not deps.is_false(request.expression_enabled) minimal_ratio = deps.ratio_or_none(request.minimal_clothing_ratio) pose_ratio = deps.ratio_or_none(request.standard_pose_ratio) parsed_seed_config = deps.parse_seed_config(request.seed_config) parsed_location_config = deps.parse_location_config(request.location_config) parsed_composition_config = deps.parse_composition_config(request.composition_config) content_rng = deps.axis_rng(parsed_seed_config, "content", seed, row_number) clothing_rng = deps.axis_rng(parsed_seed_config, "clothing", seed, row_number) pose_axis_rng = deps.axis_rng(parsed_seed_config, "pose", seed, row_number) person_rng = deps.axis_rng(parsed_seed_config, "person", seed, row_number) expression_rng = deps.axis_rng(parsed_seed_config, "expression", seed, row_number) clothing = request.clothing if request.clothing in ("full", "minimal", "random") else "full" poses = request.poses if request.poses in ("standard", "evocative", "random") else "standard" figure = request.figure if request.figure in ("curvy", "balanced", "bombshell", "random") else "curvy" clothing = deps.pick_clothing_mode(clothing_rng, clothing, minimal_ratio) poses = deps.pick_pose_mode(pose_axis_rng, poses, pose_ratio) figure = deps.pick_figure_bias(person_rng, figure) minimal_ratio = None pose_ratio = None expression_intensity, expression_intensity_source = deps.pick_expression_intensity( expression_rng, request.expression_intensity, ) exact_custom_subcategory = bool( subcategory and subcategory != deps.random_subcategory and " / " in subcategory ) if category == "auto_full" and not exact_custom_subcategory: category = deps.auto_full_choice(parsed_seed_config, seed, row_number) branch = "custom" if category == "auto_weighted" and not exact_custom_subcategory: branch = "auto_weighted" row = deps.build_auto_weighted_row( row_number, start_index, clothing, ethnicity, poses, float(request.backside_bias), figure, bool(request.no_plus_women), bool(request.no_black), minimal_ratio, pose_ratio, seed, seed_config=parsed_seed_config, ) elif category in ("woman", "man", "couple", "group_or_layout") and not exact_custom_subcategory: branch = "built_in" row = deps.build_direct_builtin_row( category, row_number, start_index, clothing, ethnicity, poses, float(request.backside_bias), figure, bool(request.no_plus_women), bool(request.no_black), minimal_ratio, pose_ratio, seed, seed_config=parsed_seed_config, ) else: row = deps.build_custom_row( category, subcategory, row_number, start_index, ethnicity, poses, figure, bool(request.no_plus_women), bool(request.no_black), int(request.women_count), int(request.men_count), seed, parsed_seed_config, expression_enabled, expression_intensity, expression_intensity_source, request.character_profile, request.character_cast, request.expression_phase, request.hardcore_position_config, parsed_location_config, parsed_composition_config, request.style_config, ) if row.get("source") == "built_in_generator": row = deps.apply_location_config_to_legacy_row( row, parsed_location_config, parsed_seed_config, seed, row_number, ) row = deps.apply_composition_config_to_legacy_row( row, parsed_composition_config, parsed_seed_config, seed, row_number, ) if not expression_enabled: row = deps.disable_row_expression(row, "disabled") row = deps.apply_camera_config(row, request.camera_config) active_trigger = request.trigger.strip() or deps.default_trigger row = deps.normalize_prompt_row( row, active_trigger=active_trigger, prepend_trigger_to_prompt=bool(request.prepend_trigger_to_prompt), extra_positive=request.extra_positive, extra_negative=request.extra_negative, default_negative=deps.default_negative, ) row.setdefault("expression_intensity", expression_intensity) row.setdefault("expression_intensity_source", expression_intensity_source) row["generation_trace"] = _generation_trace( row=row, request=request, row_number=row_number, start_index=start_index, seed=seed, category=category, subcategory=subcategory, branch=branch, parsed_seed_config=parsed_seed_config, clothing=clothing, poses=poses, figure=figure, expression_enabled=expression_enabled, expression_intensity=expression_intensity, expression_intensity_source=expression_intensity_source, exact_custom_subcategory=exact_custom_subcategory, ) return PromptBuildRoute( row=row, category=category, subcategory=subcategory, branch=branch, parsed_seed_config=dict(parsed_seed_config), expression_intensity=expression_intensity, expression_intensity_source=expression_intensity_source, ) def build_prompt(request: PromptBuildRequest, deps: PromptBuildDependencies) -> dict[str, Any]: return build_prompt_result(request, deps).row