Extract row normalization policy
This commit is contained in:
@@ -0,0 +1,119 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
try:
|
||||
from .prompt_hygiene import sanitize_caption_text, sanitize_negative_text, sanitize_prompt_text
|
||||
except ImportError: # Allows local smoke tests with `python tools/prompt_smoke.py`.
|
||||
from prompt_hygiene import sanitize_caption_text, sanitize_negative_text, sanitize_prompt_text
|
||||
|
||||
|
||||
def _trigger_tuple(active_trigger: str) -> tuple[str, ...]:
|
||||
trigger = str(active_trigger or "").strip()
|
||||
return (trigger,) if trigger else ()
|
||||
|
||||
|
||||
def prepend_trigger(prompt: str, trigger: str, enabled: bool) -> str:
|
||||
trigger = str(trigger or "").strip()
|
||||
prompt = str(prompt or "")
|
||||
if not enabled or not trigger:
|
||||
return prompt
|
||||
if prompt.lower().startswith(trigger.lower()):
|
||||
return prompt
|
||||
return f"{trigger}, {prompt}"
|
||||
|
||||
|
||||
def combined_negative(base: str, extra: str) -> str:
|
||||
parts = [str(part).strip() for part in (base, extra) if part and str(part).strip()]
|
||||
return ", ".join(parts)
|
||||
|
||||
|
||||
def caption_from_parts(parts: list[Any] | tuple[Any, ...], *, active_trigger: str = "") -> str:
|
||||
text = ", ".join(str(part).strip() for part in parts if str(part).strip())
|
||||
return sanitize_caption_text(text, triggers=_trigger_tuple(active_trigger))
|
||||
|
||||
|
||||
def normalize_prompt_row(
|
||||
row: dict[str, Any],
|
||||
*,
|
||||
active_trigger: str,
|
||||
prepend_trigger_to_prompt: bool,
|
||||
extra_positive: str = "",
|
||||
extra_negative: str = "",
|
||||
default_negative: str = "",
|
||||
) -> dict[str, Any]:
|
||||
trigger = str(active_trigger or "").strip()
|
||||
positive = str(extra_positive or "").strip()
|
||||
prompt = str(row.get("prompt", "") or "")
|
||||
if positive:
|
||||
prompt = f"{prompt.rstrip()} {positive}".strip()
|
||||
prompt = prepend_trigger(prompt, trigger, bool(prepend_trigger_to_prompt))
|
||||
row["prompt"] = sanitize_prompt_text(prompt, triggers=_trigger_tuple(trigger))
|
||||
row["caption"] = sanitize_caption_text(row.get("caption", ""), triggers=_trigger_tuple(trigger))
|
||||
row["negative_prompt"] = sanitize_negative_text(
|
||||
combined_negative(str(row.get("negative_prompt", default_negative) or ""), extra_negative)
|
||||
)
|
||||
row["trigger"] = trigger
|
||||
return row
|
||||
|
||||
|
||||
def normalize_pair_text_outputs(
|
||||
*,
|
||||
active_trigger: str,
|
||||
prepend_trigger_to_prompt: bool,
|
||||
extra_positive: str = "",
|
||||
extra_negative: str = "",
|
||||
soft_prompt: str,
|
||||
hard_prompt: str,
|
||||
soft_negative_base: str,
|
||||
hard_negative_base: str,
|
||||
soft_caption_parts: list[Any] | tuple[Any, ...],
|
||||
hard_caption_parts: list[Any] | tuple[Any, ...],
|
||||
) -> dict[str, str]:
|
||||
trigger = str(active_trigger or "").strip()
|
||||
positive = str(extra_positive or "").strip()
|
||||
if positive:
|
||||
soft_prompt = f"{str(soft_prompt or '').rstrip()} {positive}"
|
||||
hard_prompt = f"{str(hard_prompt or '').rstrip()} {positive}"
|
||||
soft_prompt = prepend_trigger(soft_prompt, trigger, bool(prepend_trigger_to_prompt))
|
||||
hard_prompt = prepend_trigger(hard_prompt, trigger, bool(prepend_trigger_to_prompt))
|
||||
return {
|
||||
"soft_prompt": sanitize_prompt_text(soft_prompt, triggers=_trigger_tuple(trigger)),
|
||||
"hard_prompt": sanitize_prompt_text(hard_prompt, triggers=_trigger_tuple(trigger)),
|
||||
"soft_negative": sanitize_negative_text(combined_negative(soft_negative_base, extra_negative)),
|
||||
"hard_negative": sanitize_negative_text(combined_negative(hard_negative_base, extra_negative)),
|
||||
"soft_caption": caption_from_parts(soft_caption_parts, active_trigger=trigger),
|
||||
"hard_caption": caption_from_parts(hard_caption_parts, active_trigger=trigger),
|
||||
}
|
||||
|
||||
|
||||
def sanitize_metadata_row_text(row: dict[str, Any], *, active_trigger: str = "") -> dict[str, Any]:
|
||||
trigger = str(active_trigger or row.get("trigger") or "").strip()
|
||||
triggers = _trigger_tuple(trigger)
|
||||
if "prompt" in row:
|
||||
row["prompt"] = sanitize_prompt_text(row.get("prompt", ""), triggers=triggers)
|
||||
if "caption" in row:
|
||||
row["caption"] = sanitize_caption_text(row.get("caption", ""), triggers=triggers)
|
||||
if "negative_prompt" in row:
|
||||
row["negative_prompt"] = sanitize_negative_text(row.get("negative_prompt", ""))
|
||||
if trigger and not row.get("trigger"):
|
||||
row["trigger"] = trigger
|
||||
return row
|
||||
|
||||
|
||||
def normalize_pair_metadata(pair: dict[str, Any], *, active_trigger: str = "") -> dict[str, Any]:
|
||||
trigger = str(active_trigger or "").strip()
|
||||
triggers = _trigger_tuple(trigger)
|
||||
for key in ("softcore_prompt", "hardcore_prompt"):
|
||||
if key in pair:
|
||||
pair[key] = sanitize_prompt_text(pair.get(key, ""), triggers=triggers)
|
||||
for key in ("softcore_caption", "hardcore_caption"):
|
||||
if key in pair:
|
||||
pair[key] = sanitize_caption_text(pair.get(key, ""), triggers=triggers)
|
||||
for key in ("softcore_negative_prompt", "hardcore_negative_prompt"):
|
||||
if key in pair:
|
||||
pair[key] = sanitize_negative_text(pair.get(key, ""))
|
||||
for key in ("softcore_row", "hardcore_row"):
|
||||
if isinstance(pair.get(key), dict):
|
||||
pair[key] = sanitize_metadata_row_text(pair[key], active_trigger=trigger)
|
||||
return pair
|
||||
Reference in New Issue
Block a user