From 55900e7c43ec85b4649a8fa23832f391eec5ca49 Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Fri, 3 Apr 2026 11:27:11 +0200 Subject: [PATCH] feat: 8 resolution slots with per-slot seed + node outputs seed - Resolution entries expanded from 6 to 8 fixed slots - Each slot now stores [w, h, seed] (migrates old [w, h] entries to [w, h, 0]) - UI adds seed number input + casino randomize button per row - ProjectResolution node now outputs (width, height, seed) instead of (width, height) Co-Authored-By: Claude Sonnet 4.6 --- project_loader.py | 13 +++++---- tab_batch_ng.py | 56 ++++++++++++++++++++++++------------ tests/test_project_loader.py | 37 ++++++++++++++++-------- 3 files changed, 70 insertions(+), 36 deletions(-) diff --git a/project_loader.py b/project_loader.py index 9dbaff3..11bfb78 100644 --- a/project_loader.py +++ b/project_loader.py @@ -311,8 +311,8 @@ class ProjectResolution: }, } - RETURN_TYPES = ("INT", "INT") - RETURN_NAMES = ("width", "height") + RETURN_TYPES = ("INT", "INT", "INT") + RETURN_NAMES = ("width", "height", "seed") FUNCTION = "fetch_resolution" CATEGORY = "JSON Manager/project" OUTPUT_NODE = False @@ -332,20 +332,21 @@ class ProjectResolution: data = _fetch_data(manager_url, project_name, file_name, sequence_number) if data.get("error") in ("http_error", "network_error", "parse_error"): logger.warning("ProjectResolution.fetch_resolution failed: %s", data.get("message")) - return (512, 512) + return (512, 512, 0) series = data.get(key_name) if not isinstance(series, list) or len(series) == 0: logger.warning("ProjectResolution: key '%s' is not a resolution series", key_name) - return (512, 512) + return (512, 512, 0) clamped = max(0, min(index, len(series) - 1)) entry = series[clamped] if not isinstance(entry, (list, tuple)) or len(entry) < 2: logger.warning("ProjectResolution: entry at index %d is malformed: %r", clamped, entry) - return (512, 512) + return (512, 512, 0) - return (to_int(entry[0]), to_int(entry[1])) + seed = to_int(entry[2]) if len(entry) >= 3 else 0 + return (to_int(entry[0]), to_int(entry[1]), seed) # --- Mappings --- diff --git a/tab_batch_ng.py b/tab_batch_ng.py index 7787efa..647d470 100644 --- a/tab_batch_ng.py +++ b/tab_batch_ng.py @@ -553,34 +553,54 @@ def _render_sequence_card(i, seq, batch_list, data, file_path, state, dict_textarea('Specific Negative', seq, 'negative').classes( 'w-full q-mt-sm').props('outlined rows=2') - # --- Resolutions (6 fixed slots) --- + # --- Resolutions (8 fixed slots) --- ui.label('Resolutions').classes('text-caption text-weight-bold q-mt-md') - if 'resolutions' not in seq or len(seq.get('resolutions', [])) < 6: - resolutions = seq.setdefault('resolutions', []) - while len(resolutions) < 6: - resolutions.append([512, 512]) + resolutions = seq.setdefault('resolutions', []) + changed = False + while len(resolutions) < 8: + resolutions.append([512, 512, 0]) + changed = True + # Migrate old [w, h] entries to [w, h, seed] + for i, entry in enumerate(resolutions): + if len(entry) < 3: + resolutions[i] = list(entry) + [0] + changed = True + if changed: commit() - resolutions = seq['resolutions'] - for idx in range(6): + for idx in range(8): entry = resolutions[idx] - with ui.row().classes('items-center w-full q-mt-xs'): - ui.label(str(idx)).classes('text-caption').style('min-width:20px') - w_inp = ui.number(value=int(entry[0]), min=1, step=1, label='W').classes( - 'col').props('outlined dense hide-bottom-space') - h_inp = ui.number(value=int(entry[1]), min=1, step=1, label='H').classes( - 'col').props('outlined dense hide-bottom-space') + with ui.row().classes('items-center w-full q-mt-xs no-wrap'): + ui.label(str(idx)).classes('text-caption').style('min-width:16px') + w_inp = ui.number(value=int(entry[0]), min=1, step=1, label='W').style( + 'width:70px').props('outlined dense hide-bottom-space') + h_inp = ui.number(value=int(entry[1]), min=1, step=1, label='H').style( + 'width:70px').props('outlined dense hide-bottom-space') + seed_inp = ui.number(value=int(entry[2]), min=0, step=1, label='Seed').style( + 'flex:1; min-width:60px').props('outlined dense hide-bottom-space') - def _sync_wh(i=idx, wi=w_inp, hi=h_inp): + def _sync_entry(i=idx, wi=w_inp, hi=h_inp, si=seed_inp): seq['resolutions'][i] = [ int(wi.value) if wi.value else 512, int(hi.value) if hi.value else 512, + int(si.value) if si.value else 0, ] commit() - w_inp.on('blur', lambda _, s=_sync_wh: s()) - w_inp.on('update:model-value', lambda _, s=_sync_wh: s()) - h_inp.on('blur', lambda _, s=_sync_wh: s()) - h_inp.on('update:model-value', lambda _, s=_sync_wh: s()) + def _randomize(si=seed_inp, i=idx): + import random + si.value = random.randint(0, 2**32 - 1) + seq['resolutions'][i][2] = int(si.value) + commit() + + ui.button(icon='casino', on_click=_randomize).props( + 'flat dense round').classes('q-ml-xs') + + w_inp.on('blur', lambda _, s=_sync_entry: s()) + w_inp.on('update:model-value', lambda _, s=_sync_entry: s()) + h_inp.on('blur', lambda _, s=_sync_entry: s()) + h_inp.on('update:model-value', lambda _, s=_sync_entry: s()) + seed_inp.on('blur', lambda _, s=_sync_entry: s()) + seed_inp.on('update:model-value', lambda _, s=_sync_entry: s()) with splitter.after: # Mode diff --git a/tests/test_project_loader.py b/tests/test_project_loader.py index 09897d6..9404d1c 100644 --- a/tests/test_project_loader.py +++ b/tests/test_project_loader.py @@ -353,46 +353,59 @@ class TestProjectResolution: assert "index" in inputs["required"] assert inputs["required"]["index"][0] == "INT" - def test_two_outputs(self): + def test_three_outputs(self): from project_loader import ProjectResolution - assert ProjectResolution.RETURN_TYPES == ("INT", "INT") - assert ProjectResolution.RETURN_NAMES == ("width", "height") + assert ProjectResolution.RETURN_TYPES == ("INT", "INT", "INT") + assert ProjectResolution.RETURN_NAMES == ("width", "height", "seed") def test_fetch_resolution_basic(self): from project_loader import ProjectResolution node = ProjectResolution() - data = {"resolutions": [[512, 512], [768, 1344], [1344, 768]]} + data = {"resolutions": [[512, 512, 0], [768, 1344, 12345], [1344, 768, 99]]} with patch("project_loader._fetch_data", return_value=data): result = node.fetch_resolution( source_label="src", key_name="resolutions", index=1, manager_url="http://localhost:8080", project_name="p", file_name="f", sequence_number=1, ) - assert result == (768, 1344) + assert result == (768, 1344, 12345) def test_fetch_resolution_index_zero(self): from project_loader import ProjectResolution node = ProjectResolution() - data = {"resolutions": [[512, 512], [1024, 1024]]} + data = {"resolutions": [[512, 512, 42], [1024, 1024, 0]]} with patch("project_loader._fetch_data", return_value=data): result = node.fetch_resolution( source_label="src", key_name="resolutions", index=0, manager_url="http://localhost:8080", project_name="p", file_name="f", sequence_number=1, ) - assert result == (512, 512) + assert result == (512, 512, 42) def test_fetch_resolution_clamps_on_out_of_bounds(self): from project_loader import ProjectResolution node = ProjectResolution() - data = {"resolutions": [[512, 512], [1024, 1024]]} + data = {"resolutions": [[512, 512, 0], [1024, 1024, 7]]} with patch("project_loader._fetch_data", return_value=data): result = node.fetch_resolution( source_label="src", key_name="resolutions", index=99, manager_url="http://localhost:8080", project_name="p", file_name="f", sequence_number=1, ) - assert result == (1024, 1024) # last entry + assert result == (1024, 1024, 7) # last entry + + def test_fetch_resolution_old_format_no_seed(self): + """Old [w, h] entries without seed should return seed=0.""" + from project_loader import ProjectResolution + node = ProjectResolution() + data = {"resolutions": [[576, 384], [960, 640]]} + with patch("project_loader._fetch_data", return_value=data): + result = node.fetch_resolution( + source_label="src", key_name="resolutions", index=0, + manager_url="http://localhost:8080", project_name="p", + file_name="f", sequence_number=1, + ) + assert result == (576, 384, 0) def test_fetch_resolution_missing_key_returns_defaults(self): from project_loader import ProjectResolution @@ -403,7 +416,7 @@ class TestProjectResolution: manager_url="http://localhost:8080", project_name="p", file_name="f", sequence_number=1, ) - assert result == (512, 512) + assert result == (512, 512, 0) def test_fetch_resolution_network_error_returns_defaults(self): from project_loader import ProjectResolution @@ -415,7 +428,7 @@ class TestProjectResolution: manager_url="http://localhost:8080", project_name="p", file_name="f", sequence_number=1, ) - assert result == (512, 512) + assert result == (512, 512, 0) def test_fetch_resolution_malformed_entry_returns_defaults(self): from project_loader import ProjectResolution @@ -427,7 +440,7 @@ class TestProjectResolution: manager_url="http://localhost:8080", project_name="p", file_name="f", sequence_number=1, ) - assert result == (512, 512) + assert result == (512, 512, 0) def test_category(self): from project_loader import ProjectResolution