Extract row camera policy
This commit is contained in:
+21
-80
@@ -40,7 +40,7 @@ try:
|
||||
from . import pair_rows
|
||||
from . import pair_options
|
||||
from . import row_normalization as row_policy
|
||||
from . import scene_camera_adapters
|
||||
from . import row_camera as row_camera_policy
|
||||
from . import seed_config as seed_policy
|
||||
from .hardcore_text_cleanup import (
|
||||
sanitize_hardcore_axis_values as _sanitize_hardcore_axis_values,
|
||||
@@ -81,7 +81,7 @@ except ImportError: # Allows local smoke tests with `python -c`.
|
||||
import pair_rows
|
||||
import pair_options
|
||||
import row_normalization as row_policy
|
||||
import scene_camera_adapters
|
||||
import row_camera as row_camera_policy
|
||||
import seed_config as seed_policy
|
||||
from hardcore_text_cleanup import (
|
||||
sanitize_hardcore_axis_values as _sanitize_hardcore_axis_values,
|
||||
@@ -1699,42 +1699,19 @@ def _camera_directive(camera_config: str | dict[str, Any] | None) -> tuple[str,
|
||||
|
||||
|
||||
def _insert_positive_directive(prompt: str, directive: str) -> str:
|
||||
marker = " Avoid:"
|
||||
if marker in prompt:
|
||||
before, after = prompt.split(marker, 1)
|
||||
return f"{before.rstrip()} {directive}{marker}{after}"
|
||||
return f"{prompt.rstrip()} {directive}"
|
||||
return row_camera_policy.insert_positive_directive(prompt, directive)
|
||||
|
||||
|
||||
def _camera_caption_text(parsed: dict[str, Any]) -> str:
|
||||
return camera_policy.camera_caption_text(parsed)
|
||||
return row_camera_policy.camera_caption_text(parsed)
|
||||
|
||||
|
||||
def _coworking_composition_prompt(scene_text: Any, composition: Any, subject_kind: str = "subjects") -> str:
|
||||
return scene_camera_adapters.coworking_composition_prompt(scene_text, composition, subject_kind)
|
||||
return row_camera_policy.coworking_composition_prompt(scene_text, composition, subject_kind)
|
||||
|
||||
|
||||
def _apply_coworking_composition(row: dict[str, Any], subject_kind: str) -> dict[str, Any]:
|
||||
scene_text = row.get("scene_text") or row.get("source_scene_text") or row.get("scene")
|
||||
old_composition = str(row.get("composition") or "").strip()
|
||||
new_composition = _coworking_composition_prompt(scene_text, old_composition, subject_kind)
|
||||
if not old_composition or new_composition == old_composition:
|
||||
return row
|
||||
row["source_composition"] = row.get("source_composition") or old_composition
|
||||
row["composition"] = new_composition
|
||||
row["composition_prompt"] = _composition_prompt(new_composition)
|
||||
prompt = str(row.get("prompt") or "")
|
||||
replacements = (
|
||||
(f"Composition: vertical {old_composition}.", f"Composition: {_composition_prompt(new_composition)}."),
|
||||
(f"Composition: {old_composition}.", f"Composition: {_composition_prompt(new_composition)}."),
|
||||
(f"Framed as {old_composition}.", f"Framed as {new_composition}."),
|
||||
)
|
||||
for old_fragment, new_fragment in replacements:
|
||||
if old_fragment in prompt:
|
||||
row["prompt"] = prompt.replace(old_fragment, new_fragment)
|
||||
break
|
||||
row["caption"] = str(row.get("caption") or "").replace(f", {old_composition},", f", {new_composition},")
|
||||
return row
|
||||
return row_camera_policy.apply_contextual_composition(row, subject_kind)
|
||||
|
||||
|
||||
def _camera_scene_directive_for_context(
|
||||
@@ -1744,10 +1721,10 @@ def _camera_scene_directive_for_context(
|
||||
pov_labels: list[str] | None = None,
|
||||
subject_kind: str = "subjects",
|
||||
) -> tuple[str, dict[str, Any]]:
|
||||
parsed = _parse_camera_config(camera_config)
|
||||
directive = scene_camera_adapters.camera_scene_directive_for_context(
|
||||
directive, parsed = row_camera_policy.camera_scene_directive_for_context(
|
||||
scene_text,
|
||||
parsed,
|
||||
composition,
|
||||
camera_config,
|
||||
pov_labels,
|
||||
subject_kind,
|
||||
CAMERA_COMPACT_LABELS,
|
||||
@@ -1756,53 +1733,23 @@ def _camera_scene_directive_for_context(
|
||||
|
||||
|
||||
def _row_camera_subject_kind(row: dict[str, Any]) -> str:
|
||||
subject_type = str(row.get("subject_type") or row.get("primary_subject") or "").lower()
|
||||
if subject_type in ("woman", "adult woman") or subject_type == "single_any":
|
||||
return "woman"
|
||||
if subject_type in ("man", "adult man"):
|
||||
return "man"
|
||||
try:
|
||||
women_count = int(row.get("women_count") or 0)
|
||||
men_count = int(row.get("men_count") or 0)
|
||||
except (TypeError, ValueError):
|
||||
women_count = men_count = 0
|
||||
if women_count == 1 and men_count == 0:
|
||||
return "woman"
|
||||
if women_count == 0 and men_count == 1:
|
||||
return "man"
|
||||
if women_count + men_count == 2:
|
||||
return "couple"
|
||||
return "subjects"
|
||||
return row_camera_policy.row_camera_subject_kind(row)
|
||||
|
||||
|
||||
def _apply_camera_config(row: dict[str, Any], camera_config: str | dict[str, Any] | None) -> dict[str, Any]:
|
||||
directive, parsed = _camera_directive(camera_config)
|
||||
pov_labels = _pov_character_labels(
|
||||
def _camera_pov_labels_for_row(row: dict[str, Any]) -> list[str]:
|
||||
return _pov_character_labels(
|
||||
_character_slot_label_map(_parse_character_cast(row.get("character_cast_slots"))),
|
||||
int(row.get("men_count") or 0) if str(row.get("men_count") or "").isdigit() else 0,
|
||||
)
|
||||
if not pov_labels:
|
||||
pov_labels = [str(label) for label in _list_from(row.get("pov_character_labels")) if str(label).strip()]
|
||||
subject_kind = _row_camera_subject_kind(row)
|
||||
row = _apply_coworking_composition(row, subject_kind)
|
||||
scene_directive, parsed = _camera_scene_directive_for_context(
|
||||
row.get("scene_text") or row.get("source_scene_text") or row.get("scene"),
|
||||
row.get("composition") or row.get("source_composition"),
|
||||
parsed,
|
||||
pov_labels,
|
||||
subject_kind,
|
||||
|
||||
|
||||
def _apply_camera_config(row: dict[str, Any], camera_config: str | dict[str, Any] | None) -> dict[str, Any]:
|
||||
return row_camera_policy.apply_camera_config(
|
||||
row,
|
||||
camera_config,
|
||||
pov_label_resolver=_camera_pov_labels_for_row,
|
||||
compact_labels=CAMERA_COMPACT_LABELS,
|
||||
)
|
||||
row["camera_config"] = parsed
|
||||
row["camera_scene_directive"] = scene_directive
|
||||
row["camera_directive"] = "" if pov_labels else directive
|
||||
combined_directive = " ".join(part for part in (scene_directive, row["camera_directive"]) if part)
|
||||
if not combined_directive:
|
||||
return row
|
||||
row["prompt"] = _insert_positive_directive(row["prompt"], combined_directive)
|
||||
camera_caption = _camera_caption_text(parsed)
|
||||
if camera_caption and not pov_labels:
|
||||
row["caption"] = f"{row.get('caption', '').rstrip()}, {camera_caption}"
|
||||
return row
|
||||
|
||||
|
||||
def _row_seed(seed: int, row_number: int, salt: int = 0) -> int:
|
||||
@@ -3168,13 +3115,7 @@ def _apply_character_profile_to_context(
|
||||
|
||||
|
||||
def _composition_prompt(composition: str) -> str:
|
||||
composition = str(composition or "").strip()
|
||||
if not composition:
|
||||
return composition
|
||||
lower = composition.lower()
|
||||
if lower.startswith("vertical ") or " vertical " in lower or lower.endswith(" vertical"):
|
||||
return composition
|
||||
return f"vertical {composition}"
|
||||
return row_camera_policy.composition_prompt(composition)
|
||||
|
||||
|
||||
def _appearance_for_subject(
|
||||
|
||||
Reference in New Issue
Block a user