Add separate style pool config

This commit is contained in:
2026-06-28 00:24:40 +02:00
parent 4c8edc0d3e
commit 78e39734b5
18 changed files with 378 additions and 27 deletions
+45 -2
View File
@@ -47,6 +47,7 @@ try:
from . import row_route_metadata as row_route_policy
from . import row_subject_route as row_subject_route_policy
from . import seed_config as seed_policy
from . import style_config as style_policy
from . import subject_context as subject_context_policy
from .hardcore_text_cleanup import (
sanitize_hardcore_axis_values as _sanitize_hardcore_axis_values,
@@ -95,6 +96,7 @@ except ImportError: # Allows local smoke tests with `python -c`.
import row_route_metadata as row_route_policy
import row_subject_route as row_subject_route_policy
import seed_config as seed_policy
import style_config as style_policy
import subject_context as subject_context_policy
from hardcore_text_cleanup import (
sanitize_hardcore_axis_values as _sanitize_hardcore_axis_values,
@@ -376,6 +378,7 @@ CATEGORY_PRESETS = category_cast_policy.CATEGORY_PRESETS
CAST_PRESETS = category_cast_policy.CAST_PRESETS
GENERATION_PROFILE_PRESETS = generation_profile_policy.GENERATION_PROFILE_PRESETS
STYLE_PRESETS = style_policy.STYLE_PRESETS
def category_preset_choices() -> list[str]:
@@ -390,6 +393,14 @@ def generation_profile_choices() -> list[str]:
return generation_profile_policy.generation_profile_choices()
def style_pool_preset_choices() -> list[str]:
return style_policy.style_pool_preset_choices()
def style_combine_mode_choices() -> list[str]:
return style_policy.style_combine_mode_choices()
def build_category_config_json(preset: str = "auto_weighted", subcategory: str = RANDOM_SUBCATEGORY) -> str:
return category_cast_policy.build_category_config_json(preset=preset, subcategory=subcategory)
@@ -436,6 +447,30 @@ def _parse_generation_profile(profile_config: str | dict[str, Any] | None) -> di
return generation_profile_policy.parse_generation_profile(profile_config)
def build_style_config_json(
enabled: bool = True,
combine_mode: str = "replace",
preset: str = "category_default",
custom_style: str = "",
custom_positive_suffix: str = "",
custom_negative: str = "",
style_config: str | dict[str, Any] | None = "",
) -> str:
return style_policy.build_style_config_json(
enabled=enabled,
combine_mode=combine_mode,
preset=preset,
custom_style=custom_style,
custom_positive_suffix=custom_positive_suffix,
custom_negative=custom_negative,
style_config=style_config,
)
def _parse_style_config(style_config: str | dict[str, Any] | None) -> dict[str, Any]:
return style_policy.parse_style_config(style_config)
def build_filter_config_json(
ethnicity: str = "any",
figure: str = "curvy",
@@ -880,8 +915,9 @@ def _row_text_fields(
category: dict[str, Any],
subcategory: dict[str, Any],
item: Any,
style_config: str | dict[str, Any] | None = None,
) -> row_rendering_policy.RowTextFields:
return row_rendering_policy.resolve_row_text_fields(category, subcategory, item)
return row_rendering_policy.resolve_row_text_fields(category, subcategory, item, style_config)
def _clean_prompt_punctuation(text: str) -> str:
@@ -2284,6 +2320,7 @@ def _build_custom_row(
hardcore_position_config: str | dict[str, Any] | None = None,
location_config: str | dict[str, Any] | None = None,
composition_config: str | dict[str, Any] | None = None,
style_config: str | dict[str, Any] | None = None,
) -> dict[str, Any]:
scene_rng = _axis_rng(seed_config, "scene", seed, row_number)
pose_rng = _axis_rng(seed_config, "pose", seed, row_number)
@@ -2421,7 +2458,7 @@ def _build_custom_row(
position_key = action_route.position_key
action_family = action_route.action_family
text_fields = _row_text_fields(category, subcategory, item)
text_fields = _row_text_fields(category, subcategory, item, style_config)
assembly_request = row_assembly_policy.CustomRowAssemblyRequest(
row_number=row_number,
@@ -2542,6 +2579,7 @@ def build_prompt(
hardcore_position_config: str | dict[str, Any] | None = None,
location_config: str | dict[str, Any] | None = None,
composition_config: str | dict[str, Any] | None = None,
style_config: str | dict[str, Any] | None = None,
) -> dict[str, Any]:
return builder_prompt_route_policy.build_prompt(
builder_prompt_route_policy.PromptBuildRequest(
@@ -2575,6 +2613,7 @@ def build_prompt(
hardcore_position_config=hardcore_position_config,
location_config=location_config,
composition_config=composition_config,
style_config=style_config,
),
_prompt_build_dependencies(),
)
@@ -2605,6 +2644,7 @@ def build_prompt_from_configs(
hardcore_position_config: str | dict[str, Any] | None = "",
location_config: str | dict[str, Any] | None = "",
composition_config: str | dict[str, Any] | None = "",
style_config: str | dict[str, Any] | None = "",
extra_positive: str = "",
extra_negative: str = "",
) -> dict[str, Any]:
@@ -2624,6 +2664,7 @@ def build_prompt_from_configs(
hardcore_position_config=hardcore_position_config,
location_config=location_config,
composition_config=composition_config,
style_config=style_config,
extra_positive=extra_positive,
extra_negative=extra_negative,
),
@@ -2801,6 +2842,7 @@ def build_insta_of_pair(
hardcore_position_config: str | dict[str, Any] | None = "",
location_config: str | dict[str, Any] | None = "",
composition_config: str | dict[str, Any] | None = "",
style_config: str | dict[str, Any] | None = "",
extra_positive: str = "",
extra_negative: str = "",
) -> dict[str, Any]:
@@ -2825,6 +2867,7 @@ def build_insta_of_pair(
hardcore_position_config=hardcore_position_config,
location_config=location_config,
composition_config=composition_config,
style_config=style_config,
extra_positive=extra_positive,
extra_negative=extra_negative,
)