Extract row location policy

This commit is contained in:
2026-06-27 03:09:17 +02:00
parent d4d3be5789
commit 7f808be997
5 changed files with 256 additions and 102 deletions
+4 -98
View File
@@ -41,6 +41,7 @@ try:
from . import pair_options
from . import row_normalization as row_policy
from . import row_camera as row_camera_policy
from . import row_location as row_location_policy
from . import seed_config as seed_policy
from .hardcore_text_cleanup import (
sanitize_hardcore_axis_values as _sanitize_hardcore_axis_values,
@@ -82,6 +83,7 @@ except ImportError: # Allows local smoke tests with `python -c`.
import pair_options
import row_normalization as row_policy
import row_camera as row_camera_policy
import row_location as row_location_policy
import seed_config as seed_policy
from hardcore_text_cleanup import (
sanitize_hardcore_axis_values as _sanitize_hardcore_axis_values,
@@ -3122,102 +3124,6 @@ def _scene_pool(
return scene_entries or fallback
def _legacy_scene_entries_for_row(row: dict[str, Any]) -> list[Any]:
subject = str(row.get("primary_subject") or "").lower()
if "group" in subject or "layout" in subject:
return list(g.GROUP_SCENES)
return list(g.SCENES)
def _legacy_scene_text_for_slug(slug: str) -> str:
for entry in list(g.SCENES) + list(g.GROUP_SCENES):
entry_slug, entry_text = _pair_from(entry)
if entry_slug == slug:
return entry_text
return ""
def _apply_location_config_to_legacy_row(
row: dict[str, Any],
location_config: dict[str, Any],
seed_config: dict[str, int],
seed: int,
row_number: int,
) -> dict[str, Any]:
if not _location_config_active(location_config):
return row
location_entries = _list_from(location_config.get("scene_entries"))
if location_config.get("apply_mode") == "add":
choices = _legacy_scene_entries_for_row(row)
_unique_extend(choices, location_entries)
else:
choices = location_entries
scene_rng = _axis_rng(seed_config, "scene", seed, row_number)
scene_slug, scene_text = _choose_pair(scene_rng, choices)
old_slug = str(row.get("scene") or "")
old_text = _legacy_scene_text_for_slug(old_slug)
row["source_scene"] = old_slug
row["source_scene_text"] = old_text
row["scene"] = scene_slug
row["scene_text"] = scene_text
row["location_config"] = location_config
if old_text:
row["prompt"] = str(row.get("prompt") or "").replace(f"Scene: {old_text}.", f"Scene: {scene_text}.")
row["caption"] = str(row.get("caption") or "").replace(f", {old_text},", f", {scene_text},")
else:
row["prompt"] = re.sub(
r"Scene:\s*.*?\.\s*Pose:",
f"Scene: {scene_text}. Pose:",
str(row.get("prompt") or ""),
count=1,
)
return row
def _legacy_composition_entries_for_row(row: dict[str, Any]) -> list[Any]:
subject = str(row.get("primary_subject") or "").lower()
if "group" in subject or "layout" in subject:
return list(g.GROUP_COMPOSITIONS)
return list(g.COMPOSITIONS)
def _apply_composition_config_to_legacy_row(
row: dict[str, Any],
composition_config: dict[str, Any],
seed_config: dict[str, int],
seed: int,
row_number: int,
) -> dict[str, Any]:
if not _composition_config_active(composition_config):
return row
composition_entries = _list_from(composition_config.get("composition_entries"))
if composition_config.get("apply_mode") == "add":
choices = _legacy_composition_entries_for_row(row)
_unique_extend(choices, composition_entries)
else:
choices = composition_entries
composition_rng = _axis_rng(seed_config, "composition", seed, row_number)
new_composition = _choose_text(composition_rng, choices)
old_composition = str(row.get("composition") or "")
old_prompt_fragment = f"Composition: vertical {old_composition}."
new_prompt_fragment = f"Composition: {_composition_prompt(new_composition)}."
row["source_composition"] = old_composition
row["composition"] = new_composition
row["composition_prompt"] = _composition_prompt(new_composition)
row["composition_config"] = composition_config
if old_composition:
row["prompt"] = str(row.get("prompt") or "").replace(old_prompt_fragment, new_prompt_fragment)
row["caption"] = str(row.get("caption") or "").replace(f", {old_composition},", f", {new_composition},")
else:
row["prompt"] = re.sub(
r"Composition:\s*.*?\.\s*Use",
f"{new_prompt_fragment} Use",
str(row.get("prompt") or ""),
count=1,
)
return row
def _expression_pool(category: dict[str, Any], subcategory: dict[str, Any], item: Any) -> list[Any]:
return _configured_pool(
category,
@@ -3919,14 +3825,14 @@ def build_prompt(
)
if row.get("source") == "built_in_generator":
row = _apply_location_config_to_legacy_row(
row = row_location_policy.apply_location_config_to_legacy_row(
row,
parsed_location_config,
parsed_seed_config,
seed,
row_number,
)
row = _apply_composition_config_to_legacy_row(
row = row_location_policy.apply_composition_config_to_legacy_row(
row,
parsed_composition_config,
parsed_seed_config,