Add atlas refine cue seed workflow

This commit is contained in:
2026-07-01 14:10:23 +02:00
parent 83dfecc55b
commit 5f602db06b
34 changed files with 12162 additions and 18 deletions
+17
View File
@@ -0,0 +1,17 @@
#!/usr/bin/env python3
from __future__ import annotations
import sys
from pathlib import Path
ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) in sys.path:
sys.path.remove(str(ROOT))
sys.path.insert(0, str(ROOT))
from krea2_atlas_refine_manifest import main # noqa: E402
if __name__ == "__main__":
raise SystemExit(main(sys.argv[1:]))
+3856 -1
View File
File diff suppressed because it is too large Load Diff
+31 -2
View File
@@ -26,6 +26,27 @@ DEFAULT_OUT_CHANNEL = "sxcp_eval_out"
DEFAULT_IN_CHANNEL = "sxcp_eval_in"
NEGATIVE_OUT_CHANNEL = "sxcp_eval_negative_out"
PROMPT_ORDERS = {"subject_first", "geometry_only", "prompt_order_test"}
PROBE_METADATA_FIELDS = (
"variant_key",
"source_entry_id",
"source_stem",
"cue_axes",
"seed_metadata",
"evidence",
"matrix_evidence",
"selection",
"prompt_source",
"reference_images",
"notes",
)
BATCH_METADATA_FIELDS = (
"subject_id",
"variant_key",
"source_entry_id",
"source_stem",
"source_prompt_sha256",
"selection",
)
class BatchError(ValueError):
@@ -80,7 +101,11 @@ def _validate_probe(raw: Any, index: int) -> dict[str, str]:
if not text:
raise BatchError(f"probes[{index}].text is required")
_validate_no_negative_channel(text, field=f"probes[{index}].text")
return {"id": probe_id, "prompt_order": prompt_order, "text": text}
probe: dict[str, Any] = {"id": probe_id, "prompt_order": prompt_order, "text": text}
for field in PROBE_METADATA_FIELDS:
if field in raw:
probe[field] = raw[field]
return probe
def _validate_image_path(value: Any, *, field: str) -> str:
@@ -111,12 +136,16 @@ def load_batch(path: Path) -> dict[str, Any]:
if not isinstance(probes_raw, list) or not probes_raw:
raise BatchError("probes must be a non-empty list")
probes = [_validate_probe(raw, index) for index, raw in enumerate(probes_raw)]
return {
loaded = {
"seed": seed,
"channel_out": channel_out,
"channel_in": channel_in,
"probes": probes,
}
for field in BATCH_METADATA_FIELDS:
if field in batch:
loaded[field] = batch[field]
return loaded
def load_results(path: Path) -> dict[str, Any]: