Add separate style pool config
This commit is contained in:
+45
-2
@@ -47,6 +47,7 @@ try:
|
||||
from . import row_route_metadata as row_route_policy
|
||||
from . import row_subject_route as row_subject_route_policy
|
||||
from . import seed_config as seed_policy
|
||||
from . import style_config as style_policy
|
||||
from . import subject_context as subject_context_policy
|
||||
from .hardcore_text_cleanup import (
|
||||
sanitize_hardcore_axis_values as _sanitize_hardcore_axis_values,
|
||||
@@ -95,6 +96,7 @@ except ImportError: # Allows local smoke tests with `python -c`.
|
||||
import row_route_metadata as row_route_policy
|
||||
import row_subject_route as row_subject_route_policy
|
||||
import seed_config as seed_policy
|
||||
import style_config as style_policy
|
||||
import subject_context as subject_context_policy
|
||||
from hardcore_text_cleanup import (
|
||||
sanitize_hardcore_axis_values as _sanitize_hardcore_axis_values,
|
||||
@@ -376,6 +378,7 @@ CATEGORY_PRESETS = category_cast_policy.CATEGORY_PRESETS
|
||||
CAST_PRESETS = category_cast_policy.CAST_PRESETS
|
||||
|
||||
GENERATION_PROFILE_PRESETS = generation_profile_policy.GENERATION_PROFILE_PRESETS
|
||||
STYLE_PRESETS = style_policy.STYLE_PRESETS
|
||||
|
||||
|
||||
def category_preset_choices() -> list[str]:
|
||||
@@ -390,6 +393,14 @@ def generation_profile_choices() -> list[str]:
|
||||
return generation_profile_policy.generation_profile_choices()
|
||||
|
||||
|
||||
def style_pool_preset_choices() -> list[str]:
|
||||
return style_policy.style_pool_preset_choices()
|
||||
|
||||
|
||||
def style_combine_mode_choices() -> list[str]:
|
||||
return style_policy.style_combine_mode_choices()
|
||||
|
||||
|
||||
def build_category_config_json(preset: str = "auto_weighted", subcategory: str = RANDOM_SUBCATEGORY) -> str:
|
||||
return category_cast_policy.build_category_config_json(preset=preset, subcategory=subcategory)
|
||||
|
||||
@@ -436,6 +447,30 @@ def _parse_generation_profile(profile_config: str | dict[str, Any] | None) -> di
|
||||
return generation_profile_policy.parse_generation_profile(profile_config)
|
||||
|
||||
|
||||
def build_style_config_json(
|
||||
enabled: bool = True,
|
||||
combine_mode: str = "replace",
|
||||
preset: str = "category_default",
|
||||
custom_style: str = "",
|
||||
custom_positive_suffix: str = "",
|
||||
custom_negative: str = "",
|
||||
style_config: str | dict[str, Any] | None = "",
|
||||
) -> str:
|
||||
return style_policy.build_style_config_json(
|
||||
enabled=enabled,
|
||||
combine_mode=combine_mode,
|
||||
preset=preset,
|
||||
custom_style=custom_style,
|
||||
custom_positive_suffix=custom_positive_suffix,
|
||||
custom_negative=custom_negative,
|
||||
style_config=style_config,
|
||||
)
|
||||
|
||||
|
||||
def _parse_style_config(style_config: str | dict[str, Any] | None) -> dict[str, Any]:
|
||||
return style_policy.parse_style_config(style_config)
|
||||
|
||||
|
||||
def build_filter_config_json(
|
||||
ethnicity: str = "any",
|
||||
figure: str = "curvy",
|
||||
@@ -880,8 +915,9 @@ def _row_text_fields(
|
||||
category: dict[str, Any],
|
||||
subcategory: dict[str, Any],
|
||||
item: Any,
|
||||
style_config: str | dict[str, Any] | None = None,
|
||||
) -> row_rendering_policy.RowTextFields:
|
||||
return row_rendering_policy.resolve_row_text_fields(category, subcategory, item)
|
||||
return row_rendering_policy.resolve_row_text_fields(category, subcategory, item, style_config)
|
||||
|
||||
|
||||
def _clean_prompt_punctuation(text: str) -> str:
|
||||
@@ -2284,6 +2320,7 @@ def _build_custom_row(
|
||||
hardcore_position_config: str | dict[str, Any] | None = None,
|
||||
location_config: str | dict[str, Any] | None = None,
|
||||
composition_config: str | dict[str, Any] | None = None,
|
||||
style_config: str | dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
scene_rng = _axis_rng(seed_config, "scene", seed, row_number)
|
||||
pose_rng = _axis_rng(seed_config, "pose", seed, row_number)
|
||||
@@ -2421,7 +2458,7 @@ def _build_custom_row(
|
||||
position_key = action_route.position_key
|
||||
action_family = action_route.action_family
|
||||
|
||||
text_fields = _row_text_fields(category, subcategory, item)
|
||||
text_fields = _row_text_fields(category, subcategory, item, style_config)
|
||||
|
||||
assembly_request = row_assembly_policy.CustomRowAssemblyRequest(
|
||||
row_number=row_number,
|
||||
@@ -2542,6 +2579,7 @@ def build_prompt(
|
||||
hardcore_position_config: str | dict[str, Any] | None = None,
|
||||
location_config: str | dict[str, Any] | None = None,
|
||||
composition_config: str | dict[str, Any] | None = None,
|
||||
style_config: str | dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
return builder_prompt_route_policy.build_prompt(
|
||||
builder_prompt_route_policy.PromptBuildRequest(
|
||||
@@ -2575,6 +2613,7 @@ def build_prompt(
|
||||
hardcore_position_config=hardcore_position_config,
|
||||
location_config=location_config,
|
||||
composition_config=composition_config,
|
||||
style_config=style_config,
|
||||
),
|
||||
_prompt_build_dependencies(),
|
||||
)
|
||||
@@ -2605,6 +2644,7 @@ def build_prompt_from_configs(
|
||||
hardcore_position_config: str | dict[str, Any] | None = "",
|
||||
location_config: str | dict[str, Any] | None = "",
|
||||
composition_config: str | dict[str, Any] | None = "",
|
||||
style_config: str | dict[str, Any] | None = "",
|
||||
extra_positive: str = "",
|
||||
extra_negative: str = "",
|
||||
) -> dict[str, Any]:
|
||||
@@ -2624,6 +2664,7 @@ def build_prompt_from_configs(
|
||||
hardcore_position_config=hardcore_position_config,
|
||||
location_config=location_config,
|
||||
composition_config=composition_config,
|
||||
style_config=style_config,
|
||||
extra_positive=extra_positive,
|
||||
extra_negative=extra_negative,
|
||||
),
|
||||
@@ -2801,6 +2842,7 @@ def build_insta_of_pair(
|
||||
hardcore_position_config: str | dict[str, Any] | None = "",
|
||||
location_config: str | dict[str, Any] | None = "",
|
||||
composition_config: str | dict[str, Any] | None = "",
|
||||
style_config: str | dict[str, Any] | None = "",
|
||||
extra_positive: str = "",
|
||||
extra_negative: str = "",
|
||||
) -> dict[str, Any]:
|
||||
@@ -2825,6 +2867,7 @@ def build_insta_of_pair(
|
||||
hardcore_position_config=hardcore_position_config,
|
||||
location_config=location_config,
|
||||
composition_config=composition_config,
|
||||
style_config=style_config,
|
||||
extra_positive=extra_positive,
|
||||
extra_negative=extra_negative,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user