from __future__ import annotations from typing import Any try: from .prompt_hygiene import combine_negative_text, sanitize_caption_text, sanitize_negative_text, sanitize_prompt_text except ImportError: # Allows local smoke tests with `python tools/prompt_smoke.py`. from prompt_hygiene import combine_negative_text, sanitize_caption_text, sanitize_negative_text, sanitize_prompt_text def _trigger_tuple(active_trigger: str) -> tuple[str, ...]: trigger = str(active_trigger or "").strip() return (trigger,) if trigger else () def prepend_trigger(prompt: str, trigger: str, enabled: bool) -> str: trigger = str(trigger or "").strip() prompt = str(prompt or "") if not enabled or not trigger: return prompt if prompt.lower().startswith(trigger.lower()): return prompt return f"{trigger}, {prompt}" def combined_negative(base: str, extra: str) -> str: return combine_negative_text(base, extra) def caption_from_parts(parts: list[Any] | tuple[Any, ...], *, active_trigger: str = "") -> str: text = ", ".join(str(part).strip() for part in parts if str(part).strip()) return sanitize_caption_text(text, triggers=_trigger_tuple(active_trigger)) def _setdefault_nonempty(row: dict[str, Any], key: str, value: Any) -> None: if str(row.get(key) or "").strip(): return if str(value or "").strip(): row[key] = value def _setdefault_count(row: dict[str, Any], key: str, value: int) -> None: if str(row.get(key) or "").strip(): return row[key] = int(value) def _legacy_subject_metadata(row: dict[str, Any]) -> tuple[str, str, int | None, int | None]: subject = str(row.get("primary_subject") or row.get("subject") or "").strip() lower = subject.lower() if lower in ("woman", "adult woman"): return "woman", subject or "woman", 1, 0 if lower in ("man", "adult man"): return "man", subject or "man", 0, 1 if "two women" in lower: return "couple", subject or "two women", 2, 0 if "two men" in lower: return "couple", subject or "two men", 0, 2 if "woman" in lower and "man" in lower: return "couple", subject or "a woman and a man", 1, 1 if "group" in lower: return "group", subject or "mixed adult group", 2, 2 if "layout" in lower: return "layout", subject or "adult layout scene", None, None return "", subject, None, None def enrich_legacy_row_metadata(row: dict[str, Any]) -> dict[str, Any]: if row.get("source") != "built_in_generator": return row subject_type, subject_phrase, women_count, men_count = _legacy_subject_metadata(row) _setdefault_nonempty(row, "subject_type", subject_type) _setdefault_nonempty(row, "subject_phrase", subject_phrase) if women_count is not None: _setdefault_count(row, "women_count", women_count) if men_count is not None: _setdefault_count(row, "men_count", men_count) if women_count is not None and men_count is not None and not str(row.get("person_count") or "").strip(): row["person_count"] = int(women_count) + int(men_count) if str(row.get("scene") or "").strip() and not str(row.get("scene_slug") or "").strip(): row["scene_slug"] = row.get("scene") return row def normalize_prompt_row( row: dict[str, Any], *, active_trigger: str, prepend_trigger_to_prompt: bool, extra_positive: str = "", extra_negative: str = "", default_negative: str = "", ) -> dict[str, Any]: row = enrich_legacy_row_metadata(row) trigger = str(active_trigger or "").strip() positive = str(extra_positive or "").strip() prompt = str(row.get("prompt", "") or "") if positive: prompt = f"{prompt.rstrip()} {positive}".strip() prompt = prepend_trigger(prompt, trigger, bool(prepend_trigger_to_prompt)) row["prompt"] = sanitize_prompt_text(prompt, triggers=_trigger_tuple(trigger)) row["caption"] = sanitize_caption_text(row.get("caption", ""), triggers=_trigger_tuple(trigger)) row["negative_prompt"] = sanitize_negative_text( combined_negative(str(row.get("negative_prompt", default_negative) or ""), extra_negative) ) row["trigger"] = trigger return row def normalize_pair_text_outputs( *, active_trigger: str, prepend_trigger_to_prompt: bool, extra_positive: str = "", extra_negative: str = "", soft_prompt: str, hard_prompt: str, soft_negative_base: str, hard_negative_base: str, soft_caption_parts: list[Any] | tuple[Any, ...], hard_caption_parts: list[Any] | tuple[Any, ...], ) -> dict[str, str]: trigger = str(active_trigger or "").strip() positive = str(extra_positive or "").strip() if positive: soft_prompt = f"{str(soft_prompt or '').rstrip()} {positive}" hard_prompt = f"{str(hard_prompt or '').rstrip()} {positive}" soft_prompt = prepend_trigger(soft_prompt, trigger, bool(prepend_trigger_to_prompt)) hard_prompt = prepend_trigger(hard_prompt, trigger, bool(prepend_trigger_to_prompt)) return { "soft_prompt": sanitize_prompt_text(soft_prompt, triggers=_trigger_tuple(trigger)), "hard_prompt": sanitize_prompt_text(hard_prompt, triggers=_trigger_tuple(trigger)), "soft_negative": sanitize_negative_text(combined_negative(soft_negative_base, extra_negative)), "hard_negative": sanitize_negative_text(combined_negative(hard_negative_base, extra_negative)), "soft_caption": caption_from_parts(soft_caption_parts, active_trigger=trigger), "hard_caption": caption_from_parts(hard_caption_parts, active_trigger=trigger), } def sanitize_metadata_row_text(row: dict[str, Any], *, active_trigger: str = "") -> dict[str, Any]: trigger = str(active_trigger or row.get("trigger") or "").strip() triggers = _trigger_tuple(trigger) if "prompt" in row: row["prompt"] = sanitize_prompt_text(row.get("prompt", ""), triggers=triggers) if "caption" in row: row["caption"] = sanitize_caption_text(row.get("caption", ""), triggers=triggers) if "negative_prompt" in row: row["negative_prompt"] = sanitize_negative_text(row.get("negative_prompt", "")) if trigger and not row.get("trigger"): row["trigger"] = trigger return row def _sync_pair_root_row_field(pair: dict[str, Any], row_key: str, root_key: str, row_field: str) -> None: row = pair.get(row_key) if not isinstance(row, dict): return if root_key in pair: row[row_field] = pair.get(root_key) elif row_field in row: pair[root_key] = row.get(row_field) def synchronize_pair_row_outputs(pair: dict[str, Any]) -> dict[str, Any]: mapping = ( ("softcore_row", "softcore_prompt", "softcore_caption", "softcore_negative_prompt"), ("hardcore_row", "hardcore_prompt", "hardcore_caption", "hardcore_negative_prompt"), ) for row_key, prompt_key, caption_key, negative_key in mapping: _sync_pair_root_row_field(pair, row_key, prompt_key, "prompt") _sync_pair_root_row_field(pair, row_key, caption_key, "caption") _sync_pair_root_row_field(pair, row_key, negative_key, "negative_prompt") return pair def synchronize_pair_side_metadata(pair: dict[str, Any]) -> dict[str, Any]: side_keys = { "softcore_row": ( "softcore_partner_styling", ), "hardcore_row": ( "hardcore_clothing_state", "character_hardcore_clothing", "default_man_hardcore_clothing", "hardcore_detail_density", "hardcore_position_config", ), } for row_key, keys in side_keys.items(): for key in keys: _sync_pair_root_row_field(pair, row_key, key, key) return pair def synchronize_pair_camera_metadata(pair: dict[str, Any]) -> dict[str, Any]: mapping = { "softcore_row": ( ("softcore_camera_config", "camera_config"), ("softcore_camera_directive", "camera_directive"), ("softcore_camera_scene_directive", "camera_scene_directive"), ), "hardcore_row": ( ("hardcore_camera_config", "camera_config"), ("hardcore_camera_directive", "camera_directive"), ("hardcore_camera_scene_directive", "camera_scene_directive"), ), } for row_key, keys in mapping.items(): for source_key, target_key in keys: _sync_pair_root_row_field(pair, row_key, source_key, target_key) return pair def synchronize_pair_cast_metadata(pair: dict[str, Any]) -> dict[str, Any]: descriptors = pair.get("shared_cast_descriptors") if isinstance(descriptors, list): descriptor_list = [str(item).strip() for item in descriptors if str(item or "").strip()] descriptor_text = "; ".join(descriptor_list) else: descriptor_text = str(descriptors or "").strip() descriptor_list = [descriptor_text] if descriptor_text else [] if not descriptor_text: return pair options = pair.get("options") if isinstance(pair.get("options"), dict) else {} row_keys = ["hardcore_row"] if options.get("softcore_cast") == "same_as_hardcore": row_keys.append("softcore_row") for row_key in row_keys: row = pair.get(row_key) if not isinstance(row, dict): continue row["cast_descriptor_text"] = descriptor_text row["cast_descriptors"] = list(descriptor_list) return pair def normalize_pair_metadata(pair: dict[str, Any], *, active_trigger: str = "") -> dict[str, Any]: trigger = str(active_trigger or "").strip() triggers = _trigger_tuple(trigger) synchronize_pair_row_outputs(pair) synchronize_pair_side_metadata(pair) synchronize_pair_camera_metadata(pair) synchronize_pair_cast_metadata(pair) for key in ("softcore_prompt", "hardcore_prompt"): if key in pair: pair[key] = sanitize_prompt_text(pair.get(key, ""), triggers=triggers) for key in ("softcore_caption", "hardcore_caption"): if key in pair: pair[key] = sanitize_caption_text(pair.get(key, ""), triggers=triggers) for key in ("softcore_negative_prompt", "hardcore_negative_prompt"): if key in pair: pair[key] = sanitize_negative_text(pair.get(key, "")) for key in ("softcore_row", "hardcore_row"): if isinstance(pair.get(key), dict): pair[key] = sanitize_metadata_row_text(pair[key], active_trigger=trigger) return pair