Extract formatter input parsing policy
This commit is contained in:
+13
-54
@@ -1,13 +1,14 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
|
||||||
import re
|
import re
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
from . import formatter_input as input_policy
|
||||||
from .hardcore_action_metadata import normalize_hardcore_action_family
|
from .hardcore_action_metadata import normalize_hardcore_action_family
|
||||||
from .prompt_hygiene import sanitize_prose_text
|
from .prompt_hygiene import sanitize_prose_text
|
||||||
except ImportError: # Allows local smoke tests with `python -c`.
|
except ImportError: # Allows local smoke tests with `python -c`.
|
||||||
|
import formatter_input as input_policy
|
||||||
from hardcore_action_metadata import normalize_hardcore_action_family
|
from hardcore_action_metadata import normalize_hardcore_action_family
|
||||||
from prompt_hygiene import sanitize_prose_text
|
from prompt_hygiene import sanitize_prose_text
|
||||||
|
|
||||||
@@ -71,11 +72,7 @@ POSITION_FAMILY_CAPTION_LABELS = {
|
|||||||
|
|
||||||
|
|
||||||
def _clean_text(value: Any) -> str:
|
def _clean_text(value: Any) -> str:
|
||||||
text = "" if value is None else str(value)
|
return input_policy.clean_text(value)
|
||||||
text = text.replace("\n", " ")
|
|
||||||
text = re.sub(r"\s+", " ", text).strip()
|
|
||||||
text = re.sub(r"\s+([,.;:])", r"\1", text)
|
|
||||||
return text
|
|
||||||
|
|
||||||
|
|
||||||
def _is_false(value: Any) -> bool:
|
def _is_false(value: Any) -> bool:
|
||||||
@@ -189,18 +186,11 @@ def _strip_style_tail(text: str) -> str:
|
|||||||
|
|
||||||
|
|
||||||
def _remove_trigger(text: str, trigger: str) -> str:
|
def _remove_trigger(text: str, trigger: str) -> str:
|
||||||
text = _clean_text(text).strip(" ,")
|
return input_policy.strip_trigger_prefix(
|
||||||
for candidate in (trigger, OLD_TRIGGER, DEFAULT_TRIGGER):
|
text,
|
||||||
candidate = candidate.strip()
|
(trigger, OLD_TRIGGER, DEFAULT_TRIGGER),
|
||||||
if not candidate:
|
remove_exact=True,
|
||||||
continue
|
)
|
||||||
if text.lower().startswith(candidate.lower() + ","):
|
|
||||||
return text[len(candidate) + 1 :].strip(" ,")
|
|
||||||
if text.lower().startswith(candidate.lower() + "."):
|
|
||||||
return text[len(candidate) + 1 :].strip(" ,")
|
|
||||||
if text.lower() == candidate.lower():
|
|
||||||
return ""
|
|
||||||
return text
|
|
||||||
|
|
||||||
|
|
||||||
def _with_trigger(text: str, trigger: str, include_trigger: bool) -> str:
|
def _with_trigger(text: str, trigger: str, include_trigger: bool) -> str:
|
||||||
@@ -214,55 +204,24 @@ def _with_trigger(text: str, trigger: str, include_trigger: bool) -> str:
|
|||||||
|
|
||||||
|
|
||||||
def _maybe_json(text: str) -> dict[str, Any] | None:
|
def _maybe_json(text: str) -> dict[str, Any] | None:
|
||||||
text = _clean_text(text)
|
return input_policy.maybe_json(text)
|
||||||
if not text or not text.startswith("{"):
|
|
||||||
return None
|
|
||||||
try:
|
|
||||||
value = json.loads(text)
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
return None
|
|
||||||
return value if isinstance(value, dict) else None
|
|
||||||
|
|
||||||
|
|
||||||
def _row_from_inputs(source_text: str, metadata_json: str, input_hint: str) -> tuple[dict[str, Any] | None, str]:
|
def _row_from_inputs(source_text: str, metadata_json: str, input_hint: str) -> tuple[dict[str, Any] | None, str]:
|
||||||
candidates: list[tuple[str, str]] = []
|
return input_policy.row_from_inputs(source_text, metadata_json, input_hint)
|
||||||
if input_hint in ("auto", "metadata_json"):
|
|
||||||
candidates.append((metadata_json, "metadata_json"))
|
|
||||||
candidates.append((source_text, "source_json"))
|
|
||||||
for text, method in candidates:
|
|
||||||
row = _maybe_json(text)
|
|
||||||
if row is not None:
|
|
||||||
return row, method
|
|
||||||
return None, "text"
|
|
||||||
|
|
||||||
|
|
||||||
def _prompt_field(text: str, label: str) -> str:
|
def _prompt_field(text: str, label: str) -> str:
|
||||||
text = _clean_text(text)
|
return input_policy.prompt_field(text, label, field_labels=PROMPT_FIELD_LABELS)
|
||||||
if not text:
|
|
||||||
return ""
|
|
||||||
labels = "|".join(re.escape(name) for name in PROMPT_FIELD_LABELS)
|
|
||||||
pattern = rf"{re.escape(label)}:\s*(.*?)(?=\. (?:{labels}):|\. Use\b|\. Avoid\b|$)"
|
|
||||||
match = re.search(pattern, text)
|
|
||||||
if not match:
|
|
||||||
return ""
|
|
||||||
return _clean_text(match.group(1)).rstrip(".")
|
|
||||||
|
|
||||||
|
|
||||||
def _row_value(row: dict[str, Any], key: str, labels: tuple[str, ...] = ()) -> str:
|
def _row_value(row: dict[str, Any], key: str, labels: tuple[str, ...] = ()) -> str:
|
||||||
value = _clean_text(row.get(key, ""))
|
return input_policy.row_value(row, key, labels, field_labels=PROMPT_FIELD_LABELS)
|
||||||
if value:
|
|
||||||
return value
|
|
||||||
prompt = _clean_text(row.get("prompt", ""))
|
|
||||||
for label in labels:
|
|
||||||
value = _prompt_field(prompt, label)
|
|
||||||
if value:
|
|
||||||
return value
|
|
||||||
return ""
|
|
||||||
|
|
||||||
|
|
||||||
def _field_from_any_prompt(text: str, labels: tuple[str, ...]) -> str:
|
def _field_from_any_prompt(text: str, labels: tuple[str, ...]) -> str:
|
||||||
for label in labels:
|
for label in labels:
|
||||||
value = _prompt_field(text, label)
|
value = input_policy.prompt_field(text, label, field_labels=PROMPT_FIELD_LABELS)
|
||||||
if value:
|
if value:
|
||||||
return value
|
return value
|
||||||
return ""
|
return ""
|
||||||
|
|||||||
@@ -62,6 +62,23 @@ route-specific owner. It also preserves ordinary words such as `composition`
|
|||||||
inside normal sentences; empty field-label cleanup is limited to standalone
|
inside normal sentences; empty field-label cleanup is limited to standalone
|
||||||
labels.
|
labels.
|
||||||
|
|
||||||
|
Formatter input/fallback parsing now has one home:
|
||||||
|
|
||||||
|
- `formatter_input.py`
|
||||||
|
|
||||||
|
It owns route-neutral parsing shared by Krea2, SDXL, and natural-caption
|
||||||
|
routes:
|
||||||
|
|
||||||
|
- whitespace and punctuation normalization before formatter parsing;
|
||||||
|
- JSON row detection from `metadata_json` or source text;
|
||||||
|
- trigger-prefix stripping with route-specific trigger candidate lists;
|
||||||
|
- `Avoid:` positive/negative splitting for fallback text;
|
||||||
|
- prompt field extraction such as `Setting:` or `Composition:`;
|
||||||
|
- row-value fallback from metadata fields to labeled prompt text.
|
||||||
|
|
||||||
|
It must not make formatter-style decisions. Krea prose, SDXL tags, and training
|
||||||
|
caption sentence shape stay in their formatter modules.
|
||||||
|
|
||||||
Shared hardcore phrase cleanup now has one home:
|
Shared hardcore phrase cleanup now has one home:
|
||||||
|
|
||||||
- `hardcore_text_cleanup.py`
|
- `hardcore_text_cleanup.py`
|
||||||
@@ -242,6 +259,9 @@ Already isolated:
|
|||||||
- `krea_pov_actions.py` owns POV hardcore action sentence rewriting,
|
- `krea_pov_actions.py` owns POV hardcore action sentence rewriting,
|
||||||
first-person body geometry, and selected-position-axis priority before loose
|
first-person body geometry, and selected-position-axis priority before loose
|
||||||
context fallback.
|
context fallback.
|
||||||
|
- `formatter_input.py` owns shared metadata/source JSON detection, trigger
|
||||||
|
stripping, prompt-field extraction, `Avoid:` splitting, and row-value
|
||||||
|
fallback for Krea, SDXL, and caption routes.
|
||||||
|
|
||||||
Improve later:
|
Improve later:
|
||||||
|
|
||||||
@@ -262,6 +282,7 @@ Keep here:
|
|||||||
- negative-prompt assembly.
|
- negative-prompt assembly.
|
||||||
- metadata-family tag hints from `action_family`, `position_family`, and
|
- metadata-family tag hints from `action_family`, `position_family`, and
|
||||||
`position_keys`.
|
`position_keys`.
|
||||||
|
- shared formatter input parsing from `formatter_input.py`.
|
||||||
|
|
||||||
Improve later:
|
Improve later:
|
||||||
|
|
||||||
@@ -280,6 +301,7 @@ Keep here:
|
|||||||
- training-caption trigger behavior;
|
- training-caption trigger behavior;
|
||||||
- style-tail policy.
|
- style-tail policy.
|
||||||
- metadata-family action labels from `action_family` and `position_family`.
|
- metadata-family action labels from `action_family` and `position_family`.
|
||||||
|
- shared formatter input parsing from `formatter_input.py`.
|
||||||
|
|
||||||
Improve later:
|
Improve later:
|
||||||
|
|
||||||
|
|||||||
@@ -94,6 +94,7 @@ Core helper ownership:
|
|||||||
| `scene_camera_adapters.py` | Location-aware camera/scene prose such as coworking lounge camera layout. |
|
| `scene_camera_adapters.py` | Location-aware camera/scene prose such as coworking lounge camera layout. |
|
||||||
| `prompt_hygiene.py` | Generic prompt, caption, and negative-prompt cleanup. |
|
| `prompt_hygiene.py` | Generic prompt, caption, and negative-prompt cleanup. |
|
||||||
| `row_normalization.py` | Final prompt-row and pair metadata normalization: trigger prepending, extra-positive append, negative merge/dedupe, caption-part joining, and embedded soft/hard row sanitation. |
|
| `row_normalization.py` | Final prompt-row and pair metadata normalization: trigger prepending, extra-positive append, negative merge/dedupe, caption-part joining, and embedded soft/hard row sanitation. |
|
||||||
|
| `formatter_input.py` | Shared formatter input parsing: text cleanup, metadata/source JSON detection, trigger-prefix stripping, `Avoid:` splitting, prompt-field extraction, and metadata row-value fallback. |
|
||||||
|
|
||||||
## Node IO Map
|
## Node IO Map
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,132 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_PROMPT_FIELD_LABELS = (
|
||||||
|
"Ages",
|
||||||
|
"Body types",
|
||||||
|
"Cast",
|
||||||
|
"Cast descriptors",
|
||||||
|
"Characters",
|
||||||
|
"Scene",
|
||||||
|
"Setting",
|
||||||
|
"Pose",
|
||||||
|
"Sexual pose",
|
||||||
|
"Sexual scene",
|
||||||
|
"Facial expression",
|
||||||
|
"Facial expressions",
|
||||||
|
"Clothing",
|
||||||
|
"Erotic outfit",
|
||||||
|
"Prop/detail",
|
||||||
|
"Composition",
|
||||||
|
"Role graph",
|
||||||
|
"Camera",
|
||||||
|
"Camera control",
|
||||||
|
"Use",
|
||||||
|
"Avoid",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def clean_text(value: Any) -> str:
|
||||||
|
text = "" if value is None else str(value)
|
||||||
|
text = text.replace("\n", " ")
|
||||||
|
text = re.sub(r"\s+", " ", text).strip()
|
||||||
|
text = re.sub(r"\s+([,.;:])", r"\1", text)
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def maybe_json(text: Any) -> dict[str, Any] | None:
|
||||||
|
text = clean_text(text)
|
||||||
|
if not text.startswith("{"):
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
value = json.loads(text)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
return None
|
||||||
|
return value if isinstance(value, dict) else None
|
||||||
|
|
||||||
|
|
||||||
|
def row_from_inputs(
|
||||||
|
source_text: str,
|
||||||
|
metadata_json: str,
|
||||||
|
input_hint: str,
|
||||||
|
*,
|
||||||
|
metadata_methods: tuple[str, ...] = ("auto", "metadata_json"),
|
||||||
|
) -> tuple[dict[str, Any] | None, str]:
|
||||||
|
if input_hint in metadata_methods:
|
||||||
|
for text, method in ((metadata_json, "metadata_json"), (source_text, "source_json")):
|
||||||
|
row = maybe_json(text)
|
||||||
|
if row is not None:
|
||||||
|
return row, method
|
||||||
|
return None, "text"
|
||||||
|
|
||||||
|
|
||||||
|
def strip_trigger_prefix(
|
||||||
|
text: Any,
|
||||||
|
trigger_candidates: tuple[str, ...] | list[str],
|
||||||
|
*,
|
||||||
|
preserve_trigger: bool = False,
|
||||||
|
remove_exact: bool = False,
|
||||||
|
) -> str:
|
||||||
|
text = clean_text(text)
|
||||||
|
if remove_exact:
|
||||||
|
text = text.strip(" ,")
|
||||||
|
if preserve_trigger:
|
||||||
|
return text
|
||||||
|
for trigger in trigger_candidates:
|
||||||
|
trigger = clean_text(trigger)
|
||||||
|
if not trigger:
|
||||||
|
continue
|
||||||
|
if text.lower().startswith(trigger.lower() + ","):
|
||||||
|
return text[len(trigger) + 1 :].strip(" ,")
|
||||||
|
if text.lower().startswith(trigger.lower() + "."):
|
||||||
|
return text[len(trigger) + 1 :].strip(" ,")
|
||||||
|
if remove_exact and text.lower() == trigger.lower():
|
||||||
|
return ""
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def split_avoid(text: Any) -> tuple[str, str]:
|
||||||
|
text = clean_text(text)
|
||||||
|
match = re.search(r"\bAvoid:\s*(.*)$", text)
|
||||||
|
if not match:
|
||||||
|
return text, ""
|
||||||
|
return text[: match.start()].strip(" ."), match.group(1).strip(" .")
|
||||||
|
|
||||||
|
|
||||||
|
def prompt_field(
|
||||||
|
text: Any,
|
||||||
|
label: str,
|
||||||
|
*,
|
||||||
|
field_labels: tuple[str, ...] | list[str] = DEFAULT_PROMPT_FIELD_LABELS,
|
||||||
|
) -> str:
|
||||||
|
text = clean_text(text)
|
||||||
|
if not text:
|
||||||
|
return ""
|
||||||
|
labels = "|".join(re.escape(name) for name in field_labels)
|
||||||
|
pattern = rf"{re.escape(label)}:\s*(.*?)(?=\. (?:{labels}):|\. Use\b|\. Avoid\b|$)"
|
||||||
|
match = re.search(pattern, text)
|
||||||
|
if not match:
|
||||||
|
return ""
|
||||||
|
return clean_text(match.group(1)).rstrip(".")
|
||||||
|
|
||||||
|
|
||||||
|
def row_value(
|
||||||
|
row: dict[str, Any],
|
||||||
|
key: str,
|
||||||
|
labels: tuple[str, ...] = (),
|
||||||
|
*,
|
||||||
|
field_labels: tuple[str, ...] | list[str] = DEFAULT_PROMPT_FIELD_LABELS,
|
||||||
|
) -> str:
|
||||||
|
value = clean_text(row.get(key, ""))
|
||||||
|
if value:
|
||||||
|
return value
|
||||||
|
prompt = clean_text(row.get("prompt", ""))
|
||||||
|
for label in labels:
|
||||||
|
value = prompt_field(prompt, label, field_labels=field_labels)
|
||||||
|
if value:
|
||||||
|
return value
|
||||||
|
return ""
|
||||||
+9
-54
@@ -1,10 +1,10 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
|
||||||
import re
|
import re
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
from . import formatter_input as input_policy
|
||||||
from .krea_action_context import (
|
from .krea_action_context import (
|
||||||
is_close_foreplay_text as _is_close_foreplay_text,
|
is_close_foreplay_text as _is_close_foreplay_text,
|
||||||
is_outercourse_text as _is_outercourse_text,
|
is_outercourse_text as _is_outercourse_text,
|
||||||
@@ -34,6 +34,7 @@ try:
|
|||||||
from .krea_pov_actions import pov_action_phrase as _pov_action_phrase
|
from .krea_pov_actions import pov_action_phrase as _pov_action_phrase
|
||||||
from .prompt_hygiene import sanitize_negative_text, sanitize_prose_text
|
from .prompt_hygiene import sanitize_negative_text, sanitize_prose_text
|
||||||
except ImportError: # Allows local smoke tests with `python -c`.
|
except ImportError: # Allows local smoke tests with `python -c`.
|
||||||
|
import formatter_input as input_policy
|
||||||
from krea_action_context import (
|
from krea_action_context import (
|
||||||
is_close_foreplay_text as _is_close_foreplay_text,
|
is_close_foreplay_text as _is_close_foreplay_text,
|
||||||
is_outercourse_text as _is_outercourse_text,
|
is_outercourse_text as _is_outercourse_text,
|
||||||
@@ -91,11 +92,7 @@ PROMPT_FIELD_LABELS = (
|
|||||||
|
|
||||||
|
|
||||||
def _clean(value: Any) -> str:
|
def _clean(value: Any) -> str:
|
||||||
text = "" if value is None else str(value)
|
return input_policy.clean_text(value)
|
||||||
text = text.replace("\n", " ")
|
|
||||||
text = re.sub(r"\s+", " ", text).strip()
|
|
||||||
text = re.sub(r"\s+([,.;:])", r"\1", text)
|
|
||||||
return text
|
|
||||||
|
|
||||||
|
|
||||||
def _is_false(value: Any) -> bool:
|
def _is_false(value: Any) -> bool:
|
||||||
@@ -133,69 +130,27 @@ def _with_indefinite_article(text: str) -> str:
|
|||||||
|
|
||||||
|
|
||||||
def _maybe_json(text: str) -> dict[str, Any] | None:
|
def _maybe_json(text: str) -> dict[str, Any] | None:
|
||||||
text = _clean(text)
|
return input_policy.maybe_json(text)
|
||||||
if not text.startswith("{"):
|
|
||||||
return None
|
|
||||||
try:
|
|
||||||
value = json.loads(text)
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
return None
|
|
||||||
return value if isinstance(value, dict) else None
|
|
||||||
|
|
||||||
|
|
||||||
def _row_from_inputs(source_text: str, metadata_json: str, input_hint: str) -> tuple[dict[str, Any] | None, str]:
|
def _row_from_inputs(source_text: str, metadata_json: str, input_hint: str) -> tuple[dict[str, Any] | None, str]:
|
||||||
candidates: list[tuple[str, str]] = []
|
return input_policy.row_from_inputs(source_text, metadata_json, input_hint)
|
||||||
if input_hint in ("auto", "metadata_json"):
|
|
||||||
candidates.append((metadata_json, "metadata_json"))
|
|
||||||
candidates.append((source_text, "source_json"))
|
|
||||||
for text, method in candidates:
|
|
||||||
row = _maybe_json(text)
|
|
||||||
if row is not None:
|
|
||||||
return row, method
|
|
||||||
return None, "text"
|
|
||||||
|
|
||||||
|
|
||||||
def _strip_trigger(text: str, preserve_trigger: bool) -> str:
|
def _strip_trigger(text: str, preserve_trigger: bool) -> str:
|
||||||
text = _clean(text)
|
return input_policy.strip_trigger_prefix(text, TRIGGER_CANDIDATES, preserve_trigger=preserve_trigger)
|
||||||
if preserve_trigger:
|
|
||||||
return text
|
|
||||||
for trigger in TRIGGER_CANDIDATES:
|
|
||||||
if text.lower().startswith(trigger.lower() + ","):
|
|
||||||
return text[len(trigger) + 1 :].strip(" ,")
|
|
||||||
if text.lower().startswith(trigger.lower() + "."):
|
|
||||||
return text[len(trigger) + 1 :].strip(" ,")
|
|
||||||
return text
|
|
||||||
|
|
||||||
|
|
||||||
def _split_avoid(text: str) -> tuple[str, str]:
|
def _split_avoid(text: str) -> tuple[str, str]:
|
||||||
match = re.search(r"\bAvoid:\s*(.*)$", text)
|
return input_policy.split_avoid(text)
|
||||||
if not match:
|
|
||||||
return text, ""
|
|
||||||
return text[: match.start()].strip(" ."), match.group(1).strip(" .")
|
|
||||||
|
|
||||||
|
|
||||||
def _prompt_field(text: str, label: str) -> str:
|
def _prompt_field(text: str, label: str) -> str:
|
||||||
text = _clean(text)
|
return input_policy.prompt_field(text, label, field_labels=PROMPT_FIELD_LABELS)
|
||||||
if not text:
|
|
||||||
return ""
|
|
||||||
labels = "|".join(re.escape(name) for name in PROMPT_FIELD_LABELS)
|
|
||||||
pattern = rf"{re.escape(label)}:\s*(.*?)(?=\. (?:{labels}):|\. Use\b|\. Avoid\b|$)"
|
|
||||||
match = re.search(pattern, text)
|
|
||||||
if not match:
|
|
||||||
return ""
|
|
||||||
return _clean(match.group(1)).rstrip(".")
|
|
||||||
|
|
||||||
|
|
||||||
def _row_value(row: dict[str, Any], key: str, labels: tuple[str, ...] = ()) -> str:
|
def _row_value(row: dict[str, Any], key: str, labels: tuple[str, ...] = ()) -> str:
|
||||||
value = _clean(row.get(key, ""))
|
return input_policy.row_value(row, key, labels, field_labels=PROMPT_FIELD_LABELS)
|
||||||
if value:
|
|
||||||
return value
|
|
||||||
prompt = _clean(row.get("prompt", ""))
|
|
||||||
for label in labels:
|
|
||||||
value = _prompt_field(prompt, label)
|
|
||||||
if value:
|
|
||||||
return value
|
|
||||||
return ""
|
|
||||||
|
|
||||||
|
|
||||||
def _body_phrase(body: Any, figure_note: Any = "") -> str:
|
def _body_phrase(body: Any, figure_note: Any = "") -> str:
|
||||||
|
|||||||
+9
-51
@@ -1,13 +1,14 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
|
||||||
import re
|
import re
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
from . import formatter_input as input_policy
|
||||||
from .hardcore_action_metadata import normalize_hardcore_action_family
|
from .hardcore_action_metadata import normalize_hardcore_action_family
|
||||||
from .prompt_hygiene import sanitize_negative_text, sanitize_tag_prompt
|
from .prompt_hygiene import sanitize_negative_text, sanitize_tag_prompt
|
||||||
except ImportError: # Allows local smoke tests with `python -c`.
|
except ImportError: # Allows local smoke tests with `python -c`.
|
||||||
|
import formatter_input as input_policy
|
||||||
from hardcore_action_metadata import normalize_hardcore_action_family
|
from hardcore_action_metadata import normalize_hardcore_action_family
|
||||||
from prompt_hygiene import sanitize_negative_text, sanitize_tag_prompt
|
from prompt_hygiene import sanitize_negative_text, sanitize_tag_prompt
|
||||||
|
|
||||||
@@ -95,74 +96,31 @@ def sdxl_quality_preset_choices() -> list[str]:
|
|||||||
|
|
||||||
|
|
||||||
def _clean(value: Any) -> str:
|
def _clean(value: Any) -> str:
|
||||||
text = "" if value is None else str(value)
|
return input_policy.clean_text(value)
|
||||||
text = text.replace("\n", " ")
|
|
||||||
text = re.sub(r"\s+", " ", text).strip()
|
|
||||||
text = re.sub(r"\s+([,.;:])", r"\1", text)
|
|
||||||
return text
|
|
||||||
|
|
||||||
|
|
||||||
def _maybe_json(text: str) -> dict[str, Any] | None:
|
def _maybe_json(text: str) -> dict[str, Any] | None:
|
||||||
text = _clean(text)
|
return input_policy.maybe_json(text)
|
||||||
if not text.startswith("{"):
|
|
||||||
return None
|
|
||||||
try:
|
|
||||||
value = json.loads(text)
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
return None
|
|
||||||
return value if isinstance(value, dict) else None
|
|
||||||
|
|
||||||
|
|
||||||
def _row_from_inputs(source_text: str, metadata_json: str, input_hint: str) -> tuple[dict[str, Any] | None, str]:
|
def _row_from_inputs(source_text: str, metadata_json: str, input_hint: str) -> tuple[dict[str, Any] | None, str]:
|
||||||
if input_hint in ("auto", "metadata_json"):
|
return input_policy.row_from_inputs(source_text, metadata_json, input_hint)
|
||||||
for text, method in ((metadata_json, "metadata_json"), (source_text, "source_json")):
|
|
||||||
row = _maybe_json(text)
|
|
||||||
if row is not None:
|
|
||||||
return row, method
|
|
||||||
return None, "text"
|
|
||||||
|
|
||||||
|
|
||||||
def _strip_trigger(text: str, preserve_trigger: bool) -> str:
|
def _strip_trigger(text: str, preserve_trigger: bool) -> str:
|
||||||
text = _clean(text)
|
return input_policy.strip_trigger_prefix(text, TRIGGER_CANDIDATES, preserve_trigger=preserve_trigger)
|
||||||
if preserve_trigger:
|
|
||||||
return text
|
|
||||||
for trigger in TRIGGER_CANDIDATES:
|
|
||||||
if text.lower().startswith(trigger.lower() + ","):
|
|
||||||
return text[len(trigger) + 1 :].strip(" ,")
|
|
||||||
if text.lower().startswith(trigger.lower() + "."):
|
|
||||||
return text[len(trigger) + 1 :].strip(" ,")
|
|
||||||
return text
|
|
||||||
|
|
||||||
|
|
||||||
def _split_avoid(text: str) -> tuple[str, str]:
|
def _split_avoid(text: str) -> tuple[str, str]:
|
||||||
match = re.search(r"\bAvoid:\s*(.*)$", text)
|
return input_policy.split_avoid(text)
|
||||||
if not match:
|
|
||||||
return text, ""
|
|
||||||
return text[: match.start()].strip(" ."), match.group(1).strip(" .")
|
|
||||||
|
|
||||||
|
|
||||||
def _prompt_field(text: str, label: str) -> str:
|
def _prompt_field(text: str, label: str) -> str:
|
||||||
text = _clean(text)
|
return input_policy.prompt_field(text, label, field_labels=PROMPT_FIELD_LABELS)
|
||||||
if not text:
|
|
||||||
return ""
|
|
||||||
labels = "|".join(re.escape(name) for name in PROMPT_FIELD_LABELS)
|
|
||||||
pattern = rf"{re.escape(label)}:\s*(.*?)(?=\. (?:{labels}):|\. Use\b|\. Avoid\b|$)"
|
|
||||||
match = re.search(pattern, text)
|
|
||||||
if not match:
|
|
||||||
return ""
|
|
||||||
return _clean(match.group(1)).rstrip(".")
|
|
||||||
|
|
||||||
|
|
||||||
def _row_value(row: dict[str, Any], key: str, labels: tuple[str, ...] = ()) -> str:
|
def _row_value(row: dict[str, Any], key: str, labels: tuple[str, ...] = ()) -> str:
|
||||||
value = _clean(row.get(key, ""))
|
return input_policy.row_value(row, key, labels, field_labels=PROMPT_FIELD_LABELS)
|
||||||
if value:
|
|
||||||
return value
|
|
||||||
prompt = _clean(row.get("prompt", ""))
|
|
||||||
for label in labels:
|
|
||||||
value = _prompt_field(prompt, label)
|
|
||||||
if value:
|
|
||||||
return value
|
|
||||||
return ""
|
|
||||||
|
|
||||||
|
|
||||||
def _split_tag_text(text: Any) -> list[str]:
|
def _split_tag_text(text: Any) -> list[str]:
|
||||||
|
|||||||
@@ -29,6 +29,7 @@ import character_profile # noqa: E402
|
|||||||
import category_cast_config # noqa: E402
|
import category_cast_config # noqa: E402
|
||||||
import category_library # noqa: E402
|
import category_library # noqa: E402
|
||||||
import filter_config # noqa: E402
|
import filter_config # noqa: E402
|
||||||
|
import formatter_input # noqa: E402
|
||||||
import hardcore_position_config # noqa: E402
|
import hardcore_position_config # noqa: E402
|
||||||
import __init__ as sxcp_nodes # noqa: E402
|
import __init__ as sxcp_nodes # noqa: E402
|
||||||
import generation_profile_config # noqa: E402
|
import generation_profile_config # noqa: E402
|
||||||
@@ -847,6 +848,57 @@ def smoke_row_normalization_policy() -> None:
|
|||||||
_expect_no_duplicate_comma_items("row_normalization.pair.hard_row_negative", pair["hardcore_row"].get("negative_prompt"))
|
_expect_no_duplicate_comma_items("row_normalization.pair.hard_row_negative", pair["hardcore_row"].get("negative_prompt"))
|
||||||
|
|
||||||
|
|
||||||
|
def smoke_formatter_input_policy() -> None:
|
||||||
|
source_row = {
|
||||||
|
"prompt": "A simple adult portrait. Setting: quiet studio. Pose: standing calmly. Avoid: low quality.",
|
||||||
|
"caption": "adult portrait, quiet studio",
|
||||||
|
"negative_prompt": "low quality",
|
||||||
|
"subject_type": "woman",
|
||||||
|
"primary_subject": "woman",
|
||||||
|
"age": "25-year-old adult",
|
||||||
|
"body_phrase": "average figure",
|
||||||
|
"skin": "warm skin",
|
||||||
|
"hair": "dark hair",
|
||||||
|
"eyes": "brown eyes",
|
||||||
|
"item": "black dress",
|
||||||
|
"scene_text": "quiet studio",
|
||||||
|
"pose": "standing calmly",
|
||||||
|
"composition": "centered portrait",
|
||||||
|
"trigger": Trigger,
|
||||||
|
}
|
||||||
|
source_json = _json(source_row)
|
||||||
|
|
||||||
|
row, method = formatter_input.row_from_inputs(source_json, "", "auto")
|
||||||
|
_expect(method == "source_json", "Formatter input parser should read source JSON when metadata is empty")
|
||||||
|
_expect(row == source_row, "Formatter input parser changed parsed JSON row")
|
||||||
|
_expect(formatter_input.split_avoid("Prompt body. Avoid: blur, watermark") == ("Prompt body", "blur, watermark"), "Avoid split changed")
|
||||||
|
_expect(
|
||||||
|
formatter_input.prompt_field(source_row["prompt"], "Setting") == "quiet studio",
|
||||||
|
"Prompt field extraction changed",
|
||||||
|
)
|
||||||
|
_expect(
|
||||||
|
formatter_input.row_value({"prompt": source_row["prompt"]}, "scene_text", ("Setting",)) == "quiet studio",
|
||||||
|
"Row value prompt fallback changed",
|
||||||
|
)
|
||||||
|
|
||||||
|
_expect(krea_formatter._clean("a b , c") == formatter_input.clean_text("a b , c"), "Krea clean helper is not delegated")
|
||||||
|
_expect(sdxl_formatter._clean("a b , c") == formatter_input.clean_text("a b , c"), "SDXL clean helper is not delegated")
|
||||||
|
_expect(caption_naturalizer._clean_text("a b , c") == formatter_input.clean_text("a b , c"), "Caption clean helper is not delegated")
|
||||||
|
_expect(krea_formatter._strip_trigger(f"{Trigger}, prompt text", False) == "prompt text", "Krea trigger stripping changed")
|
||||||
|
_expect(sdxl_formatter._strip_trigger(f"{SdxlTrigger}, prompt text", False) == "prompt text", "SDXL trigger stripping changed")
|
||||||
|
_expect(caption_naturalizer._remove_trigger(Trigger, Trigger) == "", "Caption exact-trigger removal changed")
|
||||||
|
|
||||||
|
krea = krea_formatter.format_krea2_prompt(source_json, input_hint="auto")
|
||||||
|
sdxl = sdxl_formatter.format_sdxl_prompt(source_json, input_hint="auto", trigger=SdxlTrigger, prepend_trigger=True)
|
||||||
|
caption, caption_method = caption_naturalizer.naturalize_caption(source_json, input_hint="auto", trigger=Trigger)
|
||||||
|
_expect(krea.get("method", "").startswith("source_json:krea2("), "Krea formatter did not use shared source JSON parsing")
|
||||||
|
_expect(sdxl.get("method", "").startswith("source_json:sdxl("), "SDXL formatter did not use shared source JSON parsing")
|
||||||
|
_expect(caption_method.startswith("source_json:metadata("), "Caption naturalizer did not use shared source JSON parsing")
|
||||||
|
_expect_text("formatter_input.krea_prompt", krea.get("krea_prompt"), 20)
|
||||||
|
_expect_text("formatter_input.sdxl_prompt", sdxl.get("sdxl_prompt"), 20)
|
||||||
|
_expect_text("formatter_input.caption", caption, 20)
|
||||||
|
|
||||||
|
|
||||||
def smoke_hardcore_position_config_policy() -> None:
|
def smoke_hardcore_position_config_policy() -> None:
|
||||||
_expect(
|
_expect(
|
||||||
pb.HARDCORE_POSITION_FAMILY_CHOICES is hardcore_position_config.HARDCORE_POSITION_FAMILY_CHOICES,
|
pb.HARDCORE_POSITION_FAMILY_CHOICES is hardcore_position_config.HARDCORE_POSITION_FAMILY_CHOICES,
|
||||||
@@ -2818,6 +2870,7 @@ SMOKE_CASES: list[tuple[str, Callable[[], None]]] = [
|
|||||||
("character_config_policy", smoke_character_config_policy),
|
("character_config_policy", smoke_character_config_policy),
|
||||||
("character_profile_policy", smoke_character_profile_policy),
|
("character_profile_policy", smoke_character_profile_policy),
|
||||||
("row_normalization_policy", smoke_row_normalization_policy),
|
("row_normalization_policy", smoke_row_normalization_policy),
|
||||||
|
("formatter_input_policy", smoke_formatter_input_policy),
|
||||||
("hardcore_position_config_policy", smoke_hardcore_position_config_policy),
|
("hardcore_position_config_policy", smoke_hardcore_position_config_policy),
|
||||||
("category_library_route", smoke_category_library_route),
|
("category_library_route", smoke_category_library_route),
|
||||||
("hardcore_category_routes", smoke_hardcore_category_routes),
|
("hardcore_category_routes", smoke_hardcore_category_routes),
|
||||||
|
|||||||
Reference in New Issue
Block a user