Files
ComfyUI-Ethanfel-Prompt-Bui…/row_normalization.py
T

266 lines
11 KiB
Python

from __future__ import annotations
from typing import Any
try:
from . import row_location as row_location_policy
from .prompt_hygiene import combine_negative_text, sanitize_caption_text, sanitize_negative_text, sanitize_prompt_text
except ImportError: # Allows local smoke tests with `python tools/prompt_smoke.py`.
import row_location as row_location_policy
from prompt_hygiene import combine_negative_text, sanitize_caption_text, sanitize_negative_text, sanitize_prompt_text
def _trigger_tuple(active_trigger: str) -> tuple[str, ...]:
trigger = str(active_trigger or "").strip()
return (trigger,) if trigger else ()
def prepend_trigger(prompt: str, trigger: str, enabled: bool) -> str:
trigger = str(trigger or "").strip()
prompt = str(prompt or "")
if not enabled or not trigger:
return prompt
if prompt.lower().startswith(trigger.lower()):
return prompt
return f"{trigger}, {prompt}"
def combined_negative(base: str, extra: str) -> str:
return combine_negative_text(base, extra)
def caption_from_parts(parts: list[Any] | tuple[Any, ...], *, active_trigger: str = "") -> str:
text = ", ".join(str(part).strip() for part in parts if str(part).strip())
return sanitize_caption_text(text, triggers=_trigger_tuple(active_trigger))
def _setdefault_nonempty(row: dict[str, Any], key: str, value: Any) -> None:
if str(row.get(key) or "").strip():
return
if str(value or "").strip():
row[key] = value
def _setdefault_count(row: dict[str, Any], key: str, value: int) -> None:
if str(row.get(key) or "").strip():
return
row[key] = int(value)
def _legacy_subject_metadata(row: dict[str, Any]) -> tuple[str, str, int | None, int | None]:
subject = str(row.get("primary_subject") or row.get("subject") or "").strip()
lower = subject.lower()
if lower in ("woman", "adult woman"):
return "woman", subject or "woman", 1, 0
if lower in ("man", "adult man"):
return "man", subject or "man", 0, 1
if "two women" in lower:
return "couple", subject or "two women", 2, 0
if "two men" in lower:
return "couple", subject or "two men", 0, 2
if "woman" in lower and "man" in lower:
return "couple", subject or "a woman and a man", 1, 1
if "group" in lower:
return "group", subject or "mixed adult group", 2, 2
if "layout" in lower:
return "layout", subject or "adult layout scene", None, None
return "", subject, None, None
def enrich_legacy_row_metadata(row: dict[str, Any]) -> dict[str, Any]:
if row.get("source") != "built_in_generator":
return row
subject_type, subject_phrase, women_count, men_count = _legacy_subject_metadata(row)
_setdefault_nonempty(row, "subject_type", subject_type)
_setdefault_nonempty(row, "subject_phrase", subject_phrase)
if women_count is not None:
_setdefault_count(row, "women_count", women_count)
if men_count is not None:
_setdefault_count(row, "men_count", men_count)
if women_count is not None and men_count is not None and not str(row.get("person_count") or "").strip():
row["person_count"] = int(women_count) + int(men_count)
scene_slug = str(row.get("scene") or row.get("scene_slug") or "").strip()
if scene_slug and not str(row.get("scene_slug") or "").strip():
row["scene_slug"] = scene_slug
if scene_slug and not str(row.get("scene_text") or "").strip():
scene_text = row_location_policy.legacy_scene_text_for_slug(scene_slug)
if scene_text:
row["scene_text"] = scene_text
row.setdefault("scene_entry", {"slug": scene_slug, "prompt": scene_text})
return row
def normalize_prompt_row(
row: dict[str, Any],
*,
active_trigger: str,
prepend_trigger_to_prompt: bool,
extra_positive: str = "",
extra_negative: str = "",
default_negative: str = "",
) -> dict[str, Any]:
row = enrich_legacy_row_metadata(row)
trigger = str(active_trigger or "").strip()
positive = str(extra_positive or "").strip()
prompt = str(row.get("prompt", "") or "")
if positive:
prompt = f"{prompt.rstrip()} {positive}".strip()
prompt = prepend_trigger(prompt, trigger, bool(prepend_trigger_to_prompt))
row["prompt"] = sanitize_prompt_text(prompt, triggers=_trigger_tuple(trigger))
row["caption"] = sanitize_caption_text(row.get("caption", ""), triggers=_trigger_tuple(trigger))
row["negative_prompt"] = sanitize_negative_text(
combined_negative(str(row.get("negative_prompt", default_negative) or ""), extra_negative)
)
row["trigger"] = trigger
return row
def normalize_pair_text_outputs(
*,
active_trigger: str,
prepend_trigger_to_prompt: bool,
extra_positive: str = "",
extra_negative: str = "",
soft_prompt: str,
hard_prompt: str,
soft_negative_base: str,
hard_negative_base: str,
soft_caption_parts: list[Any] | tuple[Any, ...],
hard_caption_parts: list[Any] | tuple[Any, ...],
) -> dict[str, str]:
trigger = str(active_trigger or "").strip()
positive = str(extra_positive or "").strip()
if positive:
soft_prompt = f"{str(soft_prompt or '').rstrip()} {positive}"
hard_prompt = f"{str(hard_prompt or '').rstrip()} {positive}"
soft_prompt = prepend_trigger(soft_prompt, trigger, bool(prepend_trigger_to_prompt))
hard_prompt = prepend_trigger(hard_prompt, trigger, bool(prepend_trigger_to_prompt))
return {
"soft_prompt": sanitize_prompt_text(soft_prompt, triggers=_trigger_tuple(trigger)),
"hard_prompt": sanitize_prompt_text(hard_prompt, triggers=_trigger_tuple(trigger)),
"soft_negative": sanitize_negative_text(combined_negative(soft_negative_base, extra_negative)),
"hard_negative": sanitize_negative_text(combined_negative(hard_negative_base, extra_negative)),
"soft_caption": caption_from_parts(soft_caption_parts, active_trigger=trigger),
"hard_caption": caption_from_parts(hard_caption_parts, active_trigger=trigger),
}
def sanitize_metadata_row_text(row: dict[str, Any], *, active_trigger: str = "") -> dict[str, Any]:
trigger = str(active_trigger or row.get("trigger") or "").strip()
triggers = _trigger_tuple(trigger)
if "prompt" in row:
row["prompt"] = sanitize_prompt_text(row.get("prompt", ""), triggers=triggers)
if "caption" in row:
row["caption"] = sanitize_caption_text(row.get("caption", ""), triggers=triggers)
if "negative_prompt" in row:
row["negative_prompt"] = sanitize_negative_text(row.get("negative_prompt", ""))
if trigger and not row.get("trigger"):
row["trigger"] = trigger
return row
def _sync_pair_root_row_field(pair: dict[str, Any], row_key: str, root_key: str, row_field: str) -> None:
row = pair.get(row_key)
if not isinstance(row, dict):
return
if root_key in pair:
row[row_field] = pair.get(root_key)
elif row_field in row:
pair[root_key] = row.get(row_field)
def synchronize_pair_row_outputs(pair: dict[str, Any]) -> dict[str, Any]:
mapping = (
("softcore_row", "softcore_prompt", "softcore_caption", "softcore_negative_prompt"),
("hardcore_row", "hardcore_prompt", "hardcore_caption", "hardcore_negative_prompt"),
)
for row_key, prompt_key, caption_key, negative_key in mapping:
_sync_pair_root_row_field(pair, row_key, prompt_key, "prompt")
_sync_pair_root_row_field(pair, row_key, caption_key, "caption")
_sync_pair_root_row_field(pair, row_key, negative_key, "negative_prompt")
return pair
def synchronize_pair_side_metadata(pair: dict[str, Any]) -> dict[str, Any]:
side_keys = {
"softcore_row": (
"softcore_partner_styling",
),
"hardcore_row": (
"hardcore_clothing_state",
"character_hardcore_clothing",
"default_man_hardcore_clothing",
"hardcore_detail_density",
"hardcore_position_config",
),
}
for row_key, keys in side_keys.items():
for key in keys:
_sync_pair_root_row_field(pair, row_key, key, key)
return pair
def synchronize_pair_camera_metadata(pair: dict[str, Any]) -> dict[str, Any]:
mapping = {
"softcore_row": (
("softcore_camera_config", "camera_config"),
("softcore_camera_directive", "camera_directive"),
("softcore_camera_scene_directive", "camera_scene_directive"),
),
"hardcore_row": (
("hardcore_camera_config", "camera_config"),
("hardcore_camera_directive", "camera_directive"),
("hardcore_camera_scene_directive", "camera_scene_directive"),
),
}
for row_key, keys in mapping.items():
for source_key, target_key in keys:
_sync_pair_root_row_field(pair, row_key, source_key, target_key)
return pair
def synchronize_pair_cast_metadata(pair: dict[str, Any]) -> dict[str, Any]:
descriptors = pair.get("shared_cast_descriptors")
if isinstance(descriptors, list):
descriptor_list = [str(item).strip() for item in descriptors if str(item or "").strip()]
descriptor_text = "; ".join(descriptor_list)
else:
descriptor_text = str(descriptors or "").strip()
descriptor_list = [descriptor_text] if descriptor_text else []
if not descriptor_text:
return pair
options = pair.get("options") if isinstance(pair.get("options"), dict) else {}
row_keys = ["hardcore_row"]
if options.get("softcore_cast") == "same_as_hardcore":
row_keys.append("softcore_row")
for row_key in row_keys:
row = pair.get(row_key)
if not isinstance(row, dict):
continue
row["cast_descriptor_text"] = descriptor_text
row["cast_descriptors"] = list(descriptor_list)
return pair
def normalize_pair_metadata(pair: dict[str, Any], *, active_trigger: str = "") -> dict[str, Any]:
trigger = str(active_trigger or "").strip()
triggers = _trigger_tuple(trigger)
synchronize_pair_row_outputs(pair)
synchronize_pair_side_metadata(pair)
synchronize_pair_camera_metadata(pair)
synchronize_pair_cast_metadata(pair)
for key in ("softcore_prompt", "hardcore_prompt"):
if key in pair:
pair[key] = sanitize_prompt_text(pair.get(key, ""), triggers=triggers)
for key in ("softcore_caption", "hardcore_caption"):
if key in pair:
pair[key] = sanitize_caption_text(pair.get(key, ""), triggers=triggers)
for key in ("softcore_negative_prompt", "hardcore_negative_prompt"):
if key in pair:
pair[key] = sanitize_negative_text(pair.get(key, ""))
for key in ("softcore_row", "hardcore_row"):
if isinstance(pair.get(key), dict):
pair[key] = sanitize_metadata_row_text(pair[key], active_trigger=trigger)
return pair