Normalize built-in row subject metadata
This commit is contained in:
@@ -32,6 +32,56 @@ def caption_from_parts(parts: list[Any] | tuple[Any, ...], *, active_trigger: st
|
||||
return sanitize_caption_text(text, triggers=_trigger_tuple(active_trigger))
|
||||
|
||||
|
||||
def _setdefault_nonempty(row: dict[str, Any], key: str, value: Any) -> None:
|
||||
if str(row.get(key) or "").strip():
|
||||
return
|
||||
if str(value or "").strip():
|
||||
row[key] = value
|
||||
|
||||
|
||||
def _setdefault_count(row: dict[str, Any], key: str, value: int) -> None:
|
||||
if str(row.get(key) or "").strip():
|
||||
return
|
||||
row[key] = int(value)
|
||||
|
||||
|
||||
def _legacy_subject_metadata(row: dict[str, Any]) -> tuple[str, str, int | None, int | None]:
|
||||
subject = str(row.get("primary_subject") or row.get("subject") or "").strip()
|
||||
lower = subject.lower()
|
||||
if lower in ("woman", "adult woman"):
|
||||
return "woman", subject or "woman", 1, 0
|
||||
if lower in ("man", "adult man"):
|
||||
return "man", subject or "man", 0, 1
|
||||
if "two women" in lower:
|
||||
return "couple", subject or "two women", 2, 0
|
||||
if "two men" in lower:
|
||||
return "couple", subject or "two men", 0, 2
|
||||
if "woman" in lower and "man" in lower:
|
||||
return "couple", subject or "a woman and a man", 1, 1
|
||||
if "group" in lower:
|
||||
return "group", subject or "mixed adult group", 2, 2
|
||||
if "layout" in lower:
|
||||
return "layout", subject or "adult layout scene", None, None
|
||||
return "", subject, None, None
|
||||
|
||||
|
||||
def enrich_legacy_row_metadata(row: dict[str, Any]) -> dict[str, Any]:
|
||||
if row.get("source") != "built_in_generator":
|
||||
return row
|
||||
subject_type, subject_phrase, women_count, men_count = _legacy_subject_metadata(row)
|
||||
_setdefault_nonempty(row, "subject_type", subject_type)
|
||||
_setdefault_nonempty(row, "subject_phrase", subject_phrase)
|
||||
if women_count is not None:
|
||||
_setdefault_count(row, "women_count", women_count)
|
||||
if men_count is not None:
|
||||
_setdefault_count(row, "men_count", men_count)
|
||||
if women_count is not None and men_count is not None and not str(row.get("person_count") or "").strip():
|
||||
row["person_count"] = int(women_count) + int(men_count)
|
||||
if str(row.get("scene") or "").strip() and not str(row.get("scene_slug") or "").strip():
|
||||
row["scene_slug"] = row.get("scene")
|
||||
return row
|
||||
|
||||
|
||||
def normalize_prompt_row(
|
||||
row: dict[str, Any],
|
||||
*,
|
||||
@@ -41,6 +91,7 @@ def normalize_prompt_row(
|
||||
extra_negative: str = "",
|
||||
default_negative: str = "",
|
||||
) -> dict[str, Any]:
|
||||
row = enrich_legacy_row_metadata(row)
|
||||
trigger = str(active_trigger or "").strip()
|
||||
positive = str(extra_positive or "").strip()
|
||||
prompt = str(row.get("prompt", "") or "")
|
||||
|
||||
Reference in New Issue
Block a user