Extract formatter input parsing policy
This commit is contained in:
+13
-54
@@ -1,13 +1,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
try:
|
||||
from . import formatter_input as input_policy
|
||||
from .hardcore_action_metadata import normalize_hardcore_action_family
|
||||
from .prompt_hygiene import sanitize_prose_text
|
||||
except ImportError: # Allows local smoke tests with `python -c`.
|
||||
import formatter_input as input_policy
|
||||
from hardcore_action_metadata import normalize_hardcore_action_family
|
||||
from prompt_hygiene import sanitize_prose_text
|
||||
|
||||
@@ -71,11 +72,7 @@ POSITION_FAMILY_CAPTION_LABELS = {
|
||||
|
||||
|
||||
def _clean_text(value: Any) -> str:
|
||||
text = "" if value is None else str(value)
|
||||
text = text.replace("\n", " ")
|
||||
text = re.sub(r"\s+", " ", text).strip()
|
||||
text = re.sub(r"\s+([,.;:])", r"\1", text)
|
||||
return text
|
||||
return input_policy.clean_text(value)
|
||||
|
||||
|
||||
def _is_false(value: Any) -> bool:
|
||||
@@ -189,18 +186,11 @@ def _strip_style_tail(text: str) -> str:
|
||||
|
||||
|
||||
def _remove_trigger(text: str, trigger: str) -> str:
|
||||
text = _clean_text(text).strip(" ,")
|
||||
for candidate in (trigger, OLD_TRIGGER, DEFAULT_TRIGGER):
|
||||
candidate = candidate.strip()
|
||||
if not candidate:
|
||||
continue
|
||||
if text.lower().startswith(candidate.lower() + ","):
|
||||
return text[len(candidate) + 1 :].strip(" ,")
|
||||
if text.lower().startswith(candidate.lower() + "."):
|
||||
return text[len(candidate) + 1 :].strip(" ,")
|
||||
if text.lower() == candidate.lower():
|
||||
return ""
|
||||
return text
|
||||
return input_policy.strip_trigger_prefix(
|
||||
text,
|
||||
(trigger, OLD_TRIGGER, DEFAULT_TRIGGER),
|
||||
remove_exact=True,
|
||||
)
|
||||
|
||||
|
||||
def _with_trigger(text: str, trigger: str, include_trigger: bool) -> str:
|
||||
@@ -214,55 +204,24 @@ def _with_trigger(text: str, trigger: str, include_trigger: bool) -> str:
|
||||
|
||||
|
||||
def _maybe_json(text: str) -> dict[str, Any] | None:
|
||||
text = _clean_text(text)
|
||||
if not text or not text.startswith("{"):
|
||||
return None
|
||||
try:
|
||||
value = json.loads(text)
|
||||
except json.JSONDecodeError:
|
||||
return None
|
||||
return value if isinstance(value, dict) else None
|
||||
return input_policy.maybe_json(text)
|
||||
|
||||
|
||||
def _row_from_inputs(source_text: str, metadata_json: str, input_hint: str) -> tuple[dict[str, Any] | None, str]:
|
||||
candidates: list[tuple[str, str]] = []
|
||||
if input_hint in ("auto", "metadata_json"):
|
||||
candidates.append((metadata_json, "metadata_json"))
|
||||
candidates.append((source_text, "source_json"))
|
||||
for text, method in candidates:
|
||||
row = _maybe_json(text)
|
||||
if row is not None:
|
||||
return row, method
|
||||
return None, "text"
|
||||
return input_policy.row_from_inputs(source_text, metadata_json, input_hint)
|
||||
|
||||
|
||||
def _prompt_field(text: str, label: str) -> str:
|
||||
text = _clean_text(text)
|
||||
if not text:
|
||||
return ""
|
||||
labels = "|".join(re.escape(name) for name in PROMPT_FIELD_LABELS)
|
||||
pattern = rf"{re.escape(label)}:\s*(.*?)(?=\. (?:{labels}):|\. Use\b|\. Avoid\b|$)"
|
||||
match = re.search(pattern, text)
|
||||
if not match:
|
||||
return ""
|
||||
return _clean_text(match.group(1)).rstrip(".")
|
||||
return input_policy.prompt_field(text, label, field_labels=PROMPT_FIELD_LABELS)
|
||||
|
||||
|
||||
def _row_value(row: dict[str, Any], key: str, labels: tuple[str, ...] = ()) -> str:
|
||||
value = _clean_text(row.get(key, ""))
|
||||
if value:
|
||||
return value
|
||||
prompt = _clean_text(row.get("prompt", ""))
|
||||
for label in labels:
|
||||
value = _prompt_field(prompt, label)
|
||||
if value:
|
||||
return value
|
||||
return ""
|
||||
return input_policy.row_value(row, key, labels, field_labels=PROMPT_FIELD_LABELS)
|
||||
|
||||
|
||||
def _field_from_any_prompt(text: str, labels: tuple[str, ...]) -> str:
|
||||
for label in labels:
|
||||
value = _prompt_field(text, label)
|
||||
value = input_policy.prompt_field(text, label, field_labels=PROMPT_FIELD_LABELS)
|
||||
if value:
|
||||
return value
|
||||
return ""
|
||||
|
||||
@@ -62,6 +62,23 @@ route-specific owner. It also preserves ordinary words such as `composition`
|
||||
inside normal sentences; empty field-label cleanup is limited to standalone
|
||||
labels.
|
||||
|
||||
Formatter input/fallback parsing now has one home:
|
||||
|
||||
- `formatter_input.py`
|
||||
|
||||
It owns route-neutral parsing shared by Krea2, SDXL, and natural-caption
|
||||
routes:
|
||||
|
||||
- whitespace and punctuation normalization before formatter parsing;
|
||||
- JSON row detection from `metadata_json` or source text;
|
||||
- trigger-prefix stripping with route-specific trigger candidate lists;
|
||||
- `Avoid:` positive/negative splitting for fallback text;
|
||||
- prompt field extraction such as `Setting:` or `Composition:`;
|
||||
- row-value fallback from metadata fields to labeled prompt text.
|
||||
|
||||
It must not make formatter-style decisions. Krea prose, SDXL tags, and training
|
||||
caption sentence shape stay in their formatter modules.
|
||||
|
||||
Shared hardcore phrase cleanup now has one home:
|
||||
|
||||
- `hardcore_text_cleanup.py`
|
||||
@@ -242,6 +259,9 @@ Already isolated:
|
||||
- `krea_pov_actions.py` owns POV hardcore action sentence rewriting,
|
||||
first-person body geometry, and selected-position-axis priority before loose
|
||||
context fallback.
|
||||
- `formatter_input.py` owns shared metadata/source JSON detection, trigger
|
||||
stripping, prompt-field extraction, `Avoid:` splitting, and row-value
|
||||
fallback for Krea, SDXL, and caption routes.
|
||||
|
||||
Improve later:
|
||||
|
||||
@@ -262,6 +282,7 @@ Keep here:
|
||||
- negative-prompt assembly.
|
||||
- metadata-family tag hints from `action_family`, `position_family`, and
|
||||
`position_keys`.
|
||||
- shared formatter input parsing from `formatter_input.py`.
|
||||
|
||||
Improve later:
|
||||
|
||||
@@ -280,6 +301,7 @@ Keep here:
|
||||
- training-caption trigger behavior;
|
||||
- style-tail policy.
|
||||
- metadata-family action labels from `action_family` and `position_family`.
|
||||
- shared formatter input parsing from `formatter_input.py`.
|
||||
|
||||
Improve later:
|
||||
|
||||
|
||||
@@ -94,6 +94,7 @@ Core helper ownership:
|
||||
| `scene_camera_adapters.py` | Location-aware camera/scene prose such as coworking lounge camera layout. |
|
||||
| `prompt_hygiene.py` | Generic prompt, caption, and negative-prompt cleanup. |
|
||||
| `row_normalization.py` | Final prompt-row and pair metadata normalization: trigger prepending, extra-positive append, negative merge/dedupe, caption-part joining, and embedded soft/hard row sanitation. |
|
||||
| `formatter_input.py` | Shared formatter input parsing: text cleanup, metadata/source JSON detection, trigger-prefix stripping, `Avoid:` splitting, prompt-field extraction, and metadata row-value fallback. |
|
||||
|
||||
## Node IO Map
|
||||
|
||||
|
||||
@@ -0,0 +1,132 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
|
||||
DEFAULT_PROMPT_FIELD_LABELS = (
|
||||
"Ages",
|
||||
"Body types",
|
||||
"Cast",
|
||||
"Cast descriptors",
|
||||
"Characters",
|
||||
"Scene",
|
||||
"Setting",
|
||||
"Pose",
|
||||
"Sexual pose",
|
||||
"Sexual scene",
|
||||
"Facial expression",
|
||||
"Facial expressions",
|
||||
"Clothing",
|
||||
"Erotic outfit",
|
||||
"Prop/detail",
|
||||
"Composition",
|
||||
"Role graph",
|
||||
"Camera",
|
||||
"Camera control",
|
||||
"Use",
|
||||
"Avoid",
|
||||
)
|
||||
|
||||
|
||||
def clean_text(value: Any) -> str:
|
||||
text = "" if value is None else str(value)
|
||||
text = text.replace("\n", " ")
|
||||
text = re.sub(r"\s+", " ", text).strip()
|
||||
text = re.sub(r"\s+([,.;:])", r"\1", text)
|
||||
return text
|
||||
|
||||
|
||||
def maybe_json(text: Any) -> dict[str, Any] | None:
|
||||
text = clean_text(text)
|
||||
if not text.startswith("{"):
|
||||
return None
|
||||
try:
|
||||
value = json.loads(text)
|
||||
except json.JSONDecodeError:
|
||||
return None
|
||||
return value if isinstance(value, dict) else None
|
||||
|
||||
|
||||
def row_from_inputs(
|
||||
source_text: str,
|
||||
metadata_json: str,
|
||||
input_hint: str,
|
||||
*,
|
||||
metadata_methods: tuple[str, ...] = ("auto", "metadata_json"),
|
||||
) -> tuple[dict[str, Any] | None, str]:
|
||||
if input_hint in metadata_methods:
|
||||
for text, method in ((metadata_json, "metadata_json"), (source_text, "source_json")):
|
||||
row = maybe_json(text)
|
||||
if row is not None:
|
||||
return row, method
|
||||
return None, "text"
|
||||
|
||||
|
||||
def strip_trigger_prefix(
|
||||
text: Any,
|
||||
trigger_candidates: tuple[str, ...] | list[str],
|
||||
*,
|
||||
preserve_trigger: bool = False,
|
||||
remove_exact: bool = False,
|
||||
) -> str:
|
||||
text = clean_text(text)
|
||||
if remove_exact:
|
||||
text = text.strip(" ,")
|
||||
if preserve_trigger:
|
||||
return text
|
||||
for trigger in trigger_candidates:
|
||||
trigger = clean_text(trigger)
|
||||
if not trigger:
|
||||
continue
|
||||
if text.lower().startswith(trigger.lower() + ","):
|
||||
return text[len(trigger) + 1 :].strip(" ,")
|
||||
if text.lower().startswith(trigger.lower() + "."):
|
||||
return text[len(trigger) + 1 :].strip(" ,")
|
||||
if remove_exact and text.lower() == trigger.lower():
|
||||
return ""
|
||||
return text
|
||||
|
||||
|
||||
def split_avoid(text: Any) -> tuple[str, str]:
|
||||
text = clean_text(text)
|
||||
match = re.search(r"\bAvoid:\s*(.*)$", text)
|
||||
if not match:
|
||||
return text, ""
|
||||
return text[: match.start()].strip(" ."), match.group(1).strip(" .")
|
||||
|
||||
|
||||
def prompt_field(
|
||||
text: Any,
|
||||
label: str,
|
||||
*,
|
||||
field_labels: tuple[str, ...] | list[str] = DEFAULT_PROMPT_FIELD_LABELS,
|
||||
) -> str:
|
||||
text = clean_text(text)
|
||||
if not text:
|
||||
return ""
|
||||
labels = "|".join(re.escape(name) for name in field_labels)
|
||||
pattern = rf"{re.escape(label)}:\s*(.*?)(?=\. (?:{labels}):|\. Use\b|\. Avoid\b|$)"
|
||||
match = re.search(pattern, text)
|
||||
if not match:
|
||||
return ""
|
||||
return clean_text(match.group(1)).rstrip(".")
|
||||
|
||||
|
||||
def row_value(
|
||||
row: dict[str, Any],
|
||||
key: str,
|
||||
labels: tuple[str, ...] = (),
|
||||
*,
|
||||
field_labels: tuple[str, ...] | list[str] = DEFAULT_PROMPT_FIELD_LABELS,
|
||||
) -> str:
|
||||
value = clean_text(row.get(key, ""))
|
||||
if value:
|
||||
return value
|
||||
prompt = clean_text(row.get("prompt", ""))
|
||||
for label in labels:
|
||||
value = prompt_field(prompt, label, field_labels=field_labels)
|
||||
if value:
|
||||
return value
|
||||
return ""
|
||||
+9
-54
@@ -1,10 +1,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
try:
|
||||
from . import formatter_input as input_policy
|
||||
from .krea_action_context import (
|
||||
is_close_foreplay_text as _is_close_foreplay_text,
|
||||
is_outercourse_text as _is_outercourse_text,
|
||||
@@ -34,6 +34,7 @@ try:
|
||||
from .krea_pov_actions import pov_action_phrase as _pov_action_phrase
|
||||
from .prompt_hygiene import sanitize_negative_text, sanitize_prose_text
|
||||
except ImportError: # Allows local smoke tests with `python -c`.
|
||||
import formatter_input as input_policy
|
||||
from krea_action_context import (
|
||||
is_close_foreplay_text as _is_close_foreplay_text,
|
||||
is_outercourse_text as _is_outercourse_text,
|
||||
@@ -91,11 +92,7 @@ PROMPT_FIELD_LABELS = (
|
||||
|
||||
|
||||
def _clean(value: Any) -> str:
|
||||
text = "" if value is None else str(value)
|
||||
text = text.replace("\n", " ")
|
||||
text = re.sub(r"\s+", " ", text).strip()
|
||||
text = re.sub(r"\s+([,.;:])", r"\1", text)
|
||||
return text
|
||||
return input_policy.clean_text(value)
|
||||
|
||||
|
||||
def _is_false(value: Any) -> bool:
|
||||
@@ -133,69 +130,27 @@ def _with_indefinite_article(text: str) -> str:
|
||||
|
||||
|
||||
def _maybe_json(text: str) -> dict[str, Any] | None:
|
||||
text = _clean(text)
|
||||
if not text.startswith("{"):
|
||||
return None
|
||||
try:
|
||||
value = json.loads(text)
|
||||
except json.JSONDecodeError:
|
||||
return None
|
||||
return value if isinstance(value, dict) else None
|
||||
return input_policy.maybe_json(text)
|
||||
|
||||
|
||||
def _row_from_inputs(source_text: str, metadata_json: str, input_hint: str) -> tuple[dict[str, Any] | None, str]:
|
||||
candidates: list[tuple[str, str]] = []
|
||||
if input_hint in ("auto", "metadata_json"):
|
||||
candidates.append((metadata_json, "metadata_json"))
|
||||
candidates.append((source_text, "source_json"))
|
||||
for text, method in candidates:
|
||||
row = _maybe_json(text)
|
||||
if row is not None:
|
||||
return row, method
|
||||
return None, "text"
|
||||
return input_policy.row_from_inputs(source_text, metadata_json, input_hint)
|
||||
|
||||
|
||||
def _strip_trigger(text: str, preserve_trigger: bool) -> str:
|
||||
text = _clean(text)
|
||||
if preserve_trigger:
|
||||
return text
|
||||
for trigger in TRIGGER_CANDIDATES:
|
||||
if text.lower().startswith(trigger.lower() + ","):
|
||||
return text[len(trigger) + 1 :].strip(" ,")
|
||||
if text.lower().startswith(trigger.lower() + "."):
|
||||
return text[len(trigger) + 1 :].strip(" ,")
|
||||
return text
|
||||
return input_policy.strip_trigger_prefix(text, TRIGGER_CANDIDATES, preserve_trigger=preserve_trigger)
|
||||
|
||||
|
||||
def _split_avoid(text: str) -> tuple[str, str]:
|
||||
match = re.search(r"\bAvoid:\s*(.*)$", text)
|
||||
if not match:
|
||||
return text, ""
|
||||
return text[: match.start()].strip(" ."), match.group(1).strip(" .")
|
||||
return input_policy.split_avoid(text)
|
||||
|
||||
|
||||
def _prompt_field(text: str, label: str) -> str:
|
||||
text = _clean(text)
|
||||
if not text:
|
||||
return ""
|
||||
labels = "|".join(re.escape(name) for name in PROMPT_FIELD_LABELS)
|
||||
pattern = rf"{re.escape(label)}:\s*(.*?)(?=\. (?:{labels}):|\. Use\b|\. Avoid\b|$)"
|
||||
match = re.search(pattern, text)
|
||||
if not match:
|
||||
return ""
|
||||
return _clean(match.group(1)).rstrip(".")
|
||||
return input_policy.prompt_field(text, label, field_labels=PROMPT_FIELD_LABELS)
|
||||
|
||||
|
||||
def _row_value(row: dict[str, Any], key: str, labels: tuple[str, ...] = ()) -> str:
|
||||
value = _clean(row.get(key, ""))
|
||||
if value:
|
||||
return value
|
||||
prompt = _clean(row.get("prompt", ""))
|
||||
for label in labels:
|
||||
value = _prompt_field(prompt, label)
|
||||
if value:
|
||||
return value
|
||||
return ""
|
||||
return input_policy.row_value(row, key, labels, field_labels=PROMPT_FIELD_LABELS)
|
||||
|
||||
|
||||
def _body_phrase(body: Any, figure_note: Any = "") -> str:
|
||||
|
||||
+9
-51
@@ -1,13 +1,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
try:
|
||||
from . import formatter_input as input_policy
|
||||
from .hardcore_action_metadata import normalize_hardcore_action_family
|
||||
from .prompt_hygiene import sanitize_negative_text, sanitize_tag_prompt
|
||||
except ImportError: # Allows local smoke tests with `python -c`.
|
||||
import formatter_input as input_policy
|
||||
from hardcore_action_metadata import normalize_hardcore_action_family
|
||||
from prompt_hygiene import sanitize_negative_text, sanitize_tag_prompt
|
||||
|
||||
@@ -95,74 +96,31 @@ def sdxl_quality_preset_choices() -> list[str]:
|
||||
|
||||
|
||||
def _clean(value: Any) -> str:
|
||||
text = "" if value is None else str(value)
|
||||
text = text.replace("\n", " ")
|
||||
text = re.sub(r"\s+", " ", text).strip()
|
||||
text = re.sub(r"\s+([,.;:])", r"\1", text)
|
||||
return text
|
||||
return input_policy.clean_text(value)
|
||||
|
||||
|
||||
def _maybe_json(text: str) -> dict[str, Any] | None:
|
||||
text = _clean(text)
|
||||
if not text.startswith("{"):
|
||||
return None
|
||||
try:
|
||||
value = json.loads(text)
|
||||
except json.JSONDecodeError:
|
||||
return None
|
||||
return value if isinstance(value, dict) else None
|
||||
return input_policy.maybe_json(text)
|
||||
|
||||
|
||||
def _row_from_inputs(source_text: str, metadata_json: str, input_hint: str) -> tuple[dict[str, Any] | None, str]:
|
||||
if input_hint in ("auto", "metadata_json"):
|
||||
for text, method in ((metadata_json, "metadata_json"), (source_text, "source_json")):
|
||||
row = _maybe_json(text)
|
||||
if row is not None:
|
||||
return row, method
|
||||
return None, "text"
|
||||
return input_policy.row_from_inputs(source_text, metadata_json, input_hint)
|
||||
|
||||
|
||||
def _strip_trigger(text: str, preserve_trigger: bool) -> str:
|
||||
text = _clean(text)
|
||||
if preserve_trigger:
|
||||
return text
|
||||
for trigger in TRIGGER_CANDIDATES:
|
||||
if text.lower().startswith(trigger.lower() + ","):
|
||||
return text[len(trigger) + 1 :].strip(" ,")
|
||||
if text.lower().startswith(trigger.lower() + "."):
|
||||
return text[len(trigger) + 1 :].strip(" ,")
|
||||
return text
|
||||
return input_policy.strip_trigger_prefix(text, TRIGGER_CANDIDATES, preserve_trigger=preserve_trigger)
|
||||
|
||||
|
||||
def _split_avoid(text: str) -> tuple[str, str]:
|
||||
match = re.search(r"\bAvoid:\s*(.*)$", text)
|
||||
if not match:
|
||||
return text, ""
|
||||
return text[: match.start()].strip(" ."), match.group(1).strip(" .")
|
||||
return input_policy.split_avoid(text)
|
||||
|
||||
|
||||
def _prompt_field(text: str, label: str) -> str:
|
||||
text = _clean(text)
|
||||
if not text:
|
||||
return ""
|
||||
labels = "|".join(re.escape(name) for name in PROMPT_FIELD_LABELS)
|
||||
pattern = rf"{re.escape(label)}:\s*(.*?)(?=\. (?:{labels}):|\. Use\b|\. Avoid\b|$)"
|
||||
match = re.search(pattern, text)
|
||||
if not match:
|
||||
return ""
|
||||
return _clean(match.group(1)).rstrip(".")
|
||||
return input_policy.prompt_field(text, label, field_labels=PROMPT_FIELD_LABELS)
|
||||
|
||||
|
||||
def _row_value(row: dict[str, Any], key: str, labels: tuple[str, ...] = ()) -> str:
|
||||
value = _clean(row.get(key, ""))
|
||||
if value:
|
||||
return value
|
||||
prompt = _clean(row.get("prompt", ""))
|
||||
for label in labels:
|
||||
value = _prompt_field(prompt, label)
|
||||
if value:
|
||||
return value
|
||||
return ""
|
||||
return input_policy.row_value(row, key, labels, field_labels=PROMPT_FIELD_LABELS)
|
||||
|
||||
|
||||
def _split_tag_text(text: Any) -> list[str]:
|
||||
|
||||
@@ -29,6 +29,7 @@ import character_profile # noqa: E402
|
||||
import category_cast_config # noqa: E402
|
||||
import category_library # noqa: E402
|
||||
import filter_config # noqa: E402
|
||||
import formatter_input # noqa: E402
|
||||
import hardcore_position_config # noqa: E402
|
||||
import __init__ as sxcp_nodes # noqa: E402
|
||||
import generation_profile_config # noqa: E402
|
||||
@@ -847,6 +848,57 @@ def smoke_row_normalization_policy() -> None:
|
||||
_expect_no_duplicate_comma_items("row_normalization.pair.hard_row_negative", pair["hardcore_row"].get("negative_prompt"))
|
||||
|
||||
|
||||
def smoke_formatter_input_policy() -> None:
|
||||
source_row = {
|
||||
"prompt": "A simple adult portrait. Setting: quiet studio. Pose: standing calmly. Avoid: low quality.",
|
||||
"caption": "adult portrait, quiet studio",
|
||||
"negative_prompt": "low quality",
|
||||
"subject_type": "woman",
|
||||
"primary_subject": "woman",
|
||||
"age": "25-year-old adult",
|
||||
"body_phrase": "average figure",
|
||||
"skin": "warm skin",
|
||||
"hair": "dark hair",
|
||||
"eyes": "brown eyes",
|
||||
"item": "black dress",
|
||||
"scene_text": "quiet studio",
|
||||
"pose": "standing calmly",
|
||||
"composition": "centered portrait",
|
||||
"trigger": Trigger,
|
||||
}
|
||||
source_json = _json(source_row)
|
||||
|
||||
row, method = formatter_input.row_from_inputs(source_json, "", "auto")
|
||||
_expect(method == "source_json", "Formatter input parser should read source JSON when metadata is empty")
|
||||
_expect(row == source_row, "Formatter input parser changed parsed JSON row")
|
||||
_expect(formatter_input.split_avoid("Prompt body. Avoid: blur, watermark") == ("Prompt body", "blur, watermark"), "Avoid split changed")
|
||||
_expect(
|
||||
formatter_input.prompt_field(source_row["prompt"], "Setting") == "quiet studio",
|
||||
"Prompt field extraction changed",
|
||||
)
|
||||
_expect(
|
||||
formatter_input.row_value({"prompt": source_row["prompt"]}, "scene_text", ("Setting",)) == "quiet studio",
|
||||
"Row value prompt fallback changed",
|
||||
)
|
||||
|
||||
_expect(krea_formatter._clean("a b , c") == formatter_input.clean_text("a b , c"), "Krea clean helper is not delegated")
|
||||
_expect(sdxl_formatter._clean("a b , c") == formatter_input.clean_text("a b , c"), "SDXL clean helper is not delegated")
|
||||
_expect(caption_naturalizer._clean_text("a b , c") == formatter_input.clean_text("a b , c"), "Caption clean helper is not delegated")
|
||||
_expect(krea_formatter._strip_trigger(f"{Trigger}, prompt text", False) == "prompt text", "Krea trigger stripping changed")
|
||||
_expect(sdxl_formatter._strip_trigger(f"{SdxlTrigger}, prompt text", False) == "prompt text", "SDXL trigger stripping changed")
|
||||
_expect(caption_naturalizer._remove_trigger(Trigger, Trigger) == "", "Caption exact-trigger removal changed")
|
||||
|
||||
krea = krea_formatter.format_krea2_prompt(source_json, input_hint="auto")
|
||||
sdxl = sdxl_formatter.format_sdxl_prompt(source_json, input_hint="auto", trigger=SdxlTrigger, prepend_trigger=True)
|
||||
caption, caption_method = caption_naturalizer.naturalize_caption(source_json, input_hint="auto", trigger=Trigger)
|
||||
_expect(krea.get("method", "").startswith("source_json:krea2("), "Krea formatter did not use shared source JSON parsing")
|
||||
_expect(sdxl.get("method", "").startswith("source_json:sdxl("), "SDXL formatter did not use shared source JSON parsing")
|
||||
_expect(caption_method.startswith("source_json:metadata("), "Caption naturalizer did not use shared source JSON parsing")
|
||||
_expect_text("formatter_input.krea_prompt", krea.get("krea_prompt"), 20)
|
||||
_expect_text("formatter_input.sdxl_prompt", sdxl.get("sdxl_prompt"), 20)
|
||||
_expect_text("formatter_input.caption", caption, 20)
|
||||
|
||||
|
||||
def smoke_hardcore_position_config_policy() -> None:
|
||||
_expect(
|
||||
pb.HARDCORE_POSITION_FAMILY_CHOICES is hardcore_position_config.HARDCORE_POSITION_FAMILY_CHOICES,
|
||||
@@ -2818,6 +2870,7 @@ SMOKE_CASES: list[tuple[str, Callable[[], None]]] = [
|
||||
("character_config_policy", smoke_character_config_policy),
|
||||
("character_profile_policy", smoke_character_profile_policy),
|
||||
("row_normalization_policy", smoke_row_normalization_policy),
|
||||
("formatter_input_policy", smoke_formatter_input_policy),
|
||||
("hardcore_position_config_policy", smoke_hardcore_position_config_policy),
|
||||
("category_library_route", smoke_category_library_route),
|
||||
("hardcore_category_routes", smoke_hardcore_category_routes),
|
||||
|
||||
Reference in New Issue
Block a user