207 lines
8.2 KiB
Python
207 lines
8.2 KiB
Python
from __future__ import annotations
|
|
|
|
from typing import Any
|
|
|
|
try:
|
|
from .prompt_hygiene import combine_negative_text, sanitize_caption_text, sanitize_negative_text, sanitize_prompt_text
|
|
except ImportError: # Allows local smoke tests with `python tools/prompt_smoke.py`.
|
|
from prompt_hygiene import combine_negative_text, sanitize_caption_text, sanitize_negative_text, sanitize_prompt_text
|
|
|
|
|
|
def _trigger_tuple(active_trigger: str) -> tuple[str, ...]:
|
|
trigger = str(active_trigger or "").strip()
|
|
return (trigger,) if trigger else ()
|
|
|
|
|
|
def prepend_trigger(prompt: str, trigger: str, enabled: bool) -> str:
|
|
trigger = str(trigger or "").strip()
|
|
prompt = str(prompt or "")
|
|
if not enabled or not trigger:
|
|
return prompt
|
|
if prompt.lower().startswith(trigger.lower()):
|
|
return prompt
|
|
return f"{trigger}, {prompt}"
|
|
|
|
|
|
def combined_negative(base: str, extra: str) -> str:
|
|
return combine_negative_text(base, extra)
|
|
|
|
|
|
def caption_from_parts(parts: list[Any] | tuple[Any, ...], *, active_trigger: str = "") -> str:
|
|
text = ", ".join(str(part).strip() for part in parts if str(part).strip())
|
|
return sanitize_caption_text(text, triggers=_trigger_tuple(active_trigger))
|
|
|
|
|
|
def normalize_prompt_row(
|
|
row: dict[str, Any],
|
|
*,
|
|
active_trigger: str,
|
|
prepend_trigger_to_prompt: bool,
|
|
extra_positive: str = "",
|
|
extra_negative: str = "",
|
|
default_negative: str = "",
|
|
) -> dict[str, Any]:
|
|
trigger = str(active_trigger or "").strip()
|
|
positive = str(extra_positive or "").strip()
|
|
prompt = str(row.get("prompt", "") or "")
|
|
if positive:
|
|
prompt = f"{prompt.rstrip()} {positive}".strip()
|
|
prompt = prepend_trigger(prompt, trigger, bool(prepend_trigger_to_prompt))
|
|
row["prompt"] = sanitize_prompt_text(prompt, triggers=_trigger_tuple(trigger))
|
|
row["caption"] = sanitize_caption_text(row.get("caption", ""), triggers=_trigger_tuple(trigger))
|
|
row["negative_prompt"] = sanitize_negative_text(
|
|
combined_negative(str(row.get("negative_prompt", default_negative) or ""), extra_negative)
|
|
)
|
|
row["trigger"] = trigger
|
|
return row
|
|
|
|
|
|
def normalize_pair_text_outputs(
|
|
*,
|
|
active_trigger: str,
|
|
prepend_trigger_to_prompt: bool,
|
|
extra_positive: str = "",
|
|
extra_negative: str = "",
|
|
soft_prompt: str,
|
|
hard_prompt: str,
|
|
soft_negative_base: str,
|
|
hard_negative_base: str,
|
|
soft_caption_parts: list[Any] | tuple[Any, ...],
|
|
hard_caption_parts: list[Any] | tuple[Any, ...],
|
|
) -> dict[str, str]:
|
|
trigger = str(active_trigger or "").strip()
|
|
positive = str(extra_positive or "").strip()
|
|
if positive:
|
|
soft_prompt = f"{str(soft_prompt or '').rstrip()} {positive}"
|
|
hard_prompt = f"{str(hard_prompt or '').rstrip()} {positive}"
|
|
soft_prompt = prepend_trigger(soft_prompt, trigger, bool(prepend_trigger_to_prompt))
|
|
hard_prompt = prepend_trigger(hard_prompt, trigger, bool(prepend_trigger_to_prompt))
|
|
return {
|
|
"soft_prompt": sanitize_prompt_text(soft_prompt, triggers=_trigger_tuple(trigger)),
|
|
"hard_prompt": sanitize_prompt_text(hard_prompt, triggers=_trigger_tuple(trigger)),
|
|
"soft_negative": sanitize_negative_text(combined_negative(soft_negative_base, extra_negative)),
|
|
"hard_negative": sanitize_negative_text(combined_negative(hard_negative_base, extra_negative)),
|
|
"soft_caption": caption_from_parts(soft_caption_parts, active_trigger=trigger),
|
|
"hard_caption": caption_from_parts(hard_caption_parts, active_trigger=trigger),
|
|
}
|
|
|
|
|
|
def sanitize_metadata_row_text(row: dict[str, Any], *, active_trigger: str = "") -> dict[str, Any]:
|
|
trigger = str(active_trigger or row.get("trigger") or "").strip()
|
|
triggers = _trigger_tuple(trigger)
|
|
if "prompt" in row:
|
|
row["prompt"] = sanitize_prompt_text(row.get("prompt", ""), triggers=triggers)
|
|
if "caption" in row:
|
|
row["caption"] = sanitize_caption_text(row.get("caption", ""), triggers=triggers)
|
|
if "negative_prompt" in row:
|
|
row["negative_prompt"] = sanitize_negative_text(row.get("negative_prompt", ""))
|
|
if trigger and not row.get("trigger"):
|
|
row["trigger"] = trigger
|
|
return row
|
|
|
|
|
|
def _sync_pair_root_row_field(pair: dict[str, Any], row_key: str, root_key: str, row_field: str) -> None:
|
|
row = pair.get(row_key)
|
|
if not isinstance(row, dict):
|
|
return
|
|
if root_key in pair:
|
|
row[row_field] = pair.get(root_key)
|
|
elif row_field in row:
|
|
pair[root_key] = row.get(row_field)
|
|
|
|
|
|
def synchronize_pair_row_outputs(pair: dict[str, Any]) -> dict[str, Any]:
|
|
mapping = (
|
|
("softcore_row", "softcore_prompt", "softcore_caption", "softcore_negative_prompt"),
|
|
("hardcore_row", "hardcore_prompt", "hardcore_caption", "hardcore_negative_prompt"),
|
|
)
|
|
for row_key, prompt_key, caption_key, negative_key in mapping:
|
|
_sync_pair_root_row_field(pair, row_key, prompt_key, "prompt")
|
|
_sync_pair_root_row_field(pair, row_key, caption_key, "caption")
|
|
_sync_pair_root_row_field(pair, row_key, negative_key, "negative_prompt")
|
|
return pair
|
|
|
|
|
|
def synchronize_pair_side_metadata(pair: dict[str, Any]) -> dict[str, Any]:
|
|
side_keys = {
|
|
"softcore_row": (
|
|
"softcore_partner_styling",
|
|
),
|
|
"hardcore_row": (
|
|
"hardcore_clothing_state",
|
|
"character_hardcore_clothing",
|
|
"default_man_hardcore_clothing",
|
|
"hardcore_detail_density",
|
|
"hardcore_position_config",
|
|
),
|
|
}
|
|
for row_key, keys in side_keys.items():
|
|
for key in keys:
|
|
_sync_pair_root_row_field(pair, row_key, key, key)
|
|
return pair
|
|
|
|
|
|
def synchronize_pair_camera_metadata(pair: dict[str, Any]) -> dict[str, Any]:
|
|
mapping = {
|
|
"softcore_row": (
|
|
("softcore_camera_config", "camera_config"),
|
|
("softcore_camera_directive", "camera_directive"),
|
|
("softcore_camera_scene_directive", "camera_scene_directive"),
|
|
),
|
|
"hardcore_row": (
|
|
("hardcore_camera_config", "camera_config"),
|
|
("hardcore_camera_directive", "camera_directive"),
|
|
("hardcore_camera_scene_directive", "camera_scene_directive"),
|
|
),
|
|
}
|
|
for row_key, keys in mapping.items():
|
|
for source_key, target_key in keys:
|
|
_sync_pair_root_row_field(pair, row_key, source_key, target_key)
|
|
return pair
|
|
|
|
|
|
def synchronize_pair_cast_metadata(pair: dict[str, Any]) -> dict[str, Any]:
|
|
descriptors = pair.get("shared_cast_descriptors")
|
|
if isinstance(descriptors, list):
|
|
descriptor_list = [str(item).strip() for item in descriptors if str(item or "").strip()]
|
|
descriptor_text = "; ".join(descriptor_list)
|
|
else:
|
|
descriptor_text = str(descriptors or "").strip()
|
|
descriptor_list = [descriptor_text] if descriptor_text else []
|
|
if not descriptor_text:
|
|
return pair
|
|
|
|
options = pair.get("options") if isinstance(pair.get("options"), dict) else {}
|
|
row_keys = ["hardcore_row"]
|
|
if options.get("softcore_cast") == "same_as_hardcore":
|
|
row_keys.append("softcore_row")
|
|
for row_key in row_keys:
|
|
row = pair.get(row_key)
|
|
if not isinstance(row, dict):
|
|
continue
|
|
row["cast_descriptor_text"] = descriptor_text
|
|
row["cast_descriptors"] = list(descriptor_list)
|
|
return pair
|
|
|
|
|
|
def normalize_pair_metadata(pair: dict[str, Any], *, active_trigger: str = "") -> dict[str, Any]:
|
|
trigger = str(active_trigger or "").strip()
|
|
triggers = _trigger_tuple(trigger)
|
|
synchronize_pair_row_outputs(pair)
|
|
synchronize_pair_side_metadata(pair)
|
|
synchronize_pair_camera_metadata(pair)
|
|
synchronize_pair_cast_metadata(pair)
|
|
for key in ("softcore_prompt", "hardcore_prompt"):
|
|
if key in pair:
|
|
pair[key] = sanitize_prompt_text(pair.get(key, ""), triggers=triggers)
|
|
for key in ("softcore_caption", "hardcore_caption"):
|
|
if key in pair:
|
|
pair[key] = sanitize_caption_text(pair.get(key, ""), triggers=triggers)
|
|
for key in ("softcore_negative_prompt", "hardcore_negative_prompt"):
|
|
if key in pair:
|
|
pair[key] = sanitize_negative_text(pair.get(key, ""))
|
|
for key in ("softcore_row", "hardcore_row"):
|
|
if isinstance(pair.get(key), dict):
|
|
pair[key] = sanitize_metadata_row_text(pair[key], active_trigger=trigger)
|
|
return pair
|