Extract row location policy
This commit is contained in:
+4
-98
@@ -41,6 +41,7 @@ try:
|
||||
from . import pair_options
|
||||
from . import row_normalization as row_policy
|
||||
from . import row_camera as row_camera_policy
|
||||
from . import row_location as row_location_policy
|
||||
from . import seed_config as seed_policy
|
||||
from .hardcore_text_cleanup import (
|
||||
sanitize_hardcore_axis_values as _sanitize_hardcore_axis_values,
|
||||
@@ -82,6 +83,7 @@ except ImportError: # Allows local smoke tests with `python -c`.
|
||||
import pair_options
|
||||
import row_normalization as row_policy
|
||||
import row_camera as row_camera_policy
|
||||
import row_location as row_location_policy
|
||||
import seed_config as seed_policy
|
||||
from hardcore_text_cleanup import (
|
||||
sanitize_hardcore_axis_values as _sanitize_hardcore_axis_values,
|
||||
@@ -3122,102 +3124,6 @@ def _scene_pool(
|
||||
return scene_entries or fallback
|
||||
|
||||
|
||||
def _legacy_scene_entries_for_row(row: dict[str, Any]) -> list[Any]:
|
||||
subject = str(row.get("primary_subject") or "").lower()
|
||||
if "group" in subject or "layout" in subject:
|
||||
return list(g.GROUP_SCENES)
|
||||
return list(g.SCENES)
|
||||
|
||||
|
||||
def _legacy_scene_text_for_slug(slug: str) -> str:
|
||||
for entry in list(g.SCENES) + list(g.GROUP_SCENES):
|
||||
entry_slug, entry_text = _pair_from(entry)
|
||||
if entry_slug == slug:
|
||||
return entry_text
|
||||
return ""
|
||||
|
||||
|
||||
def _apply_location_config_to_legacy_row(
|
||||
row: dict[str, Any],
|
||||
location_config: dict[str, Any],
|
||||
seed_config: dict[str, int],
|
||||
seed: int,
|
||||
row_number: int,
|
||||
) -> dict[str, Any]:
|
||||
if not _location_config_active(location_config):
|
||||
return row
|
||||
location_entries = _list_from(location_config.get("scene_entries"))
|
||||
if location_config.get("apply_mode") == "add":
|
||||
choices = _legacy_scene_entries_for_row(row)
|
||||
_unique_extend(choices, location_entries)
|
||||
else:
|
||||
choices = location_entries
|
||||
scene_rng = _axis_rng(seed_config, "scene", seed, row_number)
|
||||
scene_slug, scene_text = _choose_pair(scene_rng, choices)
|
||||
old_slug = str(row.get("scene") or "")
|
||||
old_text = _legacy_scene_text_for_slug(old_slug)
|
||||
row["source_scene"] = old_slug
|
||||
row["source_scene_text"] = old_text
|
||||
row["scene"] = scene_slug
|
||||
row["scene_text"] = scene_text
|
||||
row["location_config"] = location_config
|
||||
if old_text:
|
||||
row["prompt"] = str(row.get("prompt") or "").replace(f"Scene: {old_text}.", f"Scene: {scene_text}.")
|
||||
row["caption"] = str(row.get("caption") or "").replace(f", {old_text},", f", {scene_text},")
|
||||
else:
|
||||
row["prompt"] = re.sub(
|
||||
r"Scene:\s*.*?\.\s*Pose:",
|
||||
f"Scene: {scene_text}. Pose:",
|
||||
str(row.get("prompt") or ""),
|
||||
count=1,
|
||||
)
|
||||
return row
|
||||
|
||||
|
||||
def _legacy_composition_entries_for_row(row: dict[str, Any]) -> list[Any]:
|
||||
subject = str(row.get("primary_subject") or "").lower()
|
||||
if "group" in subject or "layout" in subject:
|
||||
return list(g.GROUP_COMPOSITIONS)
|
||||
return list(g.COMPOSITIONS)
|
||||
|
||||
|
||||
def _apply_composition_config_to_legacy_row(
|
||||
row: dict[str, Any],
|
||||
composition_config: dict[str, Any],
|
||||
seed_config: dict[str, int],
|
||||
seed: int,
|
||||
row_number: int,
|
||||
) -> dict[str, Any]:
|
||||
if not _composition_config_active(composition_config):
|
||||
return row
|
||||
composition_entries = _list_from(composition_config.get("composition_entries"))
|
||||
if composition_config.get("apply_mode") == "add":
|
||||
choices = _legacy_composition_entries_for_row(row)
|
||||
_unique_extend(choices, composition_entries)
|
||||
else:
|
||||
choices = composition_entries
|
||||
composition_rng = _axis_rng(seed_config, "composition", seed, row_number)
|
||||
new_composition = _choose_text(composition_rng, choices)
|
||||
old_composition = str(row.get("composition") or "")
|
||||
old_prompt_fragment = f"Composition: vertical {old_composition}."
|
||||
new_prompt_fragment = f"Composition: {_composition_prompt(new_composition)}."
|
||||
row["source_composition"] = old_composition
|
||||
row["composition"] = new_composition
|
||||
row["composition_prompt"] = _composition_prompt(new_composition)
|
||||
row["composition_config"] = composition_config
|
||||
if old_composition:
|
||||
row["prompt"] = str(row.get("prompt") or "").replace(old_prompt_fragment, new_prompt_fragment)
|
||||
row["caption"] = str(row.get("caption") or "").replace(f", {old_composition},", f", {new_composition},")
|
||||
else:
|
||||
row["prompt"] = re.sub(
|
||||
r"Composition:\s*.*?\.\s*Use",
|
||||
f"{new_prompt_fragment} Use",
|
||||
str(row.get("prompt") or ""),
|
||||
count=1,
|
||||
)
|
||||
return row
|
||||
|
||||
|
||||
def _expression_pool(category: dict[str, Any], subcategory: dict[str, Any], item: Any) -> list[Any]:
|
||||
return _configured_pool(
|
||||
category,
|
||||
@@ -3919,14 +3825,14 @@ def build_prompt(
|
||||
)
|
||||
|
||||
if row.get("source") == "built_in_generator":
|
||||
row = _apply_location_config_to_legacy_row(
|
||||
row = row_location_policy.apply_location_config_to_legacy_row(
|
||||
row,
|
||||
parsed_location_config,
|
||||
parsed_seed_config,
|
||||
seed,
|
||||
row_number,
|
||||
)
|
||||
row = _apply_composition_config_to_legacy_row(
|
||||
row = row_location_policy.apply_composition_config_to_legacy_row(
|
||||
row,
|
||||
parsed_composition_config,
|
||||
parsed_seed_config,
|
||||
|
||||
Reference in New Issue
Block a user