from __future__ import annotations import json import random import re from typing import Any try: from . import generate_prompt_batches as g from . import location_config as location_policy from . import row_camera from . import seed_config as seed_policy except ImportError: # Allows local smoke tests with top-level imports. import generate_prompt_batches as g import location_config as location_policy import row_camera import seed_config as seed_policy def _list_from(value: Any) -> list[Any]: if value is None: return [] if isinstance(value, list): return value return [value] def _unique_extend(target: list[Any], additions: list[Any]) -> None: seen = set() for item in target: try: seen.add(json.dumps(item, sort_keys=True)) except TypeError: seen.add(repr(item)) for item in additions: try: marker = json.dumps(item, sort_keys=True) except TypeError: marker = repr(item) if marker not in seen: target.append(item) seen.add(marker) def _pair_from(value: Any) -> tuple[str, str]: if isinstance(value, dict): text = str( value.get("prompt") or value.get("description") or value.get("text") or value.get("name") or "" ).strip() slug = str(value.get("slug") or g.slugify(str(value.get("name") or text)) or "custom").strip() if not text: raise ValueError(f"Pair extension is missing prompt text: {value!r}") return slug, text if isinstance(value, (list, tuple)) and len(value) == 2: return str(value[0]), str(value[1]) text = str(value).strip() if not text: raise ValueError("Pair extension cannot be empty") return g.slugify(text) or "custom", text def _weighted_choice(rng: random.Random, items: list[Any]) -> Any: if not items: raise ValueError("Cannot choose from an empty list") weights: list[float] = [] for item in items: weight = item.get("weight", 1.0) if isinstance(item, dict) else 1.0 try: weights.append(max(0.0, float(weight))) except (TypeError, ValueError): weights.append(1.0) total = sum(weights) if total <= 0: return items[rng.randrange(len(items))] pick = rng.random() * total running = 0.0 for item, weight in zip(items, weights): running += weight if pick <= running: return item return items[-1] def _choose_pair(rng: random.Random, items: list[Any]) -> tuple[str, str]: return _pair_from(_weighted_choice(rng, items)) def _metadata_entry(value: Any, *, slug: str = "", text: str = "") -> dict[str, Any]: if isinstance(value, dict): entry = dict(value) elif isinstance(value, (list, tuple)) and len(value) == 2: entry = {"slug": str(value[0]), "prompt": str(value[1])} else: entry = {"prompt": str(value or "")} if slug: entry["slug"] = slug if text: if "prompt" in entry: entry["prompt"] = text elif "text" in entry: entry["text"] = text else: entry["prompt"] = text return entry def _choose_text(rng: random.Random, items: list[Any]) -> str: item = _weighted_choice(rng, items) return _text_from_entry(item) def _text_from_entry(item: Any) -> str: if isinstance(item, dict): return str( item.get("template") or item.get("prompt") or item.get("text") or item.get("description") or item.get("name") or "" ).strip() return str(item).strip() def legacy_scene_entries_for_row(row: dict[str, Any]) -> list[Any]: subject = str(row.get("primary_subject") or "").lower() if "group" in subject or "layout" in subject: return list(g.GROUP_SCENES) return list(g.SCENES) def legacy_scene_text_for_slug(slug: str) -> str: for entry in list(g.SCENES) + list(g.GROUP_SCENES): entry_slug, entry_text = _pair_from(entry) if entry_slug == slug: return entry_text return "" def apply_location_config_to_legacy_row( row: dict[str, Any], location_config: dict[str, Any], seed_config: dict[str, int], seed: int, row_number: int, ) -> dict[str, Any]: if not location_policy.location_config_active(location_config): return row location_entries = _list_from(location_config.get("scene_entries")) if location_config.get("apply_mode") == "add": choices = legacy_scene_entries_for_row(row) _unique_extend(choices, location_entries) else: choices = location_entries scene_rng = seed_policy.axis_rng(seed_config, "scene", seed, row_number) scene_choice = _weighted_choice(scene_rng, choices) scene_slug, scene_text = _pair_from(scene_choice) scene_entry = _metadata_entry(scene_choice, slug=scene_slug, text=scene_text) old_slug = str(row.get("scene") or "") old_text = legacy_scene_text_for_slug(old_slug) row["source_scene"] = old_slug row["source_scene_text"] = old_text row["scene"] = scene_slug row["scene_text"] = scene_text row["scene_entry"] = scene_entry row["location_theme"] = str(location_config.get("theme") or "") row["scene_theme"] = scene_entry.get("theme", "") or ( str(location_config.get("theme") or "") if location_config.get("apply_mode") == "replace" else "" ) row["location_config"] = location_config if old_text: row["prompt"] = str(row.get("prompt") or "").replace(f"Scene: {old_text}.", f"Scene: {scene_text}.") row["caption"] = str(row.get("caption") or "").replace(f", {old_text},", f", {scene_text},") else: row["prompt"] = re.sub( r"Scene:\s*.*?\.\s*Pose:", f"Scene: {scene_text}. Pose:", str(row.get("prompt") or ""), count=1, ) return row def legacy_composition_entries_for_row(row: dict[str, Any]) -> list[Any]: subject = str(row.get("primary_subject") or "").lower() if "group" in subject or "layout" in subject: return list(g.GROUP_COMPOSITIONS) return list(g.COMPOSITIONS) def apply_composition_config_to_legacy_row( row: dict[str, Any], composition_config: dict[str, Any], seed_config: dict[str, int], seed: int, row_number: int, ) -> dict[str, Any]: if not location_policy.composition_config_active(composition_config): return row composition_entries = _list_from(composition_config.get("composition_entries")) if composition_config.get("apply_mode") == "add": choices = legacy_composition_entries_for_row(row) _unique_extend(choices, composition_entries) else: choices = composition_entries composition_rng = seed_policy.axis_rng(seed_config, "composition", seed, row_number) composition_choice = _weighted_choice(composition_rng, choices) new_composition = _text_from_entry(composition_choice) composition_entry = _metadata_entry(composition_choice, text=new_composition) old_composition = str(row.get("composition") or "") old_prompt_fragment = f"Composition: vertical {old_composition}." new_prompt_fragment = f"Composition: {row_camera.composition_prompt(new_composition)}." row["source_composition"] = old_composition row["composition"] = new_composition row["composition_entry"] = composition_entry row["composition_theme"] = str(composition_config.get("theme") or "") row["composition_prompt"] = row_camera.composition_prompt(new_composition) row["composition_config"] = composition_config if old_composition: row["prompt"] = str(row.get("prompt") or "").replace(old_prompt_fragment, new_prompt_fragment) row["caption"] = str(row.get("caption") or "").replace(f", {old_composition},", f", {new_composition},") else: row["prompt"] = re.sub( r"Composition:\s*.*?\.\s*Use", f"{new_prompt_fragment} Use", str(row.get("prompt") or ""), count=1, ) return row