Files
ComfyUI-Ethanfel-Prompt-Bui…/row_normalization.py
T

188 lines
7.4 KiB
Python

from __future__ import annotations
from typing import Any
try:
from .prompt_hygiene import sanitize_caption_text, sanitize_negative_text, sanitize_prompt_text
except ImportError: # Allows local smoke tests with `python tools/prompt_smoke.py`.
from prompt_hygiene import sanitize_caption_text, sanitize_negative_text, sanitize_prompt_text
def _trigger_tuple(active_trigger: str) -> tuple[str, ...]:
trigger = str(active_trigger or "").strip()
return (trigger,) if trigger else ()
def prepend_trigger(prompt: str, trigger: str, enabled: bool) -> str:
trigger = str(trigger or "").strip()
prompt = str(prompt or "")
if not enabled or not trigger:
return prompt
if prompt.lower().startswith(trigger.lower()):
return prompt
return f"{trigger}, {prompt}"
def combined_negative(base: str, extra: str) -> str:
parts = [str(part).strip() for part in (base, extra) if part and str(part).strip()]
return ", ".join(parts)
def caption_from_parts(parts: list[Any] | tuple[Any, ...], *, active_trigger: str = "") -> str:
text = ", ".join(str(part).strip() for part in parts if str(part).strip())
return sanitize_caption_text(text, triggers=_trigger_tuple(active_trigger))
def normalize_prompt_row(
row: dict[str, Any],
*,
active_trigger: str,
prepend_trigger_to_prompt: bool,
extra_positive: str = "",
extra_negative: str = "",
default_negative: str = "",
) -> dict[str, Any]:
trigger = str(active_trigger or "").strip()
positive = str(extra_positive or "").strip()
prompt = str(row.get("prompt", "") or "")
if positive:
prompt = f"{prompt.rstrip()} {positive}".strip()
prompt = prepend_trigger(prompt, trigger, bool(prepend_trigger_to_prompt))
row["prompt"] = sanitize_prompt_text(prompt, triggers=_trigger_tuple(trigger))
row["caption"] = sanitize_caption_text(row.get("caption", ""), triggers=_trigger_tuple(trigger))
row["negative_prompt"] = sanitize_negative_text(
combined_negative(str(row.get("negative_prompt", default_negative) or ""), extra_negative)
)
row["trigger"] = trigger
return row
def normalize_pair_text_outputs(
*,
active_trigger: str,
prepend_trigger_to_prompt: bool,
extra_positive: str = "",
extra_negative: str = "",
soft_prompt: str,
hard_prompt: str,
soft_negative_base: str,
hard_negative_base: str,
soft_caption_parts: list[Any] | tuple[Any, ...],
hard_caption_parts: list[Any] | tuple[Any, ...],
) -> dict[str, str]:
trigger = str(active_trigger or "").strip()
positive = str(extra_positive or "").strip()
if positive:
soft_prompt = f"{str(soft_prompt or '').rstrip()} {positive}"
hard_prompt = f"{str(hard_prompt or '').rstrip()} {positive}"
soft_prompt = prepend_trigger(soft_prompt, trigger, bool(prepend_trigger_to_prompt))
hard_prompt = prepend_trigger(hard_prompt, trigger, bool(prepend_trigger_to_prompt))
return {
"soft_prompt": sanitize_prompt_text(soft_prompt, triggers=_trigger_tuple(trigger)),
"hard_prompt": sanitize_prompt_text(hard_prompt, triggers=_trigger_tuple(trigger)),
"soft_negative": sanitize_negative_text(combined_negative(soft_negative_base, extra_negative)),
"hard_negative": sanitize_negative_text(combined_negative(hard_negative_base, extra_negative)),
"soft_caption": caption_from_parts(soft_caption_parts, active_trigger=trigger),
"hard_caption": caption_from_parts(hard_caption_parts, active_trigger=trigger),
}
def sanitize_metadata_row_text(row: dict[str, Any], *, active_trigger: str = "") -> dict[str, Any]:
trigger = str(active_trigger or row.get("trigger") or "").strip()
triggers = _trigger_tuple(trigger)
if "prompt" in row:
row["prompt"] = sanitize_prompt_text(row.get("prompt", ""), triggers=triggers)
if "caption" in row:
row["caption"] = sanitize_caption_text(row.get("caption", ""), triggers=triggers)
if "negative_prompt" in row:
row["negative_prompt"] = sanitize_negative_text(row.get("negative_prompt", ""))
if trigger and not row.get("trigger"):
row["trigger"] = trigger
return row
def synchronize_pair_row_outputs(pair: dict[str, Any]) -> dict[str, Any]:
mapping = (
("softcore_row", "softcore_prompt", "softcore_caption", "softcore_negative_prompt"),
("hardcore_row", "hardcore_prompt", "hardcore_caption", "hardcore_negative_prompt"),
)
for row_key, prompt_key, caption_key, negative_key in mapping:
row = pair.get(row_key)
if not isinstance(row, dict):
continue
if prompt_key in pair:
row["prompt"] = pair.get(prompt_key, "")
if caption_key in pair:
row["caption"] = pair.get(caption_key, "")
if negative_key in pair:
row["negative_prompt"] = pair.get(negative_key, "")
return pair
def synchronize_pair_side_metadata(pair: dict[str, Any]) -> dict[str, Any]:
side_keys = {
"softcore_row": (
"softcore_partner_styling",
),
"hardcore_row": (
"hardcore_clothing_state",
"character_hardcore_clothing",
"default_man_hardcore_clothing",
"hardcore_detail_density",
"hardcore_position_config",
),
}
for row_key, keys in side_keys.items():
row = pair.get(row_key)
if not isinstance(row, dict):
continue
for key in keys:
if key in pair:
row[key] = pair.get(key)
return pair
def synchronize_pair_cast_metadata(pair: dict[str, Any]) -> dict[str, Any]:
descriptors = pair.get("shared_cast_descriptors")
if isinstance(descriptors, list):
descriptor_list = [str(item).strip() for item in descriptors if str(item or "").strip()]
descriptor_text = "; ".join(descriptor_list)
else:
descriptor_text = str(descriptors or "").strip()
descriptor_list = [descriptor_text] if descriptor_text else []
if not descriptor_text:
return pair
options = pair.get("options") if isinstance(pair.get("options"), dict) else {}
row_keys = ["hardcore_row"]
if options.get("softcore_cast") == "same_as_hardcore":
row_keys.append("softcore_row")
for row_key in row_keys:
row = pair.get(row_key)
if not isinstance(row, dict):
continue
row["cast_descriptor_text"] = descriptor_text
row["cast_descriptors"] = list(descriptor_list)
return pair
def normalize_pair_metadata(pair: dict[str, Any], *, active_trigger: str = "") -> dict[str, Any]:
trigger = str(active_trigger or "").strip()
triggers = _trigger_tuple(trigger)
synchronize_pair_row_outputs(pair)
synchronize_pair_side_metadata(pair)
synchronize_pair_cast_metadata(pair)
for key in ("softcore_prompt", "hardcore_prompt"):
if key in pair:
pair[key] = sanitize_prompt_text(pair.get(key, ""), triggers=triggers)
for key in ("softcore_caption", "hardcore_caption"):
if key in pair:
pair[key] = sanitize_caption_text(pair.get(key, ""), triggers=triggers)
for key in ("softcore_negative_prompt", "hardcore_negative_prompt"):
if key in pair:
pair[key] = sanitize_negative_text(pair.get(key, ""))
for key in ("softcore_row", "hardcore_row"):
if isinstance(pair.get(key), dict):
pair[key] = sanitize_metadata_row_text(pair[key], active_trigger=trigger)
return pair