Support item template route metadata
This commit is contained in:
+79
-16
@@ -45,7 +45,7 @@ try:
|
||||
sanitize_hardcore_axis_values as _sanitize_hardcore_axis_values,
|
||||
sanitize_hardcore_environment_anchors as _sanitize_hardcore_environment_anchors,
|
||||
)
|
||||
from .hardcore_action_metadata import source_hardcore_action_family
|
||||
from .hardcore_action_metadata import normalize_hardcore_action_family, source_hardcore_action_family
|
||||
from .hardcore_role_graphs import build_hardcore_role_graph
|
||||
except ImportError: # Allows local smoke tests with `python -c`.
|
||||
from category_library import (
|
||||
@@ -85,7 +85,7 @@ except ImportError: # Allows local smoke tests with `python -c`.
|
||||
sanitize_hardcore_axis_values as _sanitize_hardcore_axis_values,
|
||||
sanitize_hardcore_environment_anchors as _sanitize_hardcore_environment_anchors,
|
||||
)
|
||||
from hardcore_action_metadata import source_hardcore_action_family
|
||||
from hardcore_action_metadata import normalize_hardcore_action_family, source_hardcore_action_family
|
||||
from hardcore_role_graphs import build_hardcore_role_graph
|
||||
|
||||
|
||||
@@ -300,6 +300,53 @@ def _item_name(item: Any) -> str:
|
||||
return _item_text(item)
|
||||
|
||||
|
||||
def _template_metadata(item: Any) -> dict[str, Any]:
|
||||
if not isinstance(item, dict):
|
||||
return {}
|
||||
metadata: dict[str, Any] = {}
|
||||
for key in (
|
||||
"action_family",
|
||||
"action_type",
|
||||
"family",
|
||||
"position_family",
|
||||
"position_key",
|
||||
"position_keys",
|
||||
"formatter_hint",
|
||||
):
|
||||
if key in item:
|
||||
metadata[key] = item[key]
|
||||
return metadata
|
||||
|
||||
|
||||
def _template_position_family(metadata: dict[str, Any]) -> str:
|
||||
return _normalize_hardcore_position_family(
|
||||
metadata.get("position_family") or metadata.get("family"),
|
||||
"",
|
||||
)
|
||||
|
||||
|
||||
def _template_position_keys(metadata: dict[str, Any]) -> list[str]:
|
||||
keys: list[Any] = []
|
||||
if metadata.get("position_keys") is not None:
|
||||
raw_keys = metadata.get("position_keys")
|
||||
keys.extend(raw_keys if isinstance(raw_keys, list) else [raw_keys])
|
||||
if metadata.get("position_key") is not None:
|
||||
keys.append(metadata.get("position_key"))
|
||||
return _normalize_hardcore_position_values(keys)
|
||||
|
||||
|
||||
def _template_action_family(metadata: dict[str, Any]) -> str:
|
||||
return normalize_hardcore_action_family(metadata.get("action_family") or metadata.get("action_type"), "")
|
||||
|
||||
|
||||
def _merge_position_keys(primary: list[str], fallback: list[str]) -> list[str]:
|
||||
merged: list[str] = []
|
||||
for key in [*primary, *fallback]:
|
||||
if key and key not in merged:
|
||||
merged.append(key)
|
||||
return merged
|
||||
|
||||
|
||||
def _oral_acts_for_position(values: list[Any], position: str) -> list[Any]:
|
||||
position_text = str(position or "").lower()
|
||||
if not position_text:
|
||||
@@ -503,11 +550,12 @@ def _compose_item(
|
||||
item: Any,
|
||||
women_count: int = 1,
|
||||
men_count: int = 1,
|
||||
) -> tuple[str, str, dict[str, str]]:
|
||||
) -> tuple[str, str, dict[str, str], dict[str, Any]]:
|
||||
templates = _template_list(category, subcategory, item, "item_templates")
|
||||
axes = _merged_axes(category, subcategory, item)
|
||||
if templates and axes:
|
||||
template = _entry_text(_weighted_choice(rng, _compatible_entries(templates, women_count, men_count)))
|
||||
template_entry = _weighted_choice(rng, _compatible_entries(templates, women_count, men_count))
|
||||
template = _entry_text(template_entry)
|
||||
fields = [key for _, key, _, _ in Formatter().parse(template) if key]
|
||||
unique_fields = list(dict.fromkeys(fields))
|
||||
axis_values: dict[str, str] = {}
|
||||
@@ -535,8 +583,8 @@ def _compose_item(
|
||||
axis_values[name] = _entry_text(_weighted_choice(rng, values))
|
||||
item_text = _format(template, axis_values).strip()
|
||||
item_name = _item_name(item) or subcategory["name"]
|
||||
return item_text, item_name, axis_values
|
||||
return _item_text(item), _item_name(item), {}
|
||||
return item_text, item_name, axis_values, _template_metadata(template_entry)
|
||||
return _item_text(item), _item_name(item), {}, _template_metadata(item)
|
||||
|
||||
|
||||
def _choose_text(rng: random.Random, items: list[Any]) -> str:
|
||||
@@ -3722,7 +3770,14 @@ def _build_custom_row(
|
||||
content_rng = _axis_rng(seed_config, content_axis, seed, row_number)
|
||||
items = _list_from(subcategory.get("items", [subcategory["name"]]))
|
||||
item = _weighted_choice(content_rng, items)
|
||||
item_text, item_name, item_axis_values = _compose_item(content_rng, category, subcategory, item, women_count, men_count)
|
||||
item_text, item_name, item_axis_values, item_template_metadata = _compose_item(
|
||||
content_rng,
|
||||
category,
|
||||
subcategory,
|
||||
item,
|
||||
women_count,
|
||||
men_count,
|
||||
)
|
||||
is_pose_category = _is_pose_content_category(category, subcategory)
|
||||
if is_pose_category:
|
||||
item_text = _sanitize_hardcore_environment_anchors(item_text)
|
||||
@@ -3868,22 +3923,29 @@ def _build_custom_row(
|
||||
position_key = ""
|
||||
action_family = ""
|
||||
if is_pose_category:
|
||||
position_family = _hardcore_source_position_family(subcategory, parsed_hardcore_position_config)
|
||||
position_keys = _hardcore_position_keys(
|
||||
template_position_family = _template_position_family(item_template_metadata)
|
||||
position_family = template_position_family or _hardcore_source_position_family(
|
||||
subcategory,
|
||||
parsed_hardcore_position_config,
|
||||
)
|
||||
inferred_position_keys = _hardcore_position_keys(
|
||||
item_text,
|
||||
source_role_graph,
|
||||
source_composition,
|
||||
pose,
|
||||
axis_values=item_axis_values,
|
||||
)
|
||||
position_keys = _merge_position_keys(_template_position_keys(item_template_metadata), inferred_position_keys)
|
||||
position_key = position_keys[0] if position_keys else ""
|
||||
action_family = source_hardcore_action_family(
|
||||
position_family,
|
||||
source_role_graph,
|
||||
item_text,
|
||||
source_composition,
|
||||
item_axis_values,
|
||||
)
|
||||
action_family = _template_action_family(item_template_metadata)
|
||||
if not action_family:
|
||||
action_family = source_hardcore_action_family(
|
||||
position_family,
|
||||
source_role_graph,
|
||||
item_text,
|
||||
source_composition,
|
||||
item_axis_values,
|
||||
)
|
||||
|
||||
negative_prompt = str(_merged_field(category, subcategory, item, "negative_prompt", g.NEGATIVE_PROMPT))
|
||||
positive_suffix = str(_merged_field(category, subcategory, item, "positive_suffix", GENERIC_POSITIVE_SUFFIX))
|
||||
@@ -3991,6 +4053,7 @@ def _build_custom_row(
|
||||
"positive_suffix": positive_suffix,
|
||||
"custom_item": item_name,
|
||||
"item_axis_values": item_axis_values,
|
||||
"item_template_metadata": item_template_metadata,
|
||||
"scene_text": scene,
|
||||
"location_config": parsed_location_config if _location_config_active(parsed_location_config) else {},
|
||||
"pose": pose,
|
||||
|
||||
Reference in New Issue
Block a user