236 lines
8.0 KiB
Python
236 lines
8.0 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
import random
|
|
import re
|
|
from typing import Any
|
|
|
|
try:
|
|
from . import generate_prompt_batches as g
|
|
from . import location_config as location_policy
|
|
from . import row_camera
|
|
from . import seed_config as seed_policy
|
|
except ImportError: # Allows local smoke tests with top-level imports.
|
|
import generate_prompt_batches as g
|
|
import location_config as location_policy
|
|
import row_camera
|
|
import seed_config as seed_policy
|
|
|
|
|
|
def _list_from(value: Any) -> list[Any]:
|
|
if value is None:
|
|
return []
|
|
if isinstance(value, list):
|
|
return value
|
|
return [value]
|
|
|
|
|
|
def _unique_extend(target: list[Any], additions: list[Any]) -> None:
|
|
seen = set()
|
|
for item in target:
|
|
try:
|
|
seen.add(json.dumps(item, sort_keys=True))
|
|
except TypeError:
|
|
seen.add(repr(item))
|
|
for item in additions:
|
|
try:
|
|
marker = json.dumps(item, sort_keys=True)
|
|
except TypeError:
|
|
marker = repr(item)
|
|
if marker not in seen:
|
|
target.append(item)
|
|
seen.add(marker)
|
|
|
|
|
|
def _pair_from(value: Any) -> tuple[str, str]:
|
|
if isinstance(value, dict):
|
|
text = str(
|
|
value.get("prompt")
|
|
or value.get("description")
|
|
or value.get("text")
|
|
or value.get("name")
|
|
or ""
|
|
).strip()
|
|
slug = str(value.get("slug") or g.slugify(str(value.get("name") or text)) or "custom").strip()
|
|
if not text:
|
|
raise ValueError(f"Pair extension is missing prompt text: {value!r}")
|
|
return slug, text
|
|
if isinstance(value, (list, tuple)) and len(value) == 2:
|
|
return str(value[0]), str(value[1])
|
|
text = str(value).strip()
|
|
if not text:
|
|
raise ValueError("Pair extension cannot be empty")
|
|
return g.slugify(text) or "custom", text
|
|
|
|
|
|
def _weighted_choice(rng: random.Random, items: list[Any]) -> Any:
|
|
if not items:
|
|
raise ValueError("Cannot choose from an empty list")
|
|
weights: list[float] = []
|
|
for item in items:
|
|
weight = item.get("weight", 1.0) if isinstance(item, dict) else 1.0
|
|
try:
|
|
weights.append(max(0.0, float(weight)))
|
|
except (TypeError, ValueError):
|
|
weights.append(1.0)
|
|
total = sum(weights)
|
|
if total <= 0:
|
|
return items[rng.randrange(len(items))]
|
|
pick = rng.random() * total
|
|
running = 0.0
|
|
for item, weight in zip(items, weights):
|
|
running += weight
|
|
if pick <= running:
|
|
return item
|
|
return items[-1]
|
|
|
|
|
|
def _choose_pair(rng: random.Random, items: list[Any]) -> tuple[str, str]:
|
|
return _pair_from(_weighted_choice(rng, items))
|
|
|
|
|
|
def _metadata_entry(value: Any, *, slug: str = "", text: str = "") -> dict[str, Any]:
|
|
if isinstance(value, dict):
|
|
entry = dict(value)
|
|
elif isinstance(value, (list, tuple)) and len(value) == 2:
|
|
entry = {"slug": str(value[0]), "prompt": str(value[1])}
|
|
else:
|
|
entry = {"prompt": str(value or "")}
|
|
if slug:
|
|
entry["slug"] = slug
|
|
if text:
|
|
if "prompt" in entry:
|
|
entry["prompt"] = text
|
|
elif "text" in entry:
|
|
entry["text"] = text
|
|
else:
|
|
entry["prompt"] = text
|
|
return entry
|
|
|
|
|
|
def _choose_text(rng: random.Random, items: list[Any]) -> str:
|
|
item = _weighted_choice(rng, items)
|
|
return _text_from_entry(item)
|
|
|
|
|
|
def _text_from_entry(item: Any) -> str:
|
|
if isinstance(item, dict):
|
|
return str(
|
|
item.get("template")
|
|
or item.get("prompt")
|
|
or item.get("text")
|
|
or item.get("description")
|
|
or item.get("name")
|
|
or ""
|
|
).strip()
|
|
return str(item).strip()
|
|
|
|
|
|
def legacy_scene_entries_for_row(row: dict[str, Any]) -> list[Any]:
|
|
subject = str(row.get("primary_subject") or "").lower()
|
|
if "group" in subject or "layout" in subject:
|
|
return list(g.GROUP_SCENES)
|
|
return list(g.SCENES)
|
|
|
|
|
|
def legacy_scene_text_for_slug(slug: str) -> str:
|
|
for entry in list(g.SCENES) + list(g.GROUP_SCENES):
|
|
entry_slug, entry_text = _pair_from(entry)
|
|
if entry_slug == slug:
|
|
return entry_text
|
|
return ""
|
|
|
|
|
|
def apply_location_config_to_legacy_row(
|
|
row: dict[str, Any],
|
|
location_config: dict[str, Any],
|
|
seed_config: dict[str, int],
|
|
seed: int,
|
|
row_number: int,
|
|
) -> dict[str, Any]:
|
|
if not location_policy.location_config_active(location_config):
|
|
return row
|
|
location_entries = _list_from(location_config.get("scene_entries"))
|
|
if location_config.get("apply_mode") == "add":
|
|
choices = legacy_scene_entries_for_row(row)
|
|
_unique_extend(choices, location_entries)
|
|
else:
|
|
choices = location_entries
|
|
scene_rng = seed_policy.axis_rng(seed_config, "scene", seed, row_number)
|
|
scene_choice = _weighted_choice(scene_rng, choices)
|
|
scene_slug, scene_text = _pair_from(scene_choice)
|
|
scene_entry = _metadata_entry(scene_choice, slug=scene_slug, text=scene_text)
|
|
old_slug = str(row.get("scene") or "")
|
|
old_text = legacy_scene_text_for_slug(old_slug)
|
|
row["source_scene"] = old_slug
|
|
row["source_scene_text"] = old_text
|
|
row["scene"] = scene_slug
|
|
row["scene_text"] = scene_text
|
|
row["scene_entry"] = scene_entry
|
|
row["location_theme"] = str(location_config.get("theme") or "")
|
|
row["scene_theme"] = scene_entry.get("theme", "") or (
|
|
str(location_config.get("theme") or "")
|
|
if location_config.get("apply_mode") == "replace"
|
|
else ""
|
|
)
|
|
row["location_config"] = location_config
|
|
if old_text:
|
|
row["prompt"] = str(row.get("prompt") or "").replace(f"Scene: {old_text}.", f"Scene: {scene_text}.")
|
|
row["caption"] = str(row.get("caption") or "").replace(f", {old_text},", f", {scene_text},")
|
|
else:
|
|
row["prompt"] = re.sub(
|
|
r"Scene:\s*.*?\.\s*Pose:",
|
|
f"Scene: {scene_text}. Pose:",
|
|
str(row.get("prompt") or ""),
|
|
count=1,
|
|
)
|
|
return row
|
|
|
|
|
|
def legacy_composition_entries_for_row(row: dict[str, Any]) -> list[Any]:
|
|
subject = str(row.get("primary_subject") or "").lower()
|
|
if "group" in subject or "layout" in subject:
|
|
return list(g.GROUP_COMPOSITIONS)
|
|
return list(g.COMPOSITIONS)
|
|
|
|
|
|
def apply_composition_config_to_legacy_row(
|
|
row: dict[str, Any],
|
|
composition_config: dict[str, Any],
|
|
seed_config: dict[str, int],
|
|
seed: int,
|
|
row_number: int,
|
|
) -> dict[str, Any]:
|
|
if not location_policy.composition_config_active(composition_config):
|
|
return row
|
|
composition_entries = _list_from(composition_config.get("composition_entries"))
|
|
if composition_config.get("apply_mode") == "add":
|
|
choices = legacy_composition_entries_for_row(row)
|
|
_unique_extend(choices, composition_entries)
|
|
else:
|
|
choices = composition_entries
|
|
composition_rng = seed_policy.axis_rng(seed_config, "composition", seed, row_number)
|
|
composition_choice = _weighted_choice(composition_rng, choices)
|
|
new_composition = _text_from_entry(composition_choice)
|
|
composition_entry = _metadata_entry(composition_choice, text=new_composition)
|
|
old_composition = str(row.get("composition") or "")
|
|
old_prompt_fragment = f"Composition: vertical {old_composition}."
|
|
new_prompt_fragment = f"Composition: {row_camera.composition_prompt(new_composition)}."
|
|
row["source_composition"] = old_composition
|
|
row["composition"] = new_composition
|
|
row["composition_entry"] = composition_entry
|
|
row["composition_theme"] = str(composition_config.get("theme") or "")
|
|
row["composition_prompt"] = row_camera.composition_prompt(new_composition)
|
|
row["composition_config"] = composition_config
|
|
if old_composition:
|
|
row["prompt"] = str(row.get("prompt") or "").replace(old_prompt_fragment, new_prompt_fragment)
|
|
row["caption"] = str(row.get("caption") or "").replace(f", {old_composition},", f", {new_composition},")
|
|
else:
|
|
row["prompt"] = re.sub(
|
|
r"Composition:\s*.*?\.\s*Use",
|
|
f"{new_prompt_fragment} Use",
|
|
str(row.get("prompt") or ""),
|
|
count=1,
|
|
)
|
|
return row
|