Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 1bbff3b1df |
@@ -1,3 +0,0 @@
|
|||||||
__pycache__/
|
|
||||||
.pytest_cache/
|
|
||||||
.worktrees/
|
|
||||||
@@ -1,336 +1,121 @@
|
|||||||
<p align="center">
|
# 🎛️ AI Settings Manager for ComfyUI
|
||||||
<svg xmlns="http://www.w3.org/2000/svg" width="480" height="100" viewBox="0 0 480 100">
|
|
||||||
<defs>
|
|
||||||
<linearGradient id="bg" x1="0%" y1="0%" x2="100%" y2="100%">
|
|
||||||
<stop offset="0%" style="stop-color:#1a1a2e;stop-opacity:1" />
|
|
||||||
<stop offset="100%" style="stop-color:#16213e;stop-opacity:1" />
|
|
||||||
</linearGradient>
|
|
||||||
<linearGradient id="accent" x1="0%" y1="0%" x2="100%" y2="0%">
|
|
||||||
<stop offset="0%" style="stop-color:#e94560" />
|
|
||||||
<stop offset="100%" style="stop-color:#0f3460" />
|
|
||||||
</linearGradient>
|
|
||||||
</defs>
|
|
||||||
<rect width="480" height="100" rx="16" fill="url(#bg)" />
|
|
||||||
<rect x="20" y="72" width="440" height="3" rx="1.5" fill="url(#accent)" opacity="0.6" />
|
|
||||||
<text x="240" y="36" text-anchor="middle" fill="#e94560" font-family="monospace" font-size="13" font-weight="bold">{ JSON }</text>
|
|
||||||
<text x="240" y="60" text-anchor="middle" fill="#eee" font-family="sans-serif" font-size="22" font-weight="bold">ComfyUI JSON Manager</text>
|
|
||||||
<text x="240" y="90" text-anchor="middle" fill="#888" font-family="sans-serif" font-size="11">Visual dashboard & dynamic nodes for AI video workflows</text>
|
|
||||||
</svg>
|
|
||||||
</p>
|
|
||||||
|
|
||||||
<p align="center">
|
A 100% vibecoded, visual dashboard for managing, versioning, and batch-processing JSON configuration files used in AI video generation workflows (I2V, VACE).
|
||||||
<img src="https://img.shields.io/badge/License-Apache_2.0-blue.svg" alt="License" />
|
|
||||||
<img src="https://img.shields.io/badge/Python-3.10%2B-green" alt="Python" />
|
|
||||||
<img src="https://img.shields.io/badge/Built%20with-NiceGUI-FF4B4B" alt="NiceGUI" />
|
|
||||||
<img src="https://img.shields.io/badge/ComfyUI-Custom%20Nodes-purple" alt="ComfyUI" />
|
|
||||||
</p>
|
|
||||||
|
|
||||||
A visual dashboard for managing, versioning, and batch-processing JSON configuration files used in AI video generation workflows (I2V, VACE). Two parts:
|
This tool consists of two parts:
|
||||||
|
1. **Streamlit Web Interface:** A Dockerized editor to manage prompts, LoRAs, settings, and **branching history**.
|
||||||
|
2. **ComfyUI Custom Nodes:** A set of nodes to read these JSON files (including custom keys) directly into your workflows.
|
||||||
|
|
||||||
1. **NiceGUI Web Interface** — Dockerized editor for prompts, LoRAs, settings, and branching history
|
  
|
||||||
2. **ComfyUI Custom Nodes** — Read JSON files directly into workflows, including a dynamic node that auto-discovers keys
|
---
|
||||||
|
|
||||||
|
## ✨ Features
|
||||||
|
|
||||||
|
### 📝 Single File Editor
|
||||||
|
* **Visual Interface:** Edit Prompts, Negative Prompts, Seeds, LoRAs, and advanced settings (Camera, FLF, VACE params) without touching raw JSON.
|
||||||
|
* **🔧 Custom Parameters:** Add arbitrary key-value pairs (e.g., `controlnet_strength`, `my_custom_value`) that persist and can be read by ComfyUI.
|
||||||
|
* **Conflict Protection:** Prevents accidental overwrites if the file is modified by another tab or process.
|
||||||
|
* **Snippet Library:** Save reusable prompt fragments (e.g., "Cinematic Lighting", "Anime Style") and append them with one click.
|
||||||
|
|
||||||
|
### 🚀 Batch Processor
|
||||||
|
* **Sequence Management:** Create unlimited sequences within a single JSON file.
|
||||||
|
* **Smart Import:** Copy settings from **any other file** or **history entry** into your current batch sequence.
|
||||||
|
* **Custom Keys per Shot:** Define unique parameters for specific shots in a batch (e.g., Shot 1 has `fog: 0.5`, Shot 2 has `fog: 0.0`).
|
||||||
|
* **Promote to Single:** One-click convert a specific batch sequence back into a standalone Single File.
|
||||||
|
|
||||||
|
### 🕒 Visual Timeline (New!)
|
||||||
|
* **Git-Style Branching:** A dedicated tab visualizes your edit history as a **horizontal node graph**.
|
||||||
|
* **Non-Destructive:** If you jump back to an old version and make changes, the system automatically **forks a new branch** so you never lose history.
|
||||||
|
* **Visual Diff:** Inspect any past version and see a "Delta View" highlighting exactly what changed (e.g., `Seed: 100 -> 555`) compared to your current state.
|
||||||
|
* **Interactive Mode (WIP):** A zoomed-out, interactive canvas to explore complex history trees.
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## Features
|
## 🛠️ Installation
|
||||||
|
|
||||||
<table>
|
### 1. Unraid / Docker Setup (The Manager)
|
||||||
<tr>
|
This tool is designed to run as a lightweight container on Unraid.
|
||||||
<td width="50%">
|
|
||||||
|
|
||||||
<h3>
|
|
||||||
<svg xmlns="http://www.w3.org/2000/svg" width="20" height="20" viewBox="0 0 20 20"><rect width="20" height="20" rx="4" fill="#e94560"/><text x="10" y="14" text-anchor="middle" fill="#fff" font-size="11">B</text></svg>
|
|
||||||
Batch Processor
|
|
||||||
</h3>
|
|
||||||
|
|
||||||
- Unlimited sequences within a single JSON file
|
|
||||||
- Import settings from any file or history entry
|
|
||||||
- Per-shot custom keys (e.g. Shot 1: `fog: 0.5`, Shot 2: `fog: 0.0`)
|
|
||||||
- Clone, reorder, and manage sequences visually
|
|
||||||
- Conflict protection against external file modifications
|
|
||||||
- Snippet library for reusable prompt fragments
|
|
||||||
|
|
||||||
</td>
|
|
||||||
</tr>
|
|
||||||
<tr>
|
|
||||||
<td width="50%">
|
|
||||||
|
|
||||||
<h3>
|
|
||||||
<svg xmlns="http://www.w3.org/2000/svg" width="20" height="20" viewBox="0 0 20 20"><rect width="20" height="20" rx="4" fill="#533483"/><text x="10" y="14" text-anchor="middle" fill="#fff" font-size="11">T</text></svg>
|
|
||||||
Visual Timeline
|
|
||||||
</h3>
|
|
||||||
|
|
||||||
- Git-style branching with horizontal node graph
|
|
||||||
- Non-destructive: forking on old-version edits preserves all history
|
|
||||||
- Visual diff highlighting changes between any two versions
|
|
||||||
- Restore any past state with one click
|
|
||||||
|
|
||||||
</td>
|
|
||||||
<td width="50%">
|
|
||||||
|
|
||||||
<h3>
|
|
||||||
<svg xmlns="http://www.w3.org/2000/svg" width="20" height="20" viewBox="0 0 20 20"><rect width="20" height="20" rx="4" fill="#2b9348"/><text x="10" y="14" text-anchor="middle" fill="#fff" font-size="11">D</text></svg>
|
|
||||||
Dynamic Node (New)
|
|
||||||
</h3>
|
|
||||||
|
|
||||||
- Auto-discovers all JSON keys and exposes them as outputs
|
|
||||||
- No code changes needed when JSON structure evolves
|
|
||||||
- Preserves connections when keys are added on refresh
|
|
||||||
- Native type handling: `int`, `float`, `string`
|
|
||||||
|
|
||||||
</td>
|
|
||||||
</tr>
|
|
||||||
</table>
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Installation
|
|
||||||
|
|
||||||
### 1. Unraid / Docker (NiceGUI Manager)
|
|
||||||
|
|
||||||
|
1. **Prepare a Folder:** Create a folder on your server (e.g., `/mnt/user/appdata/ai-manager/`) and place the following files inside:
|
||||||
|
* `app.py`
|
||||||
|
* `utils.py`
|
||||||
|
* `history_tree.py` (New logic engine)
|
||||||
|
* `tab_single.py`
|
||||||
|
* `tab_batch.py`
|
||||||
|
* `tab_timeline.py`
|
||||||
|
* `tab_timeline_wip.py`
|
||||||
|
2. **Add Container in Unraid:**
|
||||||
|
* **Repository:** `python:3.12-slim`
|
||||||
|
* **Network:** `Bridge`
|
||||||
|
* **WebUI:** `http://[IP]:[PORT:8501]`
|
||||||
|
3. **Path Mappings:**
|
||||||
|
* **App Location:** Container `/app` ↔ Host `/mnt/user/appdata/ai-manager/`
|
||||||
|
* **Project Data:** Container `/mnt/user/` ↔ Host `/mnt/user/` (Your media/JSON location)
|
||||||
|
4. **Post Arguments (Crucial):**
|
||||||
|
Enable "Advanced View" and paste this command to install the required graph engines:
|
||||||
```bash
|
```bash
|
||||||
# Repository: python:3.12-slim
|
/bin/sh -c "apt-get update && apt-get install -y graphviz && pip install streamlit opencv-python-headless graphviz streamlit-agraph && cd /app && streamlit run app.py --server.headless true --server.port 8501"
|
||||||
# Network: Bridge
|
|
||||||
# WebUI: http://[IP]:[PORT:8080]
|
|
||||||
```
|
```
|
||||||
|
|
||||||
**Path Mappings:**
|
### 2. ComfyUI Setup (The Nodes)
|
||||||
| Container | Host | Purpose |
|
1. Navigate to your ComfyUI installation: `ComfyUI/custom_nodes/`
|
||||||
|:---|:---|:---|
|
2. Create a folder named `ComfyUI-JSON-Loader`.
|
||||||
| `/app` | `/mnt/user/appdata/ai-manager/` | App files |
|
3. Place the `json_loader.py` file inside.
|
||||||
| `/mnt/user/` | `/mnt/user/` | Project data / JSON location |
|
4. Restart ComfyUI.
|
||||||
|
|
||||||
**Post Arguments:**
|
|
||||||
```bash
|
|
||||||
/bin/sh -c "apt-get update && apt-get install -y graphviz && \
|
|
||||||
pip install nicegui graphviz requests && \
|
|
||||||
cd /app && python main.py"
|
|
||||||
```
|
|
||||||
|
|
||||||
### 2. ComfyUI (Custom Nodes)
|
|
||||||
|
|
||||||
```bash
|
|
||||||
cd ComfyUI/custom_nodes/
|
|
||||||
git clone <this-repo> ComfyUI-JSON-Manager
|
|
||||||
# Restart ComfyUI
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## ComfyUI Nodes
|
## 🖥️ Usage Guide
|
||||||
|
|
||||||
### Node Overview
|
### The Web Interface
|
||||||
|
Navigate to your container's IP (e.g., `http://192.168.1.100:8501`).
|
||||||
|
|
||||||
<!--
|
* **Custom Parameters:** Scroll to the bottom of the editor (Single or Batch) to find the "🔧 Custom Parameters" section. Type a Key (e.g., `strength`) and Value (e.g., `0.8`) and click "Add".
|
||||||
Diagram: shows JSON file flowing into different node types
|
* **Timeline:** Switch to the **Timeline Tab** to see your version history.
|
||||||
-->
|
* **Restore:** Select a node from the list or click on the graph (WIP tab) to view details. Click "Restore" to revert settings to that point.
|
||||||
<p align="center">
|
* **Branching:** If you restore an old node and click "Save/Snap", a new branch is created automatically.
|
||||||
<svg xmlns="http://www.w3.org/2000/svg" width="720" height="280" viewBox="0 0 720 280">
|
|
||||||
<defs>
|
|
||||||
<linearGradient id="nodeBg" x1="0%" y1="0%" x2="0%" y2="100%">
|
|
||||||
<stop offset="0%" style="stop-color:#2d2d3d" />
|
|
||||||
<stop offset="100%" style="stop-color:#1e1e2e" />
|
|
||||||
</linearGradient>
|
|
||||||
<filter id="shadow">
|
|
||||||
<feDropShadow dx="1" dy="2" stdDeviation="3" flood-opacity="0.3"/>
|
|
||||||
</filter>
|
|
||||||
</defs>
|
|
||||||
|
|
||||||
<!-- JSON File -->
|
### ComfyUI Workflow
|
||||||
<rect x="10" y="100" width="120" height="60" rx="8" fill="#0f3460" filter="url(#shadow)" />
|
Search for "JSON" in ComfyUI to find the new nodes.
|
||||||
<text x="70" y="125" text-anchor="middle" fill="#aaa" font-family="monospace" font-size="10">batch_prompt</text>
|
|
||||||
<text x="70" y="142" text-anchor="middle" fill="#fff" font-family="monospace" font-size="13" font-weight="bold">.json</text>
|
|
||||||
|
|
||||||
<!-- Arrow -->
|
<img width="1251" height="921" alt="image" src="https://github.com/user-attachments/assets/06d567f8-15ee-4011-9b86-d0b43ce1ba74" />
|
||||||
<line x1="130" y1="130" x2="170" y2="130" stroke="#555" stroke-width="2" marker-end="url(#arrowhead)"/>
|
|
||||||
<defs><marker id="arrowhead" markerWidth="8" markerHeight="6" refX="8" refY="3" orient="auto"><polygon points="0 0, 8 3, 0 6" fill="#555"/></marker></defs>
|
|
||||||
|
|
||||||
<!-- Dynamic Node -->
|
#### Standard Nodes
|
||||||
<rect x="180" y="20" width="200" height="70" rx="10" fill="url(#nodeBg)" stroke="#2b9348" stroke-width="2" filter="url(#shadow)" />
|
| Node Name | Description |
|
||||||
<text x="280" y="44" text-anchor="middle" fill="#2b9348" font-family="sans-serif" font-size="12" font-weight="bold">JSON Loader (Dynamic)</text>
|
| :--- | :--- |
|
||||||
<text x="280" y="62" text-anchor="middle" fill="#888" font-family="monospace" font-size="10">auto-discovers keys</text>
|
| **JSON Loader (Standard/I2V)** | Outputs prompts, FLF, Seed, and paths for I2V. |
|
||||||
<text x="280" y="78" text-anchor="middle" fill="#666" font-family="monospace" font-size="9">click Refresh to populate</text>
|
| **JSON Loader (VACE Full)** | Outputs everything above plus VACE integers (frames to skip, schedule, etc.). |
|
||||||
|
| **JSON Loader (LoRAs Only)** | Outputs the 6 LoRA strings. |
|
||||||
|
|
||||||
<!-- Batch I2V Node -->
|
#### Universal Custom Nodes (New!)
|
||||||
<rect x="180" y="105" width="200" height="50" rx="10" fill="url(#nodeBg)" stroke="#e94560" stroke-width="2" filter="url(#shadow)" />
|
These nodes read *any* key you added in the "Custom Parameters" section. They work for both Single files (ignores sequence input) and Batch files (reads specific sequence).
|
||||||
<text x="280" y="127" text-anchor="middle" fill="#e94560" font-family="sans-serif" font-size="12" font-weight="bold">JSON Batch Loader (I2V)</text>
|
|
||||||
<text x="280" y="144" text-anchor="middle" fill="#888" font-family="monospace" font-size="10">prompts, flf, seed, paths</text>
|
|
||||||
|
|
||||||
<!-- Batch VACE Node -->
|
| Node Name | Description |
|
||||||
<rect x="180" y="170" width="200" height="50" rx="10" fill="url(#nodeBg)" stroke="#533483" stroke-width="2" filter="url(#shadow)" />
|
| :--- | :--- |
|
||||||
<text x="280" y="192" text-anchor="middle" fill="#533483" font-family="sans-serif" font-size="12" font-weight="bold">JSON Batch Loader (VACE)</text>
|
| **JSON Loader (Custom 1)** | Reads 1 custom key. Input the key name (e.g., "strength"), outputs the value string. |
|
||||||
<text x="280" y="209" text-anchor="middle" fill="#888" font-family="monospace" font-size="10">+ vace frames, schedule</text>
|
| **JSON Loader (Custom 3)** | Reads 3 custom keys. |
|
||||||
|
| **JSON Loader (Custom 6)** | Reads 6 custom keys. |
|
||||||
|
|
||||||
<!-- Custom Nodes -->
|
#### Batch Nodes
|
||||||
<rect x="180" y="235" width="200" height="40" rx="10" fill="url(#nodeBg)" stroke="#0f3460" stroke-width="2" filter="url(#shadow)" />
|
These nodes require an integer input (Primitive or Batch Indexer) for `sequence_number`.
|
||||||
<text x="280" y="260" text-anchor="middle" fill="#0f3460" font-family="sans-serif" font-size="12" font-weight="bold">JSON Loader (Custom 1/3/6)</text>
|
|
||||||
|
|
||||||
<!-- Output labels -->
|
| Node Name | Description |
|
||||||
<line x1="380" y1="55" x2="420" y2="55" stroke="#2b9348" stroke-width="1.5"/>
|
| :--- | :--- |
|
||||||
<text x="430" y="47" fill="#aaa" font-family="monospace" font-size="9">general_prompt</text>
|
| **JSON Batch Loader (I2V)** | Loads specific sequence data for I2V. |
|
||||||
<text x="430" y="59" fill="#aaa" font-family="monospace" font-size="9">seed (int)</text>
|
| **JSON Batch Loader (VACE)** | Loads specific sequence data for VACE. |
|
||||||
<text x="430" y="71" fill="#aaa" font-family="monospace" font-size="9">my_custom_key ...</text>
|
| **JSON Batch Loader (LoRAs)** | Loads specific LoRAs for that sequence. |
|
||||||
|
|
||||||
<line x1="380" y1="130" x2="420" y2="130" stroke="#e94560" stroke-width="1.5"/>
|
|
||||||
<text x="430" y="127" fill="#aaa" font-family="monospace" font-size="9">general_prompt, camera,</text>
|
|
||||||
<text x="430" y="139" fill="#aaa" font-family="monospace" font-size="9">flf, seed, paths ...</text>
|
|
||||||
|
|
||||||
<line x1="380" y1="195" x2="420" y2="195" stroke="#533483" stroke-width="1.5"/>
|
|
||||||
<text x="430" y="192" fill="#aaa" font-family="monospace" font-size="9">+ frame_to_skip, vace_schedule,</text>
|
|
||||||
<text x="430" y="204" fill="#aaa" font-family="monospace" font-size="9">input_a_frames ...</text>
|
|
||||||
|
|
||||||
<line x1="380" y1="255" x2="420" y2="255" stroke="#0f3460" stroke-width="1.5"/>
|
|
||||||
<text x="430" y="259" fill="#aaa" font-family="monospace" font-size="9">manual key lookup (1-6 slots)</text>
|
|
||||||
</svg>
|
|
||||||
</p>
|
|
||||||
|
|
||||||
### Dynamic Node
|
|
||||||
|
|
||||||
The **JSON Loader (Dynamic)** node reads your JSON file and automatically creates output slots for every key it finds. No code changes needed when your JSON structure evolves.
|
|
||||||
|
|
||||||
**How it works:**
|
|
||||||
1. Enter a `json_path` and `sequence_number`
|
|
||||||
2. Click **Refresh Outputs**
|
|
||||||
3. Outputs appear named after JSON keys, with native types preserved
|
|
||||||
|
|
||||||
<p align="center">
|
|
||||||
<svg xmlns="http://www.w3.org/2000/svg" width="500" height="240" viewBox="0 0 500 240">
|
|
||||||
<defs>
|
|
||||||
<linearGradient id="dynBg" x1="0%" y1="0%" x2="0%" y2="100%">
|
|
||||||
<stop offset="0%" style="stop-color:#353545" />
|
|
||||||
<stop offset="100%" style="stop-color:#252535" />
|
|
||||||
</linearGradient>
|
|
||||||
</defs>
|
|
||||||
|
|
||||||
<!-- Node body -->
|
|
||||||
<rect x="20" y="10" width="240" height="220" rx="10" fill="url(#dynBg)" stroke="#2b9348" stroke-width="2" />
|
|
||||||
<rect x="20" y="10" width="240" height="28" rx="10" fill="#2b9348" />
|
|
||||||
<rect x="20" y="28" width="240" height="10" fill="#2b9348" />
|
|
||||||
<text x="140" y="31" text-anchor="middle" fill="#fff" font-family="sans-serif" font-size="13" font-weight="bold">JSON Loader (Dynamic)</text>
|
|
||||||
|
|
||||||
<!-- Inputs -->
|
|
||||||
<text x="35" y="60" fill="#ccc" font-family="monospace" font-size="10">json_path: /data/prompt.json</text>
|
|
||||||
<text x="35" y="78" fill="#ccc" font-family="monospace" font-size="10">sequence_number: 1</text>
|
|
||||||
|
|
||||||
<!-- Refresh button -->
|
|
||||||
<rect x="45" y="88" width="190" height="24" rx="5" fill="#2b9348" opacity="0.3" stroke="#2b9348" stroke-width="1"/>
|
|
||||||
<text x="140" y="104" text-anchor="middle" fill="#2b9348" font-family="sans-serif" font-size="11" font-weight="bold">Refresh Outputs</text>
|
|
||||||
|
|
||||||
<!-- Output slots -->
|
|
||||||
<circle cx="260" cy="130" r="5" fill="#6bcb77"/>
|
|
||||||
<text x="245" y="134" text-anchor="end" fill="#ccc" font-family="monospace" font-size="10">general_prompt</text>
|
|
||||||
|
|
||||||
<circle cx="260" cy="150" r="5" fill="#6bcb77"/>
|
|
||||||
<text x="245" y="154" text-anchor="end" fill="#ccc" font-family="monospace" font-size="10">negative</text>
|
|
||||||
|
|
||||||
<circle cx="260" cy="170" r="5" fill="#4d96ff"/>
|
|
||||||
<text x="245" y="174" text-anchor="end" fill="#ccc" font-family="monospace" font-size="10">seed</text>
|
|
||||||
|
|
||||||
<circle cx="260" cy="190" r="5" fill="#ff6b6b"/>
|
|
||||||
<text x="245" y="194" text-anchor="end" fill="#ccc" font-family="monospace" font-size="10">flf</text>
|
|
||||||
|
|
||||||
<circle cx="260" cy="210" r="5" fill="#6bcb77"/>
|
|
||||||
<text x="245" y="214" text-anchor="end" fill="#ccc" font-family="monospace" font-size="10">camera</text>
|
|
||||||
|
|
||||||
<!-- Connection lines to downstream -->
|
|
||||||
<line x1="265" y1="130" x2="340" y2="130" stroke="#6bcb77" stroke-width="1.5"/>
|
|
||||||
<line x1="265" y1="170" x2="340" y2="165" stroke="#4d96ff" stroke-width="1.5"/>
|
|
||||||
|
|
||||||
<!-- Downstream node -->
|
|
||||||
<rect x="340" y="115" width="140" height="65" rx="8" fill="url(#dynBg)" stroke="#555" stroke-width="1.5" />
|
|
||||||
<text x="410" y="137" text-anchor="middle" fill="#aaa" font-family="sans-serif" font-size="11">KSampler</text>
|
|
||||||
<circle cx="340" cy="130" r="4" fill="#6bcb77"/>
|
|
||||||
<text x="350" y="150" fill="#777" font-family="monospace" font-size="9">positive</text>
|
|
||||||
<circle cx="340" cy="165" r="4" fill="#4d96ff"/>
|
|
||||||
<text x="350" y="170" fill="#777" font-family="monospace" font-size="9">seed</text>
|
|
||||||
|
|
||||||
<!-- Legend -->
|
|
||||||
<circle cx="30" y="248" r="4" fill="#6bcb77"/>
|
|
||||||
<text x="40" y="252" fill="#888" font-family="monospace" font-size="9">STRING</text>
|
|
||||||
<circle cx="100" y="248" r="4" fill="#4d96ff"/>
|
|
||||||
<text x="110" y="252" fill="#888" font-family="monospace" font-size="9">INT</text>
|
|
||||||
<circle cx="155" y="248" r="4" fill="#ff6b6b"/>
|
|
||||||
<text x="165" y="252" fill="#888" font-family="monospace" font-size="9">FLOAT</text>
|
|
||||||
</svg>
|
|
||||||
</p>
|
|
||||||
|
|
||||||
**Type handling:** Values keep their native Python type — `int` stays `int`, `float` stays `float`, booleans become `"true"`/`"false"` strings, everything else becomes `string`. The `*` (any) output type allows connecting to any input.
|
|
||||||
|
|
||||||
**Refreshing is safe:** Clicking Refresh after adding new keys to your JSON preserves all existing connections. Only removed keys get disconnected.
|
|
||||||
|
|
||||||
### Standard & Batch Nodes
|
|
||||||
|
|
||||||
| Node | Outputs | Use Case |
|
|
||||||
|:---|:---|:---|
|
|
||||||
| **JSON Loader (Standard/I2V)** | prompts, flf, seed, paths | Single-file I2V workflows |
|
|
||||||
| **JSON Loader (VACE Full)** | above + VACE integers | Single-file VACE workflows |
|
|
||||||
| **JSON Loader (LoRAs Only)** | 6 LoRA strings | Single-file LoRA loading |
|
|
||||||
| **JSON Batch Loader (I2V)** | prompts, flf, seed, paths | Batch I2V with sequence_number |
|
|
||||||
| **JSON Batch Loader (VACE)** | above + VACE integers | Batch VACE with sequence_number |
|
|
||||||
| **JSON Batch Loader (LoRAs)** | 6 LoRA strings | Batch LoRA loading |
|
|
||||||
| **JSON Loader (Custom 1/3/6)** | 1, 3, or 6 string values | Manual key lookup by name |
|
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## Web Interface Usage
|
## 📂 File Structure
|
||||||
|
|
||||||
Navigate to your container's IP (e.g., `http://192.168.1.100:8080`).
|
```text
|
||||||
|
/ai-manager
|
||||||
**Path navigation** supports case-insensitive matching — typing `/media/P5/myFolder` will resolve to `/media/p5/MyFolder` automatically.
|
├── app.py # Main entry point & Tab controller
|
||||||
|
├── utils.py # I/O logic, Config, and Defaults
|
||||||
- **Custom Parameters:** Scroll to "Custom Parameters" in any editor tab. Type a key and value, click Add.
|
├── history_tree.py # Graph logic, Branching engine, Graphviz generator
|
||||||
- **Timeline:** Switch to the Timeline tab to see version history as a graph. Restore any version, and new edits fork a branch automatically.
|
├── tab_single.py # Single Editor UI
|
||||||
- **Snippets:** Save reusable prompt fragments and append them with one click.
|
├── tab_batch.py # Batch Processor UI
|
||||||
|
├── tab_timeline.py # Stable Timeline UI (Compact Graphviz + Diff Inspector)
|
||||||
---
|
├── tab_timeline_wip.py # Interactive Timeline UI (Streamlit Agraph)
|
||||||
|
└── json_loader.py # ComfyUI Custom Node script
|
||||||
## JSON Format
|
|
||||||
|
|
||||||
```jsonc
|
|
||||||
{
|
|
||||||
"batch_data": [
|
|
||||||
{
|
|
||||||
"sequence_number": 1,
|
|
||||||
"general_prompt": "A cinematic scene...",
|
|
||||||
"negative": "blurry, low quality",
|
|
||||||
"seed": 42,
|
|
||||||
"flf": 0.5,
|
|
||||||
"camera": "pan_left",
|
|
||||||
"video file path": "/data/input.mp4",
|
|
||||||
"reference image path": "/data/ref.png",
|
|
||||||
"my_custom_key": "any value"
|
|
||||||
// ... any additional keys are auto-discovered by the Dynamic node
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## File Structure
|
|
||||||
|
|
||||||
```
|
|
||||||
ComfyUI-JSON-Manager/
|
|
||||||
├── __init__.py # ComfyUI entry point, exports nodes + WEB_DIRECTORY
|
|
||||||
├── json_loader.py # All ComfyUI node classes + /json_manager/get_keys API
|
|
||||||
├── web/
|
|
||||||
│ └── json_dynamic.js # Frontend extension for Dynamic node (refresh, show/hide)
|
|
||||||
├── main.py # NiceGUI web UI entry point & navigator
|
|
||||||
├── state.py # Application state management
|
|
||||||
├── utils.py # I/O, config, defaults, case-insensitive path resolver
|
|
||||||
├── history_tree.py # Git-style branching engine
|
|
||||||
├── tab_batch_ng.py # Batch processor UI (NiceGUI)
|
|
||||||
├── tab_timeline_ng.py # Visual timeline UI (NiceGUI)
|
|
||||||
├── tab_comfy_ng.py # ComfyUI server monitor (NiceGUI)
|
|
||||||
├── tab_raw_ng.py # Raw JSON editor (NiceGUI)
|
|
||||||
└── tests/
|
|
||||||
├── test_json_loader.py
|
|
||||||
├── test_utils.py
|
|
||||||
└── test_history_tree.py
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## License
|
|
||||||
|
|
||||||
[Apache 2.0](LICENSE)
|
|
||||||
|
|||||||
+2
-7
@@ -1,8 +1,3 @@
|
|||||||
from .project_loader import PROJECT_NODE_CLASS_MAPPINGS, PROJECT_NODE_DISPLAY_NAME_MAPPINGS
|
from .json_loader import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = PROJECT_NODE_CLASS_MAPPINGS
|
__all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS']
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = PROJECT_NODE_DISPLAY_NAME_MAPPINGS
|
|
||||||
|
|
||||||
WEB_DIRECTORY = "./web"
|
|
||||||
|
|
||||||
__all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS', 'WEB_DIRECTORY']
|
|
||||||
|
|||||||
-176
@@ -1,176 +0,0 @@
|
|||||||
"""REST API endpoints for ComfyUI to query project data from JSON files.
|
|
||||||
|
|
||||||
All endpoints are read-only. Mounted on the NiceGUI/FastAPI server.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import time
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from fastapi import HTTPException, Query
|
|
||||||
from fastapi.responses import FileResponse
|
|
||||||
from nicegui import app
|
|
||||||
|
|
||||||
from db import ProjectDB
|
|
||||||
from utils import load_json, load_config, resolve_path_case_insensitive, KEY_BATCH_DATA, KEY_SEQUENCE_NUMBER
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# The DB instance is set by register_api_routes()
|
|
||||||
_db: ProjectDB | None = None
|
|
||||||
|
|
||||||
|
|
||||||
def register_api_routes(db: ProjectDB) -> None:
|
|
||||||
"""Register all REST API routes with the NiceGUI/FastAPI app."""
|
|
||||||
global _db
|
|
||||||
_db = db
|
|
||||||
|
|
||||||
app.add_api_route("/api/projects", _list_projects, methods=["GET"])
|
|
||||||
app.add_api_route("/api/active-project", _get_active_project, methods=["GET"])
|
|
||||||
app.add_api_route("/api/projects/{name}", _get_project, methods=["GET"])
|
|
||||||
app.add_api_route("/api/projects/{name}/files", _list_files, methods=["GET"])
|
|
||||||
app.add_api_route("/api/projects/{name}/files/{file_name}/sequences", _list_sequences, methods=["GET"])
|
|
||||||
app.add_api_route("/api/projects/{name}/files/{file_name}/data", _get_data, methods=["GET"])
|
|
||||||
app.add_api_route("/api/projects/{name}/files/{file_name}/keys", _get_keys, methods=["GET"])
|
|
||||||
app.add_api_route("/api/image-preview", _serve_image, methods=["GET"])
|
|
||||||
|
|
||||||
|
|
||||||
def _get_db() -> ProjectDB:
|
|
||||||
if _db is None:
|
|
||||||
raise HTTPException(status_code=503, detail="Database not initialized")
|
|
||||||
return _db
|
|
||||||
|
|
||||||
|
|
||||||
def _list_projects() -> dict[str, Any]:
|
|
||||||
db = _get_db()
|
|
||||||
projects = db.list_projects()
|
|
||||||
return {"projects": [p["name"] for p in projects]}
|
|
||||||
|
|
||||||
|
|
||||||
def _get_active_project() -> dict[str, Any]:
|
|
||||||
config = load_config()
|
|
||||||
return {"project": config.get("current_project", "")}
|
|
||||||
|
|
||||||
|
|
||||||
def _get_project(name: str) -> dict[str, Any]:
|
|
||||||
db = _get_db()
|
|
||||||
proj = db.get_project(name)
|
|
||||||
if not proj:
|
|
||||||
raise HTTPException(status_code=404, detail=f"Project '{name}' not found")
|
|
||||||
folder_path = proj["folder_path"]
|
|
||||||
resolved = resolve_path_case_insensitive(folder_path)
|
|
||||||
if resolved:
|
|
||||||
folder_path = str(resolved)
|
|
||||||
# Apply configured path replacements (e.g. Docker mount casing differences)
|
|
||||||
config = load_config()
|
|
||||||
for rep in config.get("path_replacements", []):
|
|
||||||
src, dst = rep.get("from", ""), rep.get("to", "")
|
|
||||||
if src:
|
|
||||||
folder_path = folder_path.replace(src, dst)
|
|
||||||
return {"name": proj["name"], "folder_path": folder_path,
|
|
||||||
"description": proj.get("description", "")}
|
|
||||||
|
|
||||||
|
|
||||||
def _list_files(name: str) -> dict[str, Any]:
|
|
||||||
db = _get_db()
|
|
||||||
files = db.list_project_files(name)
|
|
||||||
return {"files": [{"name": f["name"], "data_type": f["data_type"]} for f in files]}
|
|
||||||
|
|
||||||
|
|
||||||
def _list_sequences(name: str, file_name: str) -> dict[str, Any]:
|
|
||||||
db = _get_db()
|
|
||||||
seqs = db.list_project_sequences(name, file_name)
|
|
||||||
return {"sequences": seqs}
|
|
||||||
|
|
||||||
|
|
||||||
def _load_sequences(name: str, file_name: str) -> list[dict]:
|
|
||||||
"""Load the batch_data list directly from the JSON file."""
|
|
||||||
db = _get_db()
|
|
||||||
proj = db.get_project(name)
|
|
||||||
if not proj:
|
|
||||||
raise HTTPException(status_code=404, detail=f"Project '{name}' not found")
|
|
||||||
json_path = Path(proj["folder_path"]) / f"{file_name}.json"
|
|
||||||
if not json_path.exists():
|
|
||||||
raise HTTPException(status_code=404, detail=f"File '{file_name}' not found in project '{name}'")
|
|
||||||
data, _ = load_json(json_path)
|
|
||||||
return data.get(KEY_BATCH_DATA, [])
|
|
||||||
|
|
||||||
|
|
||||||
def _get_data(name: str, file_name: str, seq: int = Query(default=1)) -> dict[str, Any]:
|
|
||||||
t0 = time.perf_counter()
|
|
||||||
sequences = _load_sequences(name, file_name)
|
|
||||||
match = next((s for s in sequences if int(s.get(KEY_SEQUENCE_NUMBER, 0)) == seq), None)
|
|
||||||
if match is None:
|
|
||||||
raise HTTPException(status_code=404, detail=f"Sequence {seq} not found")
|
|
||||||
result = dict(match)
|
|
||||||
# Inject strength defaults if not yet saved to JSON
|
|
||||||
for key, default in (
|
|
||||||
("start frame high strength", 1.0),
|
|
||||||
("start frame low strength", 1.0),
|
|
||||||
("middle frame high strength", 1.0),
|
|
||||||
("middle frame low strength", 1.0),
|
|
||||||
("end frame high strength", 1.0),
|
|
||||||
("end frame low strength", 1.0),
|
|
||||||
):
|
|
||||||
result.setdefault(key, default)
|
|
||||||
# Computed stem names from frame paths
|
|
||||||
for out_key, src_key in (
|
|
||||||
("start_name", "start frame path"),
|
|
||||||
("middle_name", "middle frame path"),
|
|
||||||
("end_name", "end frame path"),
|
|
||||||
):
|
|
||||||
path_val = result.get(src_key, "")
|
|
||||||
result[out_key] = Path(path_val).stem if path_val else ""
|
|
||||||
logger.info("API _get_data %s/%s seq=%d (%d keys): %.3fs",
|
|
||||||
name, file_name, seq, len(result), time.perf_counter() - t0)
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def _get_keys(name: str, file_name: str, seq: int = Query(default=1)) -> dict[str, Any]:
|
|
||||||
t0 = time.perf_counter()
|
|
||||||
sequences = _load_sequences(name, file_name)
|
|
||||||
match = next((s for s in sequences if int(s.get(KEY_SEQUENCE_NUMBER, 0)) == seq), None)
|
|
||||||
if match is None:
|
|
||||||
raise HTTPException(status_code=404, detail=f"Sequence {seq} not found")
|
|
||||||
keys = [k for k in match.keys() if k != KEY_SEQUENCE_NUMBER]
|
|
||||||
types = []
|
|
||||||
for k in keys:
|
|
||||||
v = match[k]
|
|
||||||
if isinstance(v, bool):
|
|
||||||
types.append("BOOLEAN")
|
|
||||||
elif isinstance(v, int):
|
|
||||||
types.append("INT")
|
|
||||||
elif isinstance(v, float):
|
|
||||||
types.append("FLOAT")
|
|
||||||
else:
|
|
||||||
types.append("STRING")
|
|
||||||
# Injected defaults — always present even if not yet saved to JSON
|
|
||||||
for key in (
|
|
||||||
"start frame high strength", "start frame low strength",
|
|
||||||
"middle frame high strength", "middle frame low strength",
|
|
||||||
"end frame high strength", "end frame low strength",
|
|
||||||
):
|
|
||||||
if key not in match:
|
|
||||||
keys.append(key)
|
|
||||||
types.append("FLOAT")
|
|
||||||
# Computed keys derived from frame paths
|
|
||||||
for out_key, src_key in (
|
|
||||||
("start_name", "start frame path"),
|
|
||||||
("middle_name", "middle frame path"),
|
|
||||||
("end_name", "end frame path"),
|
|
||||||
):
|
|
||||||
if src_key in match:
|
|
||||||
keys.append(out_key)
|
|
||||||
types.append("STRING")
|
|
||||||
total = len(sequences)
|
|
||||||
logger.info("API _get_keys %s/%s seq=%d (%d keys): %.3fs",
|
|
||||||
name, file_name, seq, len(keys), time.perf_counter() - t0)
|
|
||||||
return {"keys": keys, "types": types, "total_sequences": total}
|
|
||||||
|
|
||||||
|
|
||||||
def _serve_image(path: str = Query(...)) -> FileResponse:
|
|
||||||
p = Path(path)
|
|
||||||
if not p.exists() or not p.is_file():
|
|
||||||
raise HTTPException(status_code=404, detail="Image not found")
|
|
||||||
return FileResponse(str(p))
|
|
||||||
@@ -0,0 +1,214 @@
|
|||||||
|
import streamlit as st
|
||||||
|
import random
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# --- Import Custom Modules ---
|
||||||
|
from utils import (
|
||||||
|
load_config, save_config, load_snippets, save_snippets,
|
||||||
|
load_json, save_json, generate_templates, DEFAULTS
|
||||||
|
)
|
||||||
|
from tab_single import render_single_editor
|
||||||
|
from tab_batch import render_batch_processor
|
||||||
|
from tab_timeline import render_timeline_tab
|
||||||
|
from tab_timeline_wip import render_timeline_wip
|
||||||
|
from tab_comfy import render_comfy_monitor
|
||||||
|
|
||||||
|
# ==========================================
|
||||||
|
# 1. PAGE CONFIGURATION
|
||||||
|
# ==========================================
|
||||||
|
st.set_page_config(layout="wide", page_title="AI Settings Manager")
|
||||||
|
|
||||||
|
# ==========================================
|
||||||
|
# 2. SESSION STATE INITIALIZATION
|
||||||
|
# ==========================================
|
||||||
|
if 'config' not in st.session_state:
|
||||||
|
st.session_state.config = load_config()
|
||||||
|
st.session_state.current_dir = Path(st.session_state.config.get("last_dir", Path.cwd()))
|
||||||
|
|
||||||
|
if 'snippets' not in st.session_state:
|
||||||
|
st.session_state.snippets = load_snippets()
|
||||||
|
|
||||||
|
if 'loaded_file' not in st.session_state:
|
||||||
|
st.session_state.loaded_file = None
|
||||||
|
|
||||||
|
if 'last_mtime' not in st.session_state:
|
||||||
|
st.session_state.last_mtime = 0
|
||||||
|
|
||||||
|
if 'edit_history_idx' not in st.session_state:
|
||||||
|
st.session_state.edit_history_idx = None
|
||||||
|
|
||||||
|
if 'single_editor_cache' not in st.session_state:
|
||||||
|
st.session_state.single_editor_cache = DEFAULTS.copy()
|
||||||
|
|
||||||
|
if 'ui_reset_token' not in st.session_state:
|
||||||
|
st.session_state.ui_reset_token = 0
|
||||||
|
|
||||||
|
# Track the active tab state for programmatic switching
|
||||||
|
if 'active_tab_name' not in st.session_state:
|
||||||
|
st.session_state.active_tab_name = "📝 Single Editor"
|
||||||
|
|
||||||
|
# ==========================================
|
||||||
|
# 3. SIDEBAR (NAVIGATOR & TOOLS)
|
||||||
|
# ==========================================
|
||||||
|
with st.sidebar:
|
||||||
|
st.header("📂 Navigator")
|
||||||
|
|
||||||
|
# --- Path Navigator ---
|
||||||
|
new_path = st.text_input("Current Path", value=str(st.session_state.current_dir))
|
||||||
|
if new_path != str(st.session_state.current_dir):
|
||||||
|
p = Path(new_path)
|
||||||
|
if p.exists() and p.is_dir():
|
||||||
|
st.session_state.current_dir = p
|
||||||
|
st.session_state.config['last_dir'] = str(p)
|
||||||
|
save_config(st.session_state.current_dir, st.session_state.config['favorites'])
|
||||||
|
st.rerun()
|
||||||
|
|
||||||
|
# --- Favorites System ---
|
||||||
|
if st.button("📌 Pin Current Folder"):
|
||||||
|
if str(st.session_state.current_dir) not in st.session_state.config['favorites']:
|
||||||
|
st.session_state.config['favorites'].append(str(st.session_state.current_dir))
|
||||||
|
save_config(st.session_state.current_dir, st.session_state.config['favorites'])
|
||||||
|
st.rerun()
|
||||||
|
|
||||||
|
fav_selection = st.radio(
|
||||||
|
"Jump to:",
|
||||||
|
["Select..."] + st.session_state.config['favorites'],
|
||||||
|
index=0,
|
||||||
|
label_visibility="collapsed"
|
||||||
|
)
|
||||||
|
if fav_selection != "Select..." and fav_selection != str(st.session_state.current_dir):
|
||||||
|
st.session_state.current_dir = Path(fav_selection)
|
||||||
|
st.rerun()
|
||||||
|
|
||||||
|
st.markdown("---")
|
||||||
|
|
||||||
|
# --- Snippet Library ---
|
||||||
|
st.subheader("🧩 Snippet Library")
|
||||||
|
with st.expander("Add New Snippet"):
|
||||||
|
snip_name = st.text_input("Name", placeholder="e.g. Cinematic")
|
||||||
|
snip_content = st.text_area("Content", placeholder="4k, high quality...")
|
||||||
|
if st.button("Save Snippet"):
|
||||||
|
if snip_name and snip_content:
|
||||||
|
st.session_state.snippets[snip_name] = snip_content
|
||||||
|
save_snippets(st.session_state.snippets)
|
||||||
|
st.success(f"Saved '{snip_name}'")
|
||||||
|
st.rerun()
|
||||||
|
|
||||||
|
if st.session_state.snippets:
|
||||||
|
st.caption("Click to Append to Prompt:")
|
||||||
|
for name, content in st.session_state.snippets.items():
|
||||||
|
col_s1, col_s2 = st.columns([4, 1])
|
||||||
|
if col_s1.button(f"➕ {name}", use_container_width=True):
|
||||||
|
st.session_state.append_prompt = content
|
||||||
|
st.rerun()
|
||||||
|
if col_s2.button("🗑️", key=f"del_snip_{name}"):
|
||||||
|
del st.session_state.snippets[name]
|
||||||
|
save_snippets(st.session_state.snippets)
|
||||||
|
st.rerun()
|
||||||
|
|
||||||
|
st.markdown("---")
|
||||||
|
|
||||||
|
# --- File List & Creation ---
|
||||||
|
json_files = sorted(list(st.session_state.current_dir.glob("*.json")))
|
||||||
|
json_files = [f for f in json_files if f.name != ".editor_config.json" and f.name != ".editor_snippets.json"]
|
||||||
|
|
||||||
|
if not json_files:
|
||||||
|
if st.button("Generate Templates"):
|
||||||
|
generate_templates(st.session_state.current_dir)
|
||||||
|
st.rerun()
|
||||||
|
|
||||||
|
with st.expander("Create New JSON"):
|
||||||
|
new_filename = st.text_input("Filename", placeholder="my_prompt_vace")
|
||||||
|
is_batch = st.checkbox("Is Batch File?")
|
||||||
|
if st.button("Create"):
|
||||||
|
if not new_filename.endswith(".json"): new_filename += ".json"
|
||||||
|
path = st.session_state.current_dir / new_filename
|
||||||
|
if is_batch:
|
||||||
|
data = {"batch_data": []}
|
||||||
|
else:
|
||||||
|
data = DEFAULTS.copy()
|
||||||
|
if "vace" in new_filename: data.update({"frame_to_skip": 81, "vace schedule": 1, "video file path": ""})
|
||||||
|
elif "i2v" in new_filename: data.update({"reference image path": "", "flf image path": ""})
|
||||||
|
save_json(path, data)
|
||||||
|
st.rerun()
|
||||||
|
|
||||||
|
# --- File Selector ---
|
||||||
|
if 'file_selector' not in st.session_state:
|
||||||
|
st.session_state.file_selector = json_files[0].name if json_files else None
|
||||||
|
if st.session_state.file_selector not in [f.name for f in json_files] and json_files:
|
||||||
|
st.session_state.file_selector = json_files[0].name
|
||||||
|
|
||||||
|
selected_file_name = st.radio("Select File", [f.name for f in json_files], key="file_selector")
|
||||||
|
|
||||||
|
# ==========================================
|
||||||
|
# 4. MAIN APP LOGIC
|
||||||
|
# ==========================================
|
||||||
|
if selected_file_name:
|
||||||
|
file_path = st.session_state.current_dir / selected_file_name
|
||||||
|
|
||||||
|
# --- FILE LOADING & AUTO-SWITCH LOGIC ---
|
||||||
|
if st.session_state.loaded_file != str(file_path):
|
||||||
|
data, mtime = load_json(file_path)
|
||||||
|
st.session_state.data_cache = data
|
||||||
|
st.session_state.last_mtime = mtime
|
||||||
|
st.session_state.loaded_file = str(file_path)
|
||||||
|
|
||||||
|
# Clear transient states
|
||||||
|
if 'append_prompt' in st.session_state: del st.session_state.append_prompt
|
||||||
|
if 'rand_seed' in st.session_state: del st.session_state.rand_seed
|
||||||
|
if 'restored_indicator' in st.session_state: del st.session_state.restored_indicator
|
||||||
|
st.session_state.edit_history_idx = None
|
||||||
|
|
||||||
|
# --- AUTO-SWITCH TAB LOGIC ---
|
||||||
|
# If the file has 'batch_data' or is a list, force Batch tab.
|
||||||
|
# Otherwise, force Single tab.
|
||||||
|
is_batch = "batch_data" in data or isinstance(data, list)
|
||||||
|
if is_batch:
|
||||||
|
st.session_state.active_tab_name = "🚀 Batch Processor"
|
||||||
|
else:
|
||||||
|
st.session_state.active_tab_name = "📝 Single Editor"
|
||||||
|
|
||||||
|
else:
|
||||||
|
data = st.session_state.data_cache
|
||||||
|
|
||||||
|
st.title(f"Editing: {selected_file_name}")
|
||||||
|
|
||||||
|
# --- CONTROLLED NAVIGATION (REPLACES ST.TABS) ---
|
||||||
|
# Using radio buttons allows us to change 'active_tab_name' programmatically above.
|
||||||
|
tabs_list = [
|
||||||
|
"📝 Single Editor",
|
||||||
|
"🚀 Batch Processor",
|
||||||
|
"🕒 Timeline",
|
||||||
|
"🧪 Interactive Timeline",
|
||||||
|
"🔌 Comfy Monitor"
|
||||||
|
]
|
||||||
|
|
||||||
|
# Ensure active tab is valid (safety check)
|
||||||
|
if st.session_state.active_tab_name not in tabs_list:
|
||||||
|
st.session_state.active_tab_name = tabs_list[0]
|
||||||
|
|
||||||
|
current_tab = st.radio(
|
||||||
|
"Navigation",
|
||||||
|
tabs_list,
|
||||||
|
key="active_tab_name", # Binds to session state
|
||||||
|
horizontal=True,
|
||||||
|
label_visibility="collapsed"
|
||||||
|
)
|
||||||
|
|
||||||
|
st.markdown("---")
|
||||||
|
|
||||||
|
# --- RENDER SELECTED TAB ---
|
||||||
|
if current_tab == "📝 Single Editor":
|
||||||
|
render_single_editor(data, file_path)
|
||||||
|
|
||||||
|
elif current_tab == "🚀 Batch Processor":
|
||||||
|
render_batch_processor(data, file_path, json_files, st.session_state.current_dir, selected_file_name)
|
||||||
|
|
||||||
|
elif current_tab == "🕒 Timeline":
|
||||||
|
render_timeline_tab(data, file_path)
|
||||||
|
|
||||||
|
elif current_tab == "🧪 Interactive Timeline":
|
||||||
|
render_timeline_wip(data, file_path)
|
||||||
|
|
||||||
|
elif current_tab == "🔌 Comfy Monitor":
|
||||||
|
render_comfy_monitor()
|
||||||
@@ -1,596 +0,0 @@
|
|||||||
import json
|
|
||||||
import logging
|
|
||||||
import sqlite3
|
|
||||||
import time
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from utils import load_json, KEY_BATCH_DATA, KEY_HISTORY_TREE
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
DEFAULT_DB_PATH = Path(__file__).parent / "projects.db"
|
|
||||||
|
|
||||||
SCHEMA_SQL = """
|
|
||||||
CREATE TABLE IF NOT EXISTS projects (
|
|
||||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
||||||
name TEXT NOT NULL UNIQUE,
|
|
||||||
folder_path TEXT NOT NULL,
|
|
||||||
description TEXT NOT NULL DEFAULT '',
|
|
||||||
created_at REAL NOT NULL,
|
|
||||||
updated_at REAL NOT NULL
|
|
||||||
);
|
|
||||||
|
|
||||||
CREATE TABLE IF NOT EXISTS data_files (
|
|
||||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
||||||
project_id INTEGER NOT NULL REFERENCES projects(id) ON DELETE CASCADE,
|
|
||||||
name TEXT NOT NULL,
|
|
||||||
data_type TEXT NOT NULL DEFAULT 'generic',
|
|
||||||
top_level TEXT NOT NULL DEFAULT '{}',
|
|
||||||
created_at REAL NOT NULL,
|
|
||||||
updated_at REAL NOT NULL,
|
|
||||||
UNIQUE(project_id, name)
|
|
||||||
);
|
|
||||||
|
|
||||||
CREATE TABLE IF NOT EXISTS sequences (
|
|
||||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
||||||
data_file_id INTEGER NOT NULL REFERENCES data_files(id) ON DELETE CASCADE,
|
|
||||||
sequence_number INTEGER NOT NULL,
|
|
||||||
data TEXT NOT NULL DEFAULT '{}',
|
|
||||||
updated_at REAL NOT NULL,
|
|
||||||
UNIQUE(data_file_id, sequence_number)
|
|
||||||
);
|
|
||||||
|
|
||||||
CREATE TABLE IF NOT EXISTS history_trees (
|
|
||||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
||||||
data_file_id INTEGER NOT NULL UNIQUE REFERENCES data_files(id) ON DELETE CASCADE,
|
|
||||||
tree_data TEXT NOT NULL DEFAULT '{}',
|
|
||||||
updated_at REAL NOT NULL
|
|
||||||
);
|
|
||||||
|
|
||||||
CREATE TABLE IF NOT EXISTS history_snapshots (
|
|
||||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
||||||
data_file_id INTEGER NOT NULL REFERENCES data_files(id) ON DELETE CASCADE,
|
|
||||||
node_id TEXT NOT NULL,
|
|
||||||
snapshot_data TEXT NOT NULL,
|
|
||||||
updated_at REAL NOT NULL,
|
|
||||||
UNIQUE(data_file_id, node_id)
|
|
||||||
);
|
|
||||||
|
|
||||||
CREATE INDEX IF NOT EXISTS idx_data_files_project_id ON data_files(project_id);
|
|
||||||
CREATE INDEX IF NOT EXISTS idx_sequences_data_file_id ON sequences(data_file_id);
|
|
||||||
CREATE INDEX IF NOT EXISTS idx_history_snapshots_df ON history_snapshots(data_file_id);
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
class ProjectDB:
|
|
||||||
"""SQLite database for project-based data management."""
|
|
||||||
|
|
||||||
def __init__(self, db_path: str | Path | None = None):
|
|
||||||
self.db_path = Path(db_path) if db_path else DEFAULT_DB_PATH
|
|
||||||
self.db_path.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
self.conn = sqlite3.connect(
|
|
||||||
str(self.db_path),
|
|
||||||
check_same_thread=False,
|
|
||||||
isolation_level=None, # autocommit — explicit BEGIN/COMMIT only
|
|
||||||
)
|
|
||||||
self.conn.row_factory = sqlite3.Row
|
|
||||||
self.conn.execute("PRAGMA journal_mode=WAL")
|
|
||||||
self.conn.execute("PRAGMA foreign_keys=ON")
|
|
||||||
self.conn.executescript(SCHEMA_SQL)
|
|
||||||
self._migrate_all_lora_data()
|
|
||||||
|
|
||||||
def _migrate_all_lora_data(self) -> None:
|
|
||||||
"""Bulk migration: split combined lora 'name:strength' into separate keys."""
|
|
||||||
rows = self.conn.execute("SELECT id, data FROM sequences").fetchall()
|
|
||||||
updated = 0
|
|
||||||
self.conn.execute("BEGIN")
|
|
||||||
try:
|
|
||||||
for row in rows:
|
|
||||||
data = json.loads(row["data"])
|
|
||||||
original = row["data"]
|
|
||||||
migrated = self._migrate_lora_keys(data)
|
|
||||||
new_json = json.dumps(migrated)
|
|
||||||
if new_json != original:
|
|
||||||
self.conn.execute(
|
|
||||||
"UPDATE sequences SET data = ? WHERE id = ?",
|
|
||||||
(new_json, row["id"]),
|
|
||||||
)
|
|
||||||
updated += 1
|
|
||||||
self.conn.execute("COMMIT")
|
|
||||||
except Exception:
|
|
||||||
self.conn.execute("ROLLBACK")
|
|
||||||
raise
|
|
||||||
if updated:
|
|
||||||
logger.info("Migrated lora keys in %d/%d sequences", updated, len(rows))
|
|
||||||
|
|
||||||
def close(self):
|
|
||||||
self.conn.close()
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# Projects CRUD
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
|
|
||||||
def create_project(self, name: str, folder_path: str, description: str = "") -> int:
|
|
||||||
now = time.time()
|
|
||||||
cur = self.conn.execute(
|
|
||||||
"INSERT INTO projects (name, folder_path, description, created_at, updated_at) "
|
|
||||||
"VALUES (?, ?, ?, ?, ?)",
|
|
||||||
(name, folder_path, description, now, now),
|
|
||||||
)
|
|
||||||
self.conn.commit()
|
|
||||||
return cur.lastrowid
|
|
||||||
|
|
||||||
def list_projects(self) -> list[dict]:
|
|
||||||
rows = self.conn.execute(
|
|
||||||
"SELECT id, name, folder_path, description, created_at, updated_at "
|
|
||||||
"FROM projects ORDER BY name"
|
|
||||||
).fetchall()
|
|
||||||
return [dict(r) for r in rows]
|
|
||||||
|
|
||||||
def list_projects_with_file_counts(self) -> list[dict]:
|
|
||||||
"""List projects with data file counts in a single query."""
|
|
||||||
rows = self.conn.execute(
|
|
||||||
"SELECT p.id, p.name, p.folder_path, p.description, p.created_at, p.updated_at, "
|
|
||||||
"COUNT(df.id) AS file_count "
|
|
||||||
"FROM projects p LEFT JOIN data_files df ON df.project_id = p.id "
|
|
||||||
"GROUP BY p.id ORDER BY p.name"
|
|
||||||
).fetchall()
|
|
||||||
return [dict(r) for r in rows]
|
|
||||||
|
|
||||||
def get_project(self, name: str) -> dict | None:
|
|
||||||
row = self.conn.execute(
|
|
||||||
"SELECT id, name, folder_path, description, created_at, updated_at "
|
|
||||||
"FROM projects WHERE name = ?",
|
|
||||||
(name,),
|
|
||||||
).fetchone()
|
|
||||||
return dict(row) if row else None
|
|
||||||
|
|
||||||
def rename_project(self, old_name: str, new_name: str) -> bool:
|
|
||||||
now = time.time()
|
|
||||||
cur = self.conn.execute(
|
|
||||||
"UPDATE projects SET name = ?, updated_at = ? WHERE name = ?",
|
|
||||||
(new_name, now, old_name),
|
|
||||||
)
|
|
||||||
self.conn.commit()
|
|
||||||
return cur.rowcount > 0
|
|
||||||
|
|
||||||
def update_project_path(self, name: str, folder_path: str) -> bool:
|
|
||||||
now = time.time()
|
|
||||||
cur = self.conn.execute(
|
|
||||||
"UPDATE projects SET folder_path = ?, updated_at = ? WHERE name = ?",
|
|
||||||
(folder_path, now, name),
|
|
||||||
)
|
|
||||||
self.conn.commit()
|
|
||||||
return cur.rowcount > 0
|
|
||||||
|
|
||||||
def delete_project(self, name: str) -> bool:
|
|
||||||
cur = self.conn.execute("DELETE FROM projects WHERE name = ?", (name,))
|
|
||||||
self.conn.commit()
|
|
||||||
return cur.rowcount > 0
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# Data files
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
|
|
||||||
def create_data_file(
|
|
||||||
self, project_id: int, name: str, data_type: str = "generic", top_level: dict | None = None
|
|
||||||
) -> int:
|
|
||||||
now = time.time()
|
|
||||||
tl = json.dumps(top_level or {})
|
|
||||||
cur = self.conn.execute(
|
|
||||||
"INSERT INTO data_files (project_id, name, data_type, top_level, created_at, updated_at) "
|
|
||||||
"VALUES (?, ?, ?, ?, ?, ?)",
|
|
||||||
(project_id, name, data_type, tl, now, now),
|
|
||||||
)
|
|
||||||
self.conn.commit()
|
|
||||||
return cur.lastrowid
|
|
||||||
|
|
||||||
def list_data_files(self, project_id: int) -> list[dict]:
|
|
||||||
rows = self.conn.execute(
|
|
||||||
"SELECT id, project_id, name, data_type, created_at, updated_at "
|
|
||||||
"FROM data_files WHERE project_id = ? ORDER BY name",
|
|
||||||
(project_id,),
|
|
||||||
).fetchall()
|
|
||||||
return [dict(r) for r in rows]
|
|
||||||
|
|
||||||
def count_data_files(self, project_id: int) -> int:
|
|
||||||
"""Return the number of data files for a project."""
|
|
||||||
row = self.conn.execute(
|
|
||||||
"SELECT COUNT(*) AS cnt FROM data_files WHERE project_id = ?",
|
|
||||||
(project_id,),
|
|
||||||
).fetchone()
|
|
||||||
return row["cnt"]
|
|
||||||
|
|
||||||
def get_data_file(self, project_id: int, name: str) -> dict | None:
|
|
||||||
row = self.conn.execute(
|
|
||||||
"SELECT id, project_id, name, data_type, top_level, created_at, updated_at "
|
|
||||||
"FROM data_files WHERE project_id = ? AND name = ?",
|
|
||||||
(project_id, name),
|
|
||||||
).fetchone()
|
|
||||||
if row is None:
|
|
||||||
return None
|
|
||||||
d = dict(row)
|
|
||||||
d["top_level"] = json.loads(d["top_level"])
|
|
||||||
return d
|
|
||||||
|
|
||||||
def get_data_file_by_names(self, project_name: str, file_name: str) -> dict | None:
|
|
||||||
row = self.conn.execute(
|
|
||||||
"SELECT df.id, df.project_id, df.name, df.data_type, df.top_level, "
|
|
||||||
"df.created_at, df.updated_at "
|
|
||||||
"FROM data_files df JOIN projects p ON df.project_id = p.id "
|
|
||||||
"WHERE p.name = ? AND df.name = ?",
|
|
||||||
(project_name, file_name),
|
|
||||||
).fetchone()
|
|
||||||
if row is None:
|
|
||||||
return None
|
|
||||||
d = dict(row)
|
|
||||||
d["top_level"] = json.loads(d["top_level"])
|
|
||||||
return d
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# Sequences
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
|
|
||||||
def upsert_sequence(self, data_file_id: int, sequence_number: int, data: dict) -> None:
|
|
||||||
now = time.time()
|
|
||||||
self.conn.execute(
|
|
||||||
"INSERT INTO sequences (data_file_id, sequence_number, data, updated_at) "
|
|
||||||
"VALUES (?, ?, ?, ?) "
|
|
||||||
"ON CONFLICT(data_file_id, sequence_number) DO UPDATE SET data=excluded.data, updated_at=excluded.updated_at",
|
|
||||||
(data_file_id, sequence_number, json.dumps(data), now),
|
|
||||||
)
|
|
||||||
self.conn.commit()
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _migrate_lora_keys(data: dict) -> dict:
|
|
||||||
"""Split combined lora 'name:strength' into separate name and strength keys."""
|
|
||||||
for idx in range(1, 4):
|
|
||||||
for tier in ('high', 'low'):
|
|
||||||
name_key = f'lora {idx} {tier}'
|
|
||||||
str_key = f'lora {idx} {tier} strength'
|
|
||||||
raw = str(data.get(name_key, ''))
|
|
||||||
if raw.startswith('<lora:'):
|
|
||||||
inner = raw.replace('<lora:', '').replace('>', '')
|
|
||||||
if ':' in inner:
|
|
||||||
parts = inner.rsplit(':', 1)
|
|
||||||
data[name_key] = parts[0]
|
|
||||||
try:
|
|
||||||
data[str_key] = float(parts[1])
|
|
||||||
except ValueError:
|
|
||||||
data[str_key] = 1.0
|
|
||||||
else:
|
|
||||||
data[name_key] = inner
|
|
||||||
if str_key not in data:
|
|
||||||
data[str_key] = 1.0
|
|
||||||
elif ':' in raw and raw:
|
|
||||||
parts = raw.rsplit(':', 1)
|
|
||||||
try:
|
|
||||||
strength = float(parts[1])
|
|
||||||
data[name_key] = parts[0]
|
|
||||||
data[str_key] = strength
|
|
||||||
except ValueError:
|
|
||||||
if str_key not in data:
|
|
||||||
data[str_key] = 1.0
|
|
||||||
elif raw:
|
|
||||||
# Name exists without colon, ensure strength key exists
|
|
||||||
if str_key not in data:
|
|
||||||
data[str_key] = 1.0
|
|
||||||
# If name is empty, don't add a strength key
|
|
||||||
return data
|
|
||||||
|
|
||||||
def get_sequence(self, data_file_id: int, sequence_number: int) -> dict | None:
|
|
||||||
row = self.conn.execute(
|
|
||||||
"SELECT data FROM sequences WHERE data_file_id = ? AND sequence_number = ?",
|
|
||||||
(data_file_id, sequence_number),
|
|
||||||
).fetchone()
|
|
||||||
if not row:
|
|
||||||
return None
|
|
||||||
data = json.loads(row["data"])
|
|
||||||
return self._migrate_lora_keys(data)
|
|
||||||
|
|
||||||
def list_sequences(self, data_file_id: int) -> list[int]:
|
|
||||||
rows = self.conn.execute(
|
|
||||||
"SELECT sequence_number FROM sequences WHERE data_file_id = ? ORDER BY sequence_number",
|
|
||||||
(data_file_id,),
|
|
||||||
).fetchall()
|
|
||||||
return [r["sequence_number"] for r in rows]
|
|
||||||
|
|
||||||
def count_sequences(self, data_file_id: int) -> int:
|
|
||||||
"""Return the number of sequences for a data file."""
|
|
||||||
row = self.conn.execute(
|
|
||||||
"SELECT COUNT(*) AS cnt FROM sequences WHERE data_file_id = ?",
|
|
||||||
(data_file_id,),
|
|
||||||
).fetchone()
|
|
||||||
return row["cnt"]
|
|
||||||
|
|
||||||
def query_total_sequences(self, project_name: str, file_name: str) -> int:
|
|
||||||
"""Return total sequence count by project and file names."""
|
|
||||||
df = self.get_data_file_by_names(project_name, file_name)
|
|
||||||
if not df:
|
|
||||||
return 0
|
|
||||||
return self.count_sequences(df["id"])
|
|
||||||
|
|
||||||
def get_sequence_keys(self, data_file_id: int, sequence_number: int) -> tuple[list[str], list[str]]:
|
|
||||||
"""Returns (keys, types) for a sequence's data dict."""
|
|
||||||
data = self.get_sequence(data_file_id, sequence_number)
|
|
||||||
if not data:
|
|
||||||
return [], []
|
|
||||||
keys = []
|
|
||||||
types = []
|
|
||||||
for k, v in data.items():
|
|
||||||
keys.append(k)
|
|
||||||
if isinstance(v, bool):
|
|
||||||
types.append("STRING")
|
|
||||||
elif isinstance(v, int):
|
|
||||||
types.append("INT")
|
|
||||||
elif isinstance(v, float):
|
|
||||||
types.append("FLOAT")
|
|
||||||
else:
|
|
||||||
types.append("STRING")
|
|
||||||
return keys, types
|
|
||||||
|
|
||||||
def delete_sequences_for_file(self, data_file_id: int) -> None:
|
|
||||||
self.conn.execute("DELETE FROM sequences WHERE data_file_id = ?", (data_file_id,))
|
|
||||||
self.conn.commit()
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# History trees
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
|
|
||||||
def save_history_tree(self, data_file_id: int, tree_data: dict) -> None:
|
|
||||||
"""Save history tree, extracting snapshot data into separate table.
|
|
||||||
|
|
||||||
Supports both new format (snapshots dict) and old format (nodes dict).
|
|
||||||
"""
|
|
||||||
now = time.time()
|
|
||||||
if "snapshots" in tree_data:
|
|
||||||
entries = tree_data.get("snapshots", {})
|
|
||||||
entry_key = "snapshots"
|
|
||||||
else:
|
|
||||||
entries = tree_data.get("nodes", {})
|
|
||||||
entry_key = "nodes"
|
|
||||||
slim_tree = dict(tree_data)
|
|
||||||
slim_entries = {}
|
|
||||||
for eid, entry in entries.items():
|
|
||||||
slim_entries[eid] = {k: v for k, v in entry.items() if k != "data"}
|
|
||||||
slim_tree[entry_key] = slim_entries
|
|
||||||
|
|
||||||
self.conn.execute("BEGIN IMMEDIATE")
|
|
||||||
try:
|
|
||||||
for eid, entry in entries.items():
|
|
||||||
snap = entry.get("data")
|
|
||||||
if snap:
|
|
||||||
self.conn.execute(
|
|
||||||
"INSERT INTO history_snapshots (data_file_id, node_id, snapshot_data, updated_at) "
|
|
||||||
"VALUES (?, ?, ?, ?) "
|
|
||||||
"ON CONFLICT(data_file_id, node_id) DO UPDATE SET "
|
|
||||||
"snapshot_data=excluded.snapshot_data, updated_at=excluded.updated_at",
|
|
||||||
(data_file_id, eid, json.dumps(snap), now),
|
|
||||||
)
|
|
||||||
self.conn.execute(
|
|
||||||
"INSERT INTO history_trees (data_file_id, tree_data, updated_at) "
|
|
||||||
"VALUES (?, ?, ?) "
|
|
||||||
"ON CONFLICT(data_file_id) DO UPDATE SET tree_data=excluded.tree_data, updated_at=excluded.updated_at",
|
|
||||||
(data_file_id, json.dumps(slim_tree), now),
|
|
||||||
)
|
|
||||||
self.conn.execute("COMMIT")
|
|
||||||
except Exception:
|
|
||||||
try:
|
|
||||||
self.conn.execute("ROLLBACK")
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
raise
|
|
||||||
|
|
||||||
def get_history_tree(self, data_file_id: int) -> dict | None:
|
|
||||||
"""Load history tree metadata (without snapshot data)."""
|
|
||||||
row = self.conn.execute(
|
|
||||||
"SELECT tree_data FROM history_trees WHERE data_file_id = ?",
|
|
||||||
(data_file_id,),
|
|
||||||
).fetchone()
|
|
||||||
return json.loads(row["tree_data"]) if row else None
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# History snapshots (per-node data, loaded on demand)
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
|
|
||||||
def get_node_snapshot(self, data_file_id: int, node_id: str) -> dict | None:
|
|
||||||
"""Load a single node's snapshot data on demand."""
|
|
||||||
row = self.conn.execute(
|
|
||||||
"SELECT snapshot_data FROM history_snapshots WHERE data_file_id = ? AND node_id = ?",
|
|
||||||
(data_file_id, node_id),
|
|
||||||
).fetchone()
|
|
||||||
return json.loads(row["snapshot_data"]) if row else None
|
|
||||||
|
|
||||||
def delete_node_snapshots(self, data_file_id: int, node_ids: set) -> None:
|
|
||||||
"""Delete snapshots for removed nodes."""
|
|
||||||
if not node_ids:
|
|
||||||
return
|
|
||||||
placeholders = ",".join("?" for _ in node_ids)
|
|
||||||
self.conn.execute(
|
|
||||||
f"DELETE FROM history_snapshots WHERE data_file_id = ? AND node_id IN ({placeholders})",
|
|
||||||
(data_file_id, *node_ids),
|
|
||||||
)
|
|
||||||
self.conn.commit()
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# Import
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
|
|
||||||
def import_json_file(self, project_id: int, json_path: str | Path, data_type: str = "generic") -> int:
|
|
||||||
"""Import a JSON file into the database, splitting batch_data into sequences.
|
|
||||||
|
|
||||||
Safe to call repeatedly — existing data_file is updated, sequences are
|
|
||||||
replaced, and history_tree is upserted. Atomic: all-or-nothing.
|
|
||||||
"""
|
|
||||||
json_path = Path(json_path)
|
|
||||||
data, _ = load_json(json_path)
|
|
||||||
file_name = json_path.stem
|
|
||||||
|
|
||||||
top_level = {k: v for k, v in data.items() if k not in (KEY_BATCH_DATA, KEY_HISTORY_TREE)}
|
|
||||||
|
|
||||||
self.conn.execute("BEGIN IMMEDIATE")
|
|
||||||
try:
|
|
||||||
existing = self.conn.execute(
|
|
||||||
"SELECT id FROM data_files WHERE project_id = ? AND name = ?",
|
|
||||||
(project_id, file_name),
|
|
||||||
).fetchone()
|
|
||||||
|
|
||||||
if existing:
|
|
||||||
df_id = existing["id"]
|
|
||||||
now = time.time()
|
|
||||||
self.conn.execute(
|
|
||||||
"UPDATE data_files SET data_type = ?, top_level = ?, updated_at = ? WHERE id = ?",
|
|
||||||
(data_type, json.dumps(top_level), now, df_id),
|
|
||||||
)
|
|
||||||
self.conn.execute("DELETE FROM sequences WHERE data_file_id = ?", (df_id,))
|
|
||||||
else:
|
|
||||||
now = time.time()
|
|
||||||
cur = self.conn.execute(
|
|
||||||
"INSERT INTO data_files (project_id, name, data_type, top_level, created_at, updated_at) "
|
|
||||||
"VALUES (?, ?, ?, ?, ?, ?)",
|
|
||||||
(project_id, file_name, data_type, json.dumps(top_level), now, now),
|
|
||||||
)
|
|
||||||
df_id = cur.lastrowid
|
|
||||||
|
|
||||||
# Import sequences from batch_data
|
|
||||||
batch_data = data.get(KEY_BATCH_DATA, [])
|
|
||||||
if isinstance(batch_data, list):
|
|
||||||
for item in batch_data:
|
|
||||||
if not isinstance(item, dict):
|
|
||||||
continue
|
|
||||||
seq_num = int(item.get("sequence_number", 0))
|
|
||||||
now = time.time()
|
|
||||||
self.conn.execute(
|
|
||||||
"INSERT INTO sequences (data_file_id, sequence_number, data, updated_at) "
|
|
||||||
"VALUES (?, ?, ?, ?) "
|
|
||||||
"ON CONFLICT(data_file_id, sequence_number) DO UPDATE SET data=excluded.data, updated_at=excluded.updated_at",
|
|
||||||
(df_id, seq_num, json.dumps(item), now),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Import history tree (extract snapshots into separate table)
|
|
||||||
# Supports both new format (snapshots dict) and old format (nodes dict)
|
|
||||||
history_tree = data.get(KEY_HISTORY_TREE)
|
|
||||||
if history_tree and isinstance(history_tree, dict):
|
|
||||||
now = time.time()
|
|
||||||
if "snapshots" in history_tree:
|
|
||||||
entries = history_tree.get("snapshots", {})
|
|
||||||
entry_key = "snapshots"
|
|
||||||
else:
|
|
||||||
entries = history_tree.get("nodes", {})
|
|
||||||
entry_key = "nodes"
|
|
||||||
slim_tree = dict(history_tree)
|
|
||||||
slim_entries = {}
|
|
||||||
for eid, entry in entries.items():
|
|
||||||
snap = entry.get("data")
|
|
||||||
if snap:
|
|
||||||
self.conn.execute(
|
|
||||||
"INSERT INTO history_snapshots (data_file_id, node_id, snapshot_data, updated_at) "
|
|
||||||
"VALUES (?, ?, ?, ?) "
|
|
||||||
"ON CONFLICT(data_file_id, node_id) DO UPDATE SET "
|
|
||||||
"snapshot_data=excluded.snapshot_data, updated_at=excluded.updated_at",
|
|
||||||
(df_id, eid, json.dumps(snap), now),
|
|
||||||
)
|
|
||||||
slim_entries[eid] = {k: v for k, v in entry.items() if k != "data"}
|
|
||||||
slim_tree[entry_key] = slim_entries
|
|
||||||
self.conn.execute(
|
|
||||||
"INSERT INTO history_trees (data_file_id, tree_data, updated_at) "
|
|
||||||
"VALUES (?, ?, ?) "
|
|
||||||
"ON CONFLICT(data_file_id) DO UPDATE SET tree_data=excluded.tree_data, updated_at=excluded.updated_at",
|
|
||||||
(df_id, json.dumps(slim_tree), now),
|
|
||||||
)
|
|
||||||
|
|
||||||
self.conn.execute("COMMIT")
|
|
||||||
return df_id
|
|
||||||
except Exception:
|
|
||||||
try:
|
|
||||||
self.conn.execute("ROLLBACK")
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
raise
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# Full data reconstruction (replaces load_json for DB-backed files)
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
|
|
||||||
def load_full_data(self, project_name: str, file_name: str) -> dict | None:
|
|
||||||
"""Reconstruct the full data dict from DB, matching load_json format.
|
|
||||||
|
|
||||||
Returns None if the project or file doesn't exist in the DB.
|
|
||||||
Result has the same structure as a JSON file: top-level keys +
|
|
||||||
batch_data list + history_tree dict.
|
|
||||||
"""
|
|
||||||
t0 = time.time()
|
|
||||||
df = self.get_data_file_by_names(project_name, file_name)
|
|
||||||
if not df:
|
|
||||||
return None
|
|
||||||
t1 = time.time()
|
|
||||||
|
|
||||||
# Start with top-level keys
|
|
||||||
data = df.get("top_level", {})
|
|
||||||
if isinstance(data, str):
|
|
||||||
data = json.loads(data)
|
|
||||||
|
|
||||||
# Load all sequences as batch_data
|
|
||||||
# Group sub-segments (>=1000) after their parent: parent = seq_num / 1000
|
|
||||||
rows = self.conn.execute(
|
|
||||||
"SELECT data FROM sequences WHERE data_file_id = ? "
|
|
||||||
"ORDER BY CASE WHEN sequence_number >= 1000 THEN sequence_number / 1000 "
|
|
||||||
"ELSE sequence_number END, "
|
|
||||||
"CASE WHEN sequence_number >= 1000 THEN 1 ELSE 0 END, "
|
|
||||||
"sequence_number",
|
|
||||||
(df["id"],),
|
|
||||||
).fetchall()
|
|
||||||
batch_data = []
|
|
||||||
for row in rows:
|
|
||||||
seq = json.loads(row["data"])
|
|
||||||
self._migrate_lora_keys(seq)
|
|
||||||
batch_data.append(seq)
|
|
||||||
data["batch_data"] = batch_data
|
|
||||||
t2 = time.time()
|
|
||||||
|
|
||||||
# Load history tree (metadata only, no snapshot data)
|
|
||||||
tree = self.get_history_tree(df["id"])
|
|
||||||
if tree:
|
|
||||||
# Strip any residual snapshot data (supports both formats)
|
|
||||||
for entry in tree.get("snapshots", tree.get("nodes", {})).values():
|
|
||||||
entry.pop("data", None)
|
|
||||||
data["history_tree"] = tree
|
|
||||||
t3 = time.time()
|
|
||||||
|
|
||||||
logger.info("load_full_data %s/%s (%d seqs): lookup=%.3fs seqs=%.3fs tree=%.3fs total=%.3fs",
|
|
||||||
project_name, file_name, len(batch_data),
|
|
||||||
t1 - t0, t2 - t1, t3 - t2, t3 - t0)
|
|
||||||
return data
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# Query helpers (for REST API)
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
|
|
||||||
def query_sequence_data(self, project_name: str, file_name: str, sequence_number: int) -> dict | None:
|
|
||||||
"""Query a single sequence by project name, file name, and sequence number."""
|
|
||||||
df = self.get_data_file_by_names(project_name, file_name)
|
|
||||||
if not df:
|
|
||||||
return None
|
|
||||||
return self.get_sequence(df["id"], sequence_number)
|
|
||||||
|
|
||||||
def query_sequence_keys(self, project_name: str, file_name: str, sequence_number: int) -> tuple[list[str], list[str]]:
|
|
||||||
"""Query keys and types for a sequence."""
|
|
||||||
df = self.get_data_file_by_names(project_name, file_name)
|
|
||||||
if not df:
|
|
||||||
return [], []
|
|
||||||
return self.get_sequence_keys(df["id"], sequence_number)
|
|
||||||
|
|
||||||
def list_project_files(self, project_name: str) -> list[dict]:
|
|
||||||
"""List data files for a project by name."""
|
|
||||||
proj = self.get_project(project_name)
|
|
||||||
if not proj:
|
|
||||||
return []
|
|
||||||
return self.list_data_files(proj["id"])
|
|
||||||
|
|
||||||
def list_project_sequences(self, project_name: str, file_name: str) -> list[int]:
|
|
||||||
"""List sequence numbers for a file in a project."""
|
|
||||||
df = self.get_data_file_by_names(project_name, file_name)
|
|
||||||
if not df:
|
|
||||||
return []
|
|
||||||
return self.list_sequences(df["id"])
|
|
||||||
@@ -1,81 +0,0 @@
|
|||||||
# Resolution Series Design
|
|
||||||
|
|
||||||
## Problem
|
|
||||||
|
|
||||||
When running ComfyUI loop nodes for multi-step upscaling (e.g. 3+ resolutions at different sizes),
|
|
||||||
managing portrait vs landscape width/height per iteration is tedious. Users need a structured way
|
|
||||||
to define N resolution pairs in the manager UI and retrieve them by loop index in ComfyUI.
|
|
||||||
|
|
||||||
## Design
|
|
||||||
|
|
||||||
### Data Model
|
|
||||||
|
|
||||||
Resolution series are stored as a JSON array under a user-chosen key in the sequence data:
|
|
||||||
|
|
||||||
```json
|
|
||||||
"upscale_resolutions": [[512, 512], [768, 1344], [1344, 768], [2048, 2048]]
|
|
||||||
```
|
|
||||||
|
|
||||||
- Each element is `[width, height]` (both INT)
|
|
||||||
- Key name is chosen by the user (any string)
|
|
||||||
- Number of entries is configurable (add/remove rows)
|
|
||||||
- Stored in the same project JSON file and sequence — no schema change required
|
|
||||||
- Index out of bounds → clamp to last entry
|
|
||||||
|
|
||||||
### NiceGUI UI (tab_batch_ng.py)
|
|
||||||
|
|
||||||
A resolution series editor is rendered in the left column of the sequence card, directly below
|
|
||||||
the "Specific Negative" textarea.
|
|
||||||
|
|
||||||
Layout:
|
|
||||||
|
|
||||||
```
|
|
||||||
── Resolution Series ──────────────────
|
|
||||||
key name: [upscale_resolutions ]
|
|
||||||
# Width Height
|
|
||||||
1 [2048] [2048] [x]
|
|
||||||
2 [768 ] [1344] [x]
|
|
||||||
3 [1344] [768 ] [x]
|
|
||||||
[+ Add row]
|
|
||||||
```
|
|
||||||
|
|
||||||
- Key name is editable (defaults to `resolutions`)
|
|
||||||
- Rows added/removed inline; each change calls `commit()` immediately
|
|
||||||
- Hidden behind an "Add Resolution Series" button when no resolution key exists yet
|
|
||||||
- A value is detected as a resolution series if it is a list of `[int, int]` pairs
|
|
||||||
|
|
||||||
### ComfyUI Node (`ProjectResolution`)
|
|
||||||
|
|
||||||
New node class in `project_loader.py`, sibling to `ProjectKey`.
|
|
||||||
|
|
||||||
**Inputs:**
|
|
||||||
- `source_label` (STRING) — references a `ProjectSource` by label
|
|
||||||
- `key_name` (STRING) — the resolution series key name
|
|
||||||
- `index` (INT, min 0) — wired from loop node's current index output
|
|
||||||
- `manager_url`, `project_name`, `file_name`, `sequence_number` — optional, synced from `ProjectSource` via JS
|
|
||||||
|
|
||||||
**Outputs:** `width` (INT), `height` (INT)
|
|
||||||
|
|
||||||
**Execution:** fetches the sequence data, reads `data[key_name]`, indexes into the array with
|
|
||||||
clamp-to-last on out-of-bounds, returns `(width, height)`.
|
|
||||||
|
|
||||||
**JS (`web/project_resolution.js`):**
|
|
||||||
- Same `_syncFromSource` mechanism as `project_key.js`
|
|
||||||
- `key_name` widget is replaced with a combo dropdown populated with keys whose value is a
|
|
||||||
resolution series (list of `[int, int]` pairs), detected via the existing keys API
|
|
||||||
- Registered in `PROJECT_NODE_CLASS_MAPPINGS` and `PROJECT_NODE_DISPLAY_NAME_MAPPINGS`
|
|
||||||
|
|
||||||
### API
|
|
||||||
|
|
||||||
No new endpoints. Uses existing:
|
|
||||||
- `/json_manager/get_project_keys` — for key discovery (JS combo population)
|
|
||||||
- `_fetch_data()` — for execution-time data fetch
|
|
||||||
|
|
||||||
### Files Changed
|
|
||||||
|
|
||||||
| File | Change |
|
|
||||||
|------|--------|
|
|
||||||
| `project_loader.py` | Add `ProjectResolution` class + register in mappings |
|
|
||||||
| `web/project_resolution.js` | New JS extension for the node |
|
|
||||||
| `tab_batch_ng.py` | Resolution series editor below Specific Negative |
|
|
||||||
| `__init__.py` | Register new JS file if needed |
|
|
||||||
@@ -1,640 +0,0 @@
|
|||||||
# Resolution Series Implementation Plan
|
|
||||||
|
|
||||||
> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task.
|
|
||||||
|
|
||||||
**Goal:** Add a `ProjectResolution` ComfyUI node and NiceGUI editor that let users define N `(width, height)` pairs per sequence and retrieve them by loop index.
|
|
||||||
|
|
||||||
**Architecture:** Resolution series are stored as a JSON array of `[width, height]` pairs under a user-chosen key in sequence data (e.g. `"upscale_resolutions": [[512,512],[768,1344]]`). A new `ProjectResolution` ComfyUI node (sibling of `ProjectKey`) accepts a `source_label`, `key_name`, and `index` INT from a loop node, and returns `width` + `height`. The NiceGUI sequence card gets an inline table editor placed directly below the "Specific Negative" textarea.
|
|
||||||
|
|
||||||
**Tech Stack:** Python (ComfyUI node), NiceGUI (UI), JavaScript (ComfyUI frontend extension), pytest
|
|
||||||
|
|
||||||
**Branch:** Create and work on `feat/resolution-series` branched from `main`:
|
|
||||||
```bash
|
|
||||||
git checkout main && git checkout -b feat/resolution-series
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### Task 0: Fix pre-existing test failures on `main`
|
|
||||||
|
|
||||||
When `file_name` was added as a second output to `ProjectSource`, two tests were not updated.
|
|
||||||
They fail on `main` before any new code is written.
|
|
||||||
|
|
||||||
**Files:**
|
|
||||||
- Modify: `tests/test_project_loader.py` (`TestProjectSource` class, lines ~216-231)
|
|
||||||
|
|
||||||
**Step 1: Update the two broken tests**
|
|
||||||
|
|
||||||
```python
|
|
||||||
def test_outputs_sequence_number(self):
|
|
||||||
from project_loader import ProjectSource
|
|
||||||
assert ProjectSource.RETURN_TYPES == ("INT", "STRING",)
|
|
||||||
assert ProjectSource.RETURN_NAMES == ("sequence_number", "file_name",)
|
|
||||||
|
|
||||||
def test_hold_config_returns_sequence_number(self):
|
|
||||||
from project_loader import ProjectSource
|
|
||||||
node = ProjectSource()
|
|
||||||
result = node.hold_config(
|
|
||||||
manager_url="http://localhost:8080",
|
|
||||||
project_name="proj1",
|
|
||||||
file_name="batch_i2v",
|
|
||||||
sequence_number=42,
|
|
||||||
label="my_source"
|
|
||||||
)
|
|
||||||
assert result == (42, "batch_i2v")
|
|
||||||
```
|
|
||||||
|
|
||||||
**Step 2: Verify they now pass**
|
|
||||||
|
|
||||||
```bash
|
|
||||||
pytest tests/test_project_loader.py::TestProjectSource -v
|
|
||||||
```
|
|
||||||
Expected: all 4 PASS
|
|
||||||
|
|
||||||
**Step 3: Commit**
|
|
||||||
|
|
||||||
```bash
|
|
||||||
git add tests/test_project_loader.py
|
|
||||||
git commit -m "fix: update ProjectSource tests for file_name output"
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### Task 1: Python node — `ProjectResolution`
|
|
||||||
|
|
||||||
**Files:**
|
|
||||||
- Modify: `project_loader.py` (after the `ProjectKey` class, before `# --- Mappings ---`)
|
|
||||||
- Modify: `tests/test_project_loader.py` (add `TestProjectResolution` class)
|
|
||||||
|
|
||||||
**Step 1: Write failing tests**
|
|
||||||
|
|
||||||
Add this class to `tests/test_project_loader.py`:
|
|
||||||
|
|
||||||
```python
|
|
||||||
class TestProjectResolution:
|
|
||||||
def test_input_types(self):
|
|
||||||
from project_loader import ProjectResolution
|
|
||||||
inputs = ProjectResolution.INPUT_TYPES()
|
|
||||||
assert "source_label" in inputs["required"]
|
|
||||||
assert "key_name" in inputs["required"]
|
|
||||||
assert "index" in inputs["required"]
|
|
||||||
assert inputs["required"]["index"][0] == "INT"
|
|
||||||
|
|
||||||
def test_two_outputs(self):
|
|
||||||
from project_loader import ProjectResolution
|
|
||||||
assert ProjectResolution.RETURN_TYPES == ("INT", "INT")
|
|
||||||
assert ProjectResolution.RETURN_NAMES == ("width", "height")
|
|
||||||
|
|
||||||
def test_fetch_resolution_basic(self):
|
|
||||||
from project_loader import ProjectResolution
|
|
||||||
node = ProjectResolution()
|
|
||||||
data = {"resolutions": [[512, 512], [768, 1344], [1344, 768]]}
|
|
||||||
with patch("project_loader._fetch_data", return_value=data):
|
|
||||||
result = node.fetch_resolution(
|
|
||||||
source_label="src", key_name="resolutions", index=1,
|
|
||||||
manager_url="http://localhost:8080", project_name="p",
|
|
||||||
file_name="f", sequence_number=1,
|
|
||||||
)
|
|
||||||
assert result == (768, 1344)
|
|
||||||
|
|
||||||
def test_fetch_resolution_index_zero(self):
|
|
||||||
from project_loader import ProjectResolution
|
|
||||||
node = ProjectResolution()
|
|
||||||
data = {"resolutions": [[512, 512], [1024, 1024]]}
|
|
||||||
with patch("project_loader._fetch_data", return_value=data):
|
|
||||||
result = node.fetch_resolution(
|
|
||||||
source_label="src", key_name="resolutions", index=0,
|
|
||||||
manager_url="http://localhost:8080", project_name="p",
|
|
||||||
file_name="f", sequence_number=1,
|
|
||||||
)
|
|
||||||
assert result == (512, 512)
|
|
||||||
|
|
||||||
def test_fetch_resolution_clamps_on_out_of_bounds(self):
|
|
||||||
from project_loader import ProjectResolution
|
|
||||||
node = ProjectResolution()
|
|
||||||
data = {"resolutions": [[512, 512], [1024, 1024]]}
|
|
||||||
with patch("project_loader._fetch_data", return_value=data):
|
|
||||||
result = node.fetch_resolution(
|
|
||||||
source_label="src", key_name="resolutions", index=99,
|
|
||||||
manager_url="http://localhost:8080", project_name="p",
|
|
||||||
file_name="f", sequence_number=1,
|
|
||||||
)
|
|
||||||
assert result == (1024, 1024) # last entry
|
|
||||||
|
|
||||||
def test_fetch_resolution_missing_key_returns_defaults(self):
|
|
||||||
from project_loader import ProjectResolution
|
|
||||||
node = ProjectResolution()
|
|
||||||
with patch("project_loader._fetch_data", return_value={}):
|
|
||||||
result = node.fetch_resolution(
|
|
||||||
source_label="src", key_name="nonexistent", index=0,
|
|
||||||
manager_url="http://localhost:8080", project_name="p",
|
|
||||||
file_name="f", sequence_number=1,
|
|
||||||
)
|
|
||||||
assert result == (512, 512)
|
|
||||||
|
|
||||||
def test_fetch_resolution_network_error_returns_defaults(self):
|
|
||||||
from project_loader import ProjectResolution
|
|
||||||
node = ProjectResolution()
|
|
||||||
error_resp = {"error": "network_error", "message": "Connection refused"}
|
|
||||||
with patch("project_loader._fetch_data", return_value=error_resp):
|
|
||||||
result = node.fetch_resolution(
|
|
||||||
source_label="src", key_name="resolutions", index=0,
|
|
||||||
manager_url="http://localhost:8080", project_name="p",
|
|
||||||
file_name="f", sequence_number=1,
|
|
||||||
)
|
|
||||||
assert result == (512, 512)
|
|
||||||
|
|
||||||
def test_category(self):
|
|
||||||
from project_loader import ProjectResolution
|
|
||||||
assert ProjectResolution.CATEGORY == "utils/json/project"
|
|
||||||
```
|
|
||||||
|
|
||||||
**Step 2: Run tests to verify they fail**
|
|
||||||
|
|
||||||
```bash
|
|
||||||
pytest tests/test_project_loader.py::TestProjectResolution -v
|
|
||||||
```
|
|
||||||
Expected: `ImportError: cannot import name 'ProjectResolution'`
|
|
||||||
|
|
||||||
**Step 3: Implement `ProjectResolution` in `project_loader.py`**
|
|
||||||
|
|
||||||
Insert this class after `ProjectKey` (line ~294), before `# --- Mappings ---`:
|
|
||||||
|
|
||||||
```python
|
|
||||||
class ProjectResolution:
|
|
||||||
"""Fetches a (width, height) pair from a resolution series by loop index."""
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(s):
|
|
||||||
return {
|
|
||||||
"required": {
|
|
||||||
"source_label": ("STRING", {"default": "", "multiline": False}),
|
|
||||||
"key_name": ("STRING", {"default": "resolutions", "multiline": False}),
|
|
||||||
"index": ("INT", {"default": 0, "min": 0, "max": 9999}),
|
|
||||||
},
|
|
||||||
"optional": {
|
|
||||||
"manager_url": ("STRING", {"default": "http://localhost:8080", "multiline": False}),
|
|
||||||
"project_name": ("STRING", {"default": "", "multiline": False}),
|
|
||||||
"file_name": ("STRING", {"default": "", "multiline": False}),
|
|
||||||
"sequence_number": ("INT", {"default": 1, "min": 1, "max": 9999}),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
RETURN_TYPES = ("INT", "INT")
|
|
||||||
RETURN_NAMES = ("width", "height")
|
|
||||||
FUNCTION = "fetch_resolution"
|
|
||||||
CATEGORY = "utils/json/project"
|
|
||||||
OUTPUT_NODE = False
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def IS_CHANGED(cls, **kwargs):
|
|
||||||
return float("nan")
|
|
||||||
|
|
||||||
def fetch_resolution(self, source_label, key_name, index,
|
|
||||||
manager_url="http://localhost:8080", project_name="",
|
|
||||||
file_name="", sequence_number=1):
|
|
||||||
sequence_number = int(sequence_number)
|
|
||||||
data = _fetch_data(manager_url, project_name, file_name, sequence_number)
|
|
||||||
if data.get("error") in ("http_error", "network_error", "parse_error"):
|
|
||||||
logger.warning("ProjectResolution.fetch_resolution failed: %s", data.get("message"))
|
|
||||||
return (512, 512)
|
|
||||||
|
|
||||||
series = data.get(key_name)
|
|
||||||
if not isinstance(series, list) or len(series) == 0:
|
|
||||||
logger.warning("ProjectResolution: key '%s' is not a resolution series", key_name)
|
|
||||||
return (512, 512)
|
|
||||||
|
|
||||||
clamped = min(index, len(series) - 1)
|
|
||||||
entry = series[clamped]
|
|
||||||
if not isinstance(entry, (list, tuple)) or len(entry) < 2:
|
|
||||||
logger.warning("ProjectResolution: entry at index %d is malformed: %r", clamped, entry)
|
|
||||||
return (512, 512)
|
|
||||||
|
|
||||||
return (to_int(entry[0]), to_int(entry[1]))
|
|
||||||
```
|
|
||||||
|
|
||||||
**Step 4: Run tests to verify they pass**
|
|
||||||
|
|
||||||
```bash
|
|
||||||
pytest tests/test_project_loader.py::TestProjectResolution -v
|
|
||||||
```
|
|
||||||
Expected: all 7 tests PASS
|
|
||||||
|
|
||||||
**Step 5: Commit**
|
|
||||||
|
|
||||||
```bash
|
|
||||||
git add project_loader.py tests/test_project_loader.py
|
|
||||||
git commit -m "feat: add ProjectResolution node"
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### Task 2: Register `ProjectResolution` in mappings + fix mapping tests
|
|
||||||
|
|
||||||
**Files:**
|
|
||||||
- Modify: `project_loader.py` (mappings section, lines ~297-307)
|
|
||||||
- Modify: `tests/test_project_loader.py` (`TestNodeMappings` class)
|
|
||||||
|
|
||||||
**Step 1: Update mappings in `project_loader.py`**
|
|
||||||
|
|
||||||
Change the mappings at the bottom of the file:
|
|
||||||
|
|
||||||
```python
|
|
||||||
PROJECT_NODE_CLASS_MAPPINGS = {
|
|
||||||
"ProjectLoaderDynamic": ProjectLoaderDynamic,
|
|
||||||
"ProjectSource": ProjectSource,
|
|
||||||
"ProjectKey": ProjectKey,
|
|
||||||
"ProjectResolution": ProjectResolution,
|
|
||||||
}
|
|
||||||
|
|
||||||
PROJECT_NODE_DISPLAY_NAME_MAPPINGS = {
|
|
||||||
"ProjectLoaderDynamic": "Project Loader (Dynamic)",
|
|
||||||
"ProjectSource": "Project Source",
|
|
||||||
"ProjectKey": "Project Key",
|
|
||||||
"ProjectResolution": "Project Resolution",
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
**Step 2: Update the mapping test**
|
|
||||||
|
|
||||||
In `tests/test_project_loader.py`, update `TestNodeMappings.test_mappings_exist`:
|
|
||||||
|
|
||||||
```python
|
|
||||||
class TestNodeMappings:
|
|
||||||
def test_mappings_exist(self):
|
|
||||||
from project_loader import PROJECT_NODE_CLASS_MAPPINGS, PROJECT_NODE_DISPLAY_NAME_MAPPINGS
|
|
||||||
assert "ProjectLoaderDynamic" in PROJECT_NODE_CLASS_MAPPINGS
|
|
||||||
assert "ProjectSource" in PROJECT_NODE_CLASS_MAPPINGS
|
|
||||||
assert "ProjectKey" in PROJECT_NODE_CLASS_MAPPINGS
|
|
||||||
assert "ProjectResolution" in PROJECT_NODE_CLASS_MAPPINGS
|
|
||||||
assert len(PROJECT_NODE_CLASS_MAPPINGS) == 4
|
|
||||||
assert len(PROJECT_NODE_DISPLAY_NAME_MAPPINGS) == 4
|
|
||||||
```
|
|
||||||
|
|
||||||
**Step 3: Run all project_loader tests**
|
|
||||||
|
|
||||||
```bash
|
|
||||||
pytest tests/test_project_loader.py -v
|
|
||||||
```
|
|
||||||
Expected: all tests PASS
|
|
||||||
|
|
||||||
**Step 4: Commit**
|
|
||||||
|
|
||||||
```bash
|
|
||||||
git add project_loader.py tests/test_project_loader.py
|
|
||||||
git commit -m "feat: register ProjectResolution in node mappings"
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### Task 3: NiceGUI resolution series editor in `tab_batch_ng.py`
|
|
||||||
|
|
||||||
**Files:**
|
|
||||||
- Modify: `tab_batch_ng.py`
|
|
||||||
|
|
||||||
The resolution series editor goes inside `splitter.before`, directly after the "Specific Negative" textarea (currently line ~552-553). No new file needed.
|
|
||||||
|
|
||||||
**Step 1: Add the helper function**
|
|
||||||
|
|
||||||
Add this function near the other helper functions at the top of the render section (before `_render_sequence_card`):
|
|
||||||
|
|
||||||
```python
|
|
||||||
def _is_resolution_series(val) -> bool:
|
|
||||||
"""Return True if val is a list of [width, height] int pairs."""
|
|
||||||
if not isinstance(val, list) or len(val) == 0:
|
|
||||||
return False
|
|
||||||
return all(
|
|
||||||
isinstance(entry, (list, tuple)) and len(entry) == 2
|
|
||||||
and all(isinstance(v, (int, float)) for v in entry)
|
|
||||||
for entry in val
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
Note: `Any` is intentionally omitted — `tab_batch_ng.py` does not import `typing.Any`.
|
|
||||||
|
|
||||||
**Step 2: Add the resolution series render section**
|
|
||||||
|
|
||||||
After the "Specific Negative" textarea in `splitter.before` (after line ~553), add:
|
|
||||||
|
|
||||||
```python
|
|
||||||
# --- Resolution Series ---
|
|
||||||
res_keys = [k for k, v in seq.items() if _is_resolution_series(v)]
|
|
||||||
if res_keys:
|
|
||||||
ui.label('Resolution Series').classes('text-caption text-weight-bold q-mt-md')
|
|
||||||
for res_key in res_keys:
|
|
||||||
series: list = seq[res_key]
|
|
||||||
with ui.card().classes('w-full q-pa-sm q-mt-xs').props('flat bordered'):
|
|
||||||
with ui.row().classes('items-center q-mb-xs'):
|
|
||||||
ui.label(res_key).classes('text-caption col')
|
|
||||||
def del_series(k=res_key):
|
|
||||||
del seq[k]
|
|
||||||
commit()
|
|
||||||
ui.button(icon='delete', on_click=del_series).props(
|
|
||||||
'flat dense round size=xs color=negative')
|
|
||||||
with ui.row().classes('text-caption text-grey q-mb-xs'):
|
|
||||||
ui.label('#').style('width:24px')
|
|
||||||
ui.label('Width').classes('col')
|
|
||||||
ui.label('Height').classes('col')
|
|
||||||
ui.label('').style('width:28px')
|
|
||||||
for idx, entry in enumerate(series):
|
|
||||||
with ui.row().classes('items-center w-full'):
|
|
||||||
ui.label(str(idx + 1)).classes('text-caption').style('width:24px')
|
|
||||||
w_inp = ui.number(value=int(entry[0]), min=1, step=1).classes(
|
|
||||||
'col').props('outlined dense hide-bottom-space')
|
|
||||||
h_inp = ui.number(value=int(entry[1]), min=1, step=1).classes(
|
|
||||||
'col').props('outlined dense hide-bottom-space')
|
|
||||||
|
|
||||||
def _sync_wh(i=idx, k=res_key, wi=w_inp, hi=h_inp):
|
|
||||||
seq[k][i] = [
|
|
||||||
int(wi.value) if wi.value else 512,
|
|
||||||
int(hi.value) if hi.value else 512,
|
|
||||||
]
|
|
||||||
commit()
|
|
||||||
|
|
||||||
w_inp.on('blur', lambda _, s=_sync_wh: s())
|
|
||||||
h_inp.on('blur', lambda _, s=_sync_wh: s())
|
|
||||||
|
|
||||||
def del_row(i=idx, k=res_key):
|
|
||||||
seq[k].pop(i)
|
|
||||||
commit()
|
|
||||||
ui.button(icon='remove', on_click=del_row).props(
|
|
||||||
'flat dense round size=xs')
|
|
||||||
|
|
||||||
def add_row(k=res_key):
|
|
||||||
seq[k].append([512, 512])
|
|
||||||
commit()
|
|
||||||
ui.button('+ Add row', icon='add', on_click=add_row).props(
|
|
||||||
'flat dense size=sm').classes('q-mt-xs')
|
|
||||||
|
|
||||||
with ui.expansion('Add Resolution Series', icon='straighten').classes('w-full q-mt-sm'):
|
|
||||||
new_res_key = ui.input('Key name', value='resolutions').props('outlined dense')
|
|
||||||
def add_res_series():
|
|
||||||
k = new_res_key.value.strip()
|
|
||||||
if k and k not in seq:
|
|
||||||
seq[k] = [[512, 512], [1024, 1024]]
|
|
||||||
commit()
|
|
||||||
ui.button('Add', icon='add', on_click=add_res_series).props('outlined dense')
|
|
||||||
```
|
|
||||||
|
|
||||||
**Step 3: Run all tests**
|
|
||||||
|
|
||||||
```bash
|
|
||||||
pytest tests/ -q
|
|
||||||
```
|
|
||||||
Expected: all tests PASS (no Python tests cover the NiceGUI render path, but no regressions)
|
|
||||||
|
|
||||||
**Important:** Also update the `custom_keys` filter in `_render_sequence_card` (line ~648) to exclude
|
|
||||||
resolution series keys — otherwise they'd render in both the resolution editor AND "Custom Parameters":
|
|
||||||
|
|
||||||
```python
|
|
||||||
# Find this line:
|
|
||||||
custom_keys = [k for k in seq.keys() if k not in standard_keys]
|
|
||||||
# Replace with:
|
|
||||||
custom_keys = [k for k in seq.keys() if k not in standard_keys and not _is_resolution_series(seq.get(k))]
|
|
||||||
```
|
|
||||||
|
|
||||||
**Step 4: Commit**
|
|
||||||
|
|
||||||
```bash
|
|
||||||
git add tab_batch_ng.py
|
|
||||||
git commit -m "feat: resolution series editor in sequence card"
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### Task 4: JS extension `web/project_resolution.js`
|
|
||||||
|
|
||||||
**Files:**
|
|
||||||
- Create: `web/project_resolution.js`
|
|
||||||
|
|
||||||
This file mirrors `web/project_key.js` exactly, with two differences:
|
|
||||||
1. It targets `"ProjectResolution"` instead of `"ProjectKey"`
|
|
||||||
2. `_refreshKeys` filters to only show keys whose value is a resolution series (list of `[int, int]` pairs) — but since the keys API only returns key names (not values), the filter is done by naming convention or we just show all keys and let the user pick. For simplicity, show all keys (same as ProjectKey) and let the user pick.
|
|
||||||
3. The `index` widget is **not** hidden — the user wires it from a loop node
|
|
||||||
4. The node has two outputs (`width`, `height`) so no output slot name update is needed
|
|
||||||
|
|
||||||
**Step 1: Create `web/project_resolution.js`**
|
|
||||||
|
|
||||||
```javascript
|
|
||||||
import { app } from "../../scripts/app.js";
|
|
||||||
import { api } from "../../scripts/api.js";
|
|
||||||
|
|
||||||
app.registerExtension({
|
|
||||||
name: "json.manager.project.resolution",
|
|
||||||
|
|
||||||
async beforeQueuePrompt() {
|
|
||||||
if (!app.graph?._nodes) return;
|
|
||||||
for (const node of app.graph._nodes) {
|
|
||||||
if (node.type === "ProjectResolution" && node._syncFromSource) {
|
|
||||||
node._syncFromSource();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
|
|
||||||
async beforeRegisterNodeDef(nodeType, nodeData, app) {
|
|
||||||
if (nodeData.name !== "ProjectResolution") return;
|
|
||||||
|
|
||||||
function hideWidget(widget) {
|
|
||||||
if (widget.origType === undefined) widget.origType = widget.type;
|
|
||||||
widget.type = "hidden";
|
|
||||||
widget.hidden = true;
|
|
||||||
widget.computeSize = () => [0, -4];
|
|
||||||
}
|
|
||||||
|
|
||||||
function replaceWithCombo(node, name, values, callback) {
|
|
||||||
const idx = node.widgets?.findIndex(w => w.name === name);
|
|
||||||
if (idx === -1 || idx === undefined) return null;
|
|
||||||
const oldWidget = node.widgets[idx];
|
|
||||||
const savedValue = oldWidget.value || "";
|
|
||||||
const comboValues = values.length > 0 ? values : [""];
|
|
||||||
if (savedValue && !comboValues.includes(savedValue)) {
|
|
||||||
comboValues.unshift(savedValue);
|
|
||||||
}
|
|
||||||
const defaultValue = savedValue || comboValues[0];
|
|
||||||
node.widgets.splice(idx, 1);
|
|
||||||
const combo = node.addWidget("combo", name, defaultValue, callback, { values: comboValues });
|
|
||||||
if (node.widgets.length > 1) {
|
|
||||||
node.widgets.splice(node.widgets.length - 1, 1);
|
|
||||||
node.widgets.splice(idx, 0, combo);
|
|
||||||
}
|
|
||||||
return combo;
|
|
||||||
}
|
|
||||||
|
|
||||||
const origOnNodeCreated = nodeType.prototype.onNodeCreated;
|
|
||||||
nodeType.prototype.onNodeCreated = function () {
|
|
||||||
origOnNodeCreated?.apply(this, arguments);
|
|
||||||
this._configured = false;
|
|
||||||
|
|
||||||
// Hide synced config widgets (index stays visible — user wires it)
|
|
||||||
for (const name of ["manager_url", "project_name", "file_name", "sequence_number"]) {
|
|
||||||
const w = this.widgets?.find(w => w.name === name);
|
|
||||||
if (w) hideWidget(w);
|
|
||||||
}
|
|
||||||
|
|
||||||
const node = this;
|
|
||||||
const sourceLabels = this._getSourceLabels?.() || [];
|
|
||||||
const srcCombo = replaceWithCombo(this, "source_label", sourceLabels, function (value) {
|
|
||||||
node._syncFromSource();
|
|
||||||
node._refreshKeys();
|
|
||||||
});
|
|
||||||
if (srcCombo) srcCombo.value = sourceLabels[0] || "";
|
|
||||||
|
|
||||||
const keyCombo = replaceWithCombo(this, "key_name", [], function (value) {
|
|
||||||
node.title = value ? `Resolution: ${value}` : "Project Resolution";
|
|
||||||
app.graph?.setDirtyCanvas(true, true);
|
|
||||||
});
|
|
||||||
if (keyCombo) keyCombo.value = "";
|
|
||||||
|
|
||||||
queueMicrotask(() => {
|
|
||||||
if (!this._configured) {
|
|
||||||
this.setSize(this.computeSize());
|
|
||||||
}
|
|
||||||
});
|
|
||||||
};
|
|
||||||
|
|
||||||
nodeType.prototype._getSourceLabels = function () {
|
|
||||||
const seen = new Set();
|
|
||||||
const labels = [];
|
|
||||||
if (!this.graph) return labels;
|
|
||||||
for (const node of this.graph._nodes) {
|
|
||||||
if (node.type === "ProjectSource") {
|
|
||||||
const lw = node.widgets?.find(w => w.name === "label");
|
|
||||||
if (lw?.value && !seen.has(lw.value)) {
|
|
||||||
seen.add(lw.value);
|
|
||||||
labels.push(lw.value);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return labels;
|
|
||||||
};
|
|
||||||
|
|
||||||
nodeType.prototype._findSource = function (label) {
|
|
||||||
if (!this.graph || !label) return null;
|
|
||||||
for (const node of this.graph._nodes) {
|
|
||||||
if (node.type === "ProjectSource") {
|
|
||||||
const lw = node.widgets?.find(w => w.name === "label");
|
|
||||||
if (lw?.value === label) return node;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return null;
|
|
||||||
};
|
|
||||||
|
|
||||||
nodeType.prototype._syncFromSource = function () {
|
|
||||||
const srcWidget = this.widgets?.find(w => w.name === "source_label");
|
|
||||||
const source = this._findSource(srcWidget?.value);
|
|
||||||
if (!source) return;
|
|
||||||
for (const name of ["manager_url", "project_name", "file_name", "sequence_number"]) {
|
|
||||||
const dst = this.widgets?.find(w => w.name === name);
|
|
||||||
const src = source.widgets?.find(w => w.name === name);
|
|
||||||
if (dst && src) dst.value = src.value;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
nodeType.prototype._refreshKeys = async function () {
|
|
||||||
const urlW = this.widgets?.find(w => w.name === "manager_url");
|
|
||||||
const projW = this.widgets?.find(w => w.name === "project_name");
|
|
||||||
const fileW = this.widgets?.find(w => w.name === "file_name");
|
|
||||||
const seqW = this.widgets?.find(w => w.name === "sequence_number");
|
|
||||||
if (!urlW?.value || !projW?.value || !fileW?.value) return;
|
|
||||||
|
|
||||||
try {
|
|
||||||
const resp = await api.fetchApi(
|
|
||||||
`/json_manager/get_project_keys?url=${encodeURIComponent(urlW.value)}&project=${encodeURIComponent(projW.value)}&file=${encodeURIComponent(fileW.value)}&seq=${seqW?.value || 1}`
|
|
||||||
);
|
|
||||||
if (!resp.ok) return;
|
|
||||||
const data = await resp.json();
|
|
||||||
if (data.error || !Array.isArray(data.keys)) return;
|
|
||||||
|
|
||||||
const keyWidget = this.widgets?.find(w => w.name === "key_name");
|
|
||||||
if (keyWidget) {
|
|
||||||
keyWidget.options.values = data.keys.length > 0 ? data.keys : [""];
|
|
||||||
}
|
|
||||||
} catch (e) {
|
|
||||||
console.error("[ProjectResolution] Failed to refresh keys:", e);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
const origOnMouseDown = nodeType.prototype.onMouseDown;
|
|
||||||
nodeType.prototype.onMouseDown = function (e, localPos, graphCanvas) {
|
|
||||||
origOnMouseDown?.apply(this, arguments);
|
|
||||||
const srcWidget = this.widgets?.find(w => w.name === "source_label");
|
|
||||||
if (srcWidget) srcWidget.options.values = this._getSourceLabels();
|
|
||||||
this._syncFromSource();
|
|
||||||
};
|
|
||||||
|
|
||||||
const origOnConfigure = nodeType.prototype.onConfigure;
|
|
||||||
nodeType.prototype.onConfigure = function (info) {
|
|
||||||
origOnConfigure?.apply(this, arguments);
|
|
||||||
this._configured = true;
|
|
||||||
|
|
||||||
for (const name of ["manager_url", "project_name", "file_name", "sequence_number"]) {
|
|
||||||
const w = this.widgets?.find(w => w.name === name);
|
|
||||||
if (w) hideWidget(w);
|
|
||||||
}
|
|
||||||
|
|
||||||
const srcWidget = this.widgets?.find(w => w.name === "source_label");
|
|
||||||
if (srcWidget && srcWidget.type !== "combo") {
|
|
||||||
const node = this;
|
|
||||||
replaceWithCombo(this, "source_label", this._getSourceLabels(), function (value) {
|
|
||||||
node._syncFromSource();
|
|
||||||
node._refreshKeys();
|
|
||||||
});
|
|
||||||
} else if (srcWidget) {
|
|
||||||
srcWidget.options.values = this._getSourceLabels();
|
|
||||||
}
|
|
||||||
|
|
||||||
const keyWidget = this.widgets?.find(w => w.name === "key_name");
|
|
||||||
if (keyWidget && keyWidget.type !== "combo") {
|
|
||||||
const node = this;
|
|
||||||
replaceWithCombo(this, "key_name", [], function (value) {
|
|
||||||
node.title = value ? `Resolution: ${value}` : "Project Resolution";
|
|
||||||
app.graph?.setDirtyCanvas(true, true);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
const finalKeyWidget = this.widgets?.find(w => w.name === "key_name");
|
|
||||||
if (finalKeyWidget?.value) {
|
|
||||||
this.title = `Resolution: ${finalKeyWidget.value}`;
|
|
||||||
}
|
|
||||||
|
|
||||||
this.setSize(this.computeSize());
|
|
||||||
|
|
||||||
const node = this;
|
|
||||||
queueMicrotask(() => {
|
|
||||||
node._syncFromSource();
|
|
||||||
node._refreshKeys();
|
|
||||||
});
|
|
||||||
};
|
|
||||||
},
|
|
||||||
});
|
|
||||||
```
|
|
||||||
|
|
||||||
**Step 2: Run all tests**
|
|
||||||
|
|
||||||
```bash
|
|
||||||
pytest tests/ -q
|
|
||||||
```
|
|
||||||
Expected: all tests PASS (JS has no Python tests)
|
|
||||||
|
|
||||||
**Step 3: Commit**
|
|
||||||
|
|
||||||
```bash
|
|
||||||
git add web/project_resolution.js
|
|
||||||
git commit -m "feat: ProjectResolution JS extension for ComfyUI frontend"
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### Task 5: Final verification and push
|
|
||||||
|
|
||||||
**Step 1: Run full test suite**
|
|
||||||
|
|
||||||
```bash
|
|
||||||
pytest tests/ -v
|
|
||||||
```
|
|
||||||
Expected: all tests PASS
|
|
||||||
|
|
||||||
**Step 2: Push branch**
|
|
||||||
|
|
||||||
```bash
|
|
||||||
git push origin HEAD
|
|
||||||
```
|
|
||||||
@@ -1,67 +0,0 @@
|
|||||||
# BinaryIndexDecoder Node — Design
|
|
||||||
|
|
||||||
## Summary
|
|
||||||
|
|
||||||
A standalone ComfyUI utility node that converts an integer index into 3 boolean
|
|
||||||
outputs using binary (bit-field) encoding. Intended for use with loop counters to
|
|
||||||
gate multiple processing branches simultaneously.
|
|
||||||
|
|
||||||
## Node Spec
|
|
||||||
|
|
||||||
| Field | Value |
|
|
||||||
|---|---|
|
|
||||||
| Class name | `BinaryIndexDecoder` |
|
|
||||||
| Display name | `Binary Index Decoder` |
|
|
||||||
| Category | `JSON Manager/utils` |
|
|
||||||
| Function | `decode` |
|
|
||||||
|
|
||||||
### Inputs
|
|
||||||
|
|
||||||
| Name | Type | Default | Range |
|
|
||||||
|---|---|---|---|
|
|
||||||
| `index` | INT | 0 | 0–7 |
|
|
||||||
|
|
||||||
### Outputs
|
|
||||||
|
|
||||||
| Name | Type |
|
|
||||||
|---|---|
|
|
||||||
| `flag_0` | BOOLEAN |
|
|
||||||
| `flag_1` | BOOLEAN |
|
|
||||||
| `flag_2` | BOOLEAN |
|
|
||||||
|
|
||||||
### Logic
|
|
||||||
|
|
||||||
```
|
|
||||||
flag_0 = bool((index >> 0) & 1)
|
|
||||||
flag_1 = bool((index >> 1) & 1)
|
|
||||||
flag_2 = bool((index >> 2) & 1)
|
|
||||||
```
|
|
||||||
|
|
||||||
### Truth table
|
|
||||||
|
|
||||||
| index | flag_0 | flag_1 | flag_2 |
|
|
||||||
|---|---|---|---|
|
|
||||||
| 0 | F | F | F |
|
|
||||||
| 1 | T | F | F |
|
|
||||||
| 2 | F | T | F |
|
|
||||||
| 3 | T | T | F |
|
|
||||||
| 4 | F | F | T |
|
|
||||||
| 5 | T | F | T |
|
|
||||||
| 6 | F | T | T |
|
|
||||||
| 7 | T | T | T |
|
|
||||||
|
|
||||||
## Implementation Notes
|
|
||||||
|
|
||||||
- Lives in `project_loader.py` alongside other project nodes
|
|
||||||
- Added to `PROJECT_NODE_CLASS_MAPPINGS` and `PROJECT_NODE_DISPLAY_NAME_MAPPINGS`
|
|
||||||
- No JavaScript extension needed (no source sync, no dynamic widgets)
|
|
||||||
- No NiceGUI UI changes needed
|
|
||||||
- `IS_CHANGED` not needed (output is deterministic from input)
|
|
||||||
|
|
||||||
## Testing
|
|
||||||
|
|
||||||
9 tests in `tests/test_project_loader.py::TestBinaryIndexDecoder`:
|
|
||||||
- Input types include `index` as INT
|
|
||||||
- All 8 index values (0–7) produce correct boolean tuple
|
|
||||||
- Out-of-range index (e.g. 8) clamps to 0–7 or wraps gracefully
|
|
||||||
- `NodeMappings` test updated: 5 nodes, mappings length == 5
|
|
||||||
@@ -1,166 +0,0 @@
|
|||||||
# BinaryIndexDecoder Node — Implementation Plan
|
|
||||||
|
|
||||||
> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task.
|
|
||||||
|
|
||||||
**Goal:** Add a standalone ComfyUI node `BinaryIndexDecoder` that converts an integer index to 3 boolean outputs using binary (bit-field) encoding.
|
|
||||||
|
|
||||||
**Architecture:** Single class in `project_loader.py`, no JS extension needed, no NiceGUI changes. Takes `index` INT, returns `(flag_0, flag_1, flag_2)` as BOOLEAN using bit-shift logic. Added to existing node mappings.
|
|
||||||
|
|
||||||
**Tech Stack:** Python, ComfyUI node API, pytest
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### Task 1: Write failing tests for BinaryIndexDecoder
|
|
||||||
|
|
||||||
**Files:**
|
|
||||||
- Modify: `tests/test_project_loader.py` (append new test class at end, before `TestNodeMappings`)
|
|
||||||
- Modify: `tests/test_project_loader.py` — update `TestNodeMappings.test_mappings_exist` to expect 5 nodes
|
|
||||||
|
|
||||||
**Step 1: Add the test class**
|
|
||||||
|
|
||||||
Append this class before `TestNodeMappings` in `tests/test_project_loader.py`:
|
|
||||||
|
|
||||||
```python
|
|
||||||
class TestBinaryIndexDecoder:
|
|
||||||
def test_input_types(self):
|
|
||||||
from project_loader import BinaryIndexDecoder
|
|
||||||
inputs = BinaryIndexDecoder.INPUT_TYPES()
|
|
||||||
assert "index" in inputs["required"]
|
|
||||||
assert inputs["required"]["index"][0] == "INT"
|
|
||||||
|
|
||||||
def test_three_boolean_outputs(self):
|
|
||||||
from project_loader import BinaryIndexDecoder
|
|
||||||
assert BinaryIndexDecoder.RETURN_TYPES == ("BOOLEAN", "BOOLEAN", "BOOLEAN")
|
|
||||||
assert BinaryIndexDecoder.RETURN_NAMES == ("flag_0", "flag_1", "flag_2")
|
|
||||||
|
|
||||||
def test_category(self):
|
|
||||||
from project_loader import BinaryIndexDecoder
|
|
||||||
assert BinaryIndexDecoder.CATEGORY == "JSON Manager/utils"
|
|
||||||
|
|
||||||
def test_index_0(self):
|
|
||||||
from project_loader import BinaryIndexDecoder
|
|
||||||
assert BinaryIndexDecoder().decode(0) == (False, False, False)
|
|
||||||
|
|
||||||
def test_index_1(self):
|
|
||||||
from project_loader import BinaryIndexDecoder
|
|
||||||
assert BinaryIndexDecoder().decode(1) == (True, False, False)
|
|
||||||
|
|
||||||
def test_index_2(self):
|
|
||||||
from project_loader import BinaryIndexDecoder
|
|
||||||
assert BinaryIndexDecoder().decode(2) == (False, True, False)
|
|
||||||
|
|
||||||
def test_index_3(self):
|
|
||||||
from project_loader import BinaryIndexDecoder
|
|
||||||
assert BinaryIndexDecoder().decode(3) == (True, True, False)
|
|
||||||
|
|
||||||
def test_index_4(self):
|
|
||||||
from project_loader import BinaryIndexDecoder
|
|
||||||
assert BinaryIndexDecoder().decode(4) == (False, False, True)
|
|
||||||
|
|
||||||
def test_index_7(self):
|
|
||||||
from project_loader import BinaryIndexDecoder
|
|
||||||
assert BinaryIndexDecoder().decode(7) == (True, True, True)
|
|
||||||
```
|
|
||||||
|
|
||||||
Also update `TestNodeMappings.test_mappings_exist`:
|
|
||||||
|
|
||||||
```python
|
|
||||||
def test_mappings_exist(self):
|
|
||||||
from project_loader import PROJECT_NODE_CLASS_MAPPINGS, PROJECT_NODE_DISPLAY_NAME_MAPPINGS
|
|
||||||
assert "ProjectLoaderDynamic" in PROJECT_NODE_CLASS_MAPPINGS
|
|
||||||
assert "ProjectSource" in PROJECT_NODE_CLASS_MAPPINGS
|
|
||||||
assert "ProjectKey" in PROJECT_NODE_CLASS_MAPPINGS
|
|
||||||
assert "ProjectResolution" in PROJECT_NODE_CLASS_MAPPINGS
|
|
||||||
assert "BinaryIndexDecoder" in PROJECT_NODE_CLASS_MAPPINGS
|
|
||||||
assert len(PROJECT_NODE_CLASS_MAPPINGS) == 5
|
|
||||||
assert len(PROJECT_NODE_DISPLAY_NAME_MAPPINGS) == 5
|
|
||||||
```
|
|
||||||
|
|
||||||
**Step 2: Run tests to verify they fail**
|
|
||||||
|
|
||||||
```bash
|
|
||||||
python -m pytest tests/test_project_loader.py::TestBinaryIndexDecoder -v
|
|
||||||
```
|
|
||||||
|
|
||||||
Expected: FAIL with `ImportError: cannot import name 'BinaryIndexDecoder'`
|
|
||||||
|
|
||||||
**Step 3: Commit the failing tests**
|
|
||||||
|
|
||||||
```bash
|
|
||||||
git add tests/test_project_loader.py
|
|
||||||
git commit -m "test: add failing tests for BinaryIndexDecoder node"
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### Task 2: Implement BinaryIndexDecoder
|
|
||||||
|
|
||||||
**Files:**
|
|
||||||
- Modify: `project_loader.py` — add class after `ProjectResolution`, update mappings
|
|
||||||
|
|
||||||
**Step 1: Add the class**
|
|
||||||
|
|
||||||
Insert after the `ProjectResolution` class (before `# --- Mappings ---`) in `project_loader.py`:
|
|
||||||
|
|
||||||
```python
|
|
||||||
class BinaryIndexDecoder:
|
|
||||||
"""Decodes an integer index into 3 boolean flags using binary (bit-field) encoding.
|
|
||||||
|
|
||||||
index 0 → (False, False, False)
|
|
||||||
index 1 → (True, False, False) # bit 0
|
|
||||||
index 2 → (False, True, False) # bit 1
|
|
||||||
index 3 → (True, True, False) # bits 0+1
|
|
||||||
index 4 → (False, False, True) # bit 2
|
|
||||||
...
|
|
||||||
index 7 → (True, True, True)
|
|
||||||
"""
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(s):
|
|
||||||
return {
|
|
||||||
"required": {
|
|
||||||
"index": ("INT", {"default": 0, "min": 0, "max": 7}),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
RETURN_TYPES = ("BOOLEAN", "BOOLEAN", "BOOLEAN")
|
|
||||||
RETURN_NAMES = ("flag_0", "flag_1", "flag_2")
|
|
||||||
FUNCTION = "decode"
|
|
||||||
CATEGORY = "JSON Manager/utils"
|
|
||||||
OUTPUT_NODE = False
|
|
||||||
|
|
||||||
def decode(self, index: int):
|
|
||||||
return (
|
|
||||||
bool((index >> 0) & 1),
|
|
||||||
bool((index >> 1) & 1),
|
|
||||||
bool((index >> 2) & 1),
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
**Step 2: Update mappings**
|
|
||||||
|
|
||||||
In `PROJECT_NODE_CLASS_MAPPINGS`, add:
|
|
||||||
```python
|
|
||||||
"BinaryIndexDecoder": BinaryIndexDecoder,
|
|
||||||
```
|
|
||||||
|
|
||||||
In `PROJECT_NODE_DISPLAY_NAME_MAPPINGS`, add:
|
|
||||||
```python
|
|
||||||
"BinaryIndexDecoder": "Binary Index Decoder",
|
|
||||||
```
|
|
||||||
|
|
||||||
**Step 3: Run all tests**
|
|
||||||
|
|
||||||
```bash
|
|
||||||
python -m pytest tests/test_project_loader.py -v
|
|
||||||
```
|
|
||||||
|
|
||||||
Expected: all tests PASS (42 existing + 10 new = 52 total)
|
|
||||||
|
|
||||||
**Step 4: Commit**
|
|
||||||
|
|
||||||
```bash
|
|
||||||
git add project_loader.py tests/test_project_loader.py
|
|
||||||
git commit -m "feat: add BinaryIndexDecoder node (INT index → 3 BOOLEANs, binary encoding)"
|
|
||||||
git push
|
|
||||||
```
|
|
||||||
+33
-147
@@ -1,27 +1,19 @@
|
|||||||
import html
|
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
KEY_PROMPT_HISTORY = "prompt_history"
|
|
||||||
|
|
||||||
|
|
||||||
class HistoryTree:
|
class HistoryTree:
|
||||||
def __init__(self, raw_data: dict[str, Any]) -> None:
|
def __init__(self, raw_data):
|
||||||
self.nodes: dict[str, dict[str, Any]] = raw_data.get("nodes", {})
|
self.nodes = raw_data.get("nodes", {})
|
||||||
self.branches: dict[str, str | None] = raw_data.get("branches", {"main": None})
|
self.branches = raw_data.get("branches", {"main": None})
|
||||||
self.head_id: str | None = raw_data.get("head_id", None)
|
self.head_id = raw_data.get("head_id", None)
|
||||||
|
|
||||||
if KEY_PROMPT_HISTORY in raw_data and isinstance(raw_data[KEY_PROMPT_HISTORY], list) and not self.nodes:
|
if "prompt_history" in raw_data and isinstance(raw_data["prompt_history"], list) and not self.nodes:
|
||||||
self._migrate_legacy(raw_data[KEY_PROMPT_HISTORY])
|
self._migrate_legacy(raw_data["prompt_history"])
|
||||||
|
|
||||||
def _migrate_legacy(self, old_list: list[dict[str, Any]]) -> None:
|
def _migrate_legacy(self, old_list):
|
||||||
parent = None
|
parent = None
|
||||||
for item in reversed(old_list):
|
for item in reversed(old_list):
|
||||||
for _ in range(10):
|
|
||||||
node_id = str(uuid.uuid4())[:8]
|
node_id = str(uuid.uuid4())[:8]
|
||||||
if node_id not in self.nodes:
|
|
||||||
break
|
|
||||||
self.nodes[node_id] = {
|
self.nodes[node_id] = {
|
||||||
"id": node_id, "parent": parent, "timestamp": time.time(),
|
"id": node_id, "parent": parent, "timestamp": time.time(),
|
||||||
"data": item, "note": item.get("note", "Legacy Import")
|
"data": item, "note": item.get("note", "Legacy Import")
|
||||||
@@ -30,25 +22,8 @@ class HistoryTree:
|
|||||||
self.branches["main"] = parent
|
self.branches["main"] = parent
|
||||||
self.head_id = parent
|
self.head_id = parent
|
||||||
|
|
||||||
def commit(self, data: dict[str, Any], note: str = "Snapshot") -> str:
|
def commit(self, data, note="Snapshot"):
|
||||||
# Generate unique node ID with collision check
|
|
||||||
for _ in range(10):
|
|
||||||
new_id = str(uuid.uuid4())[:8]
|
new_id = str(uuid.uuid4())[:8]
|
||||||
if new_id not in self.nodes:
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
raise ValueError("Failed to generate unique node ID after 10 attempts")
|
|
||||||
|
|
||||||
# Cycle detection: walk parent chain from head to verify no cycle
|
|
||||||
if self.head_id:
|
|
||||||
visited = set()
|
|
||||||
current = self.head_id
|
|
||||||
while current:
|
|
||||||
if current in visited:
|
|
||||||
raise ValueError(f"Cycle detected in history tree at node {current}")
|
|
||||||
visited.add(current)
|
|
||||||
node = self.nodes.get(current)
|
|
||||||
current = node.get("parent") if node else None
|
|
||||||
|
|
||||||
active_branch = None
|
active_branch = None
|
||||||
for b_name, tip_id in self.branches.items():
|
for b_name, tip_id in self.branches.items():
|
||||||
@@ -70,159 +45,70 @@ class HistoryTree:
|
|||||||
self.head_id = new_id
|
self.head_id = new_id
|
||||||
return new_id
|
return new_id
|
||||||
|
|
||||||
def checkout(self, node_id: str) -> dict[str, Any] | None:
|
def checkout(self, node_id):
|
||||||
if node_id in self.nodes:
|
if node_id in self.nodes:
|
||||||
self.head_id = node_id
|
self.head_id = node_id
|
||||||
return self.nodes[node_id]["data"]
|
return self.nodes[node_id]["data"]
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def strip_snapshots(self) -> None:
|
def to_dict(self):
|
||||||
"""Remove snapshot data from all nodes to free memory."""
|
|
||||||
for node in self.nodes.values():
|
|
||||||
node.pop("data", None)
|
|
||||||
|
|
||||||
def to_dict(self) -> dict[str, Any]:
|
|
||||||
return {"nodes": self.nodes, "branches": self.branches, "head_id": self.head_id}
|
return {"nodes": self.nodes, "branches": self.branches, "head_id": self.head_id}
|
||||||
|
|
||||||
# --- UPDATED GRAPH GENERATOR ---
|
# --- UPDATED GRAPH GENERATOR ---
|
||||||
def generate_graph(self, direction: str = "LR") -> str:
|
def generate_graph(self, direction="LR"):
|
||||||
"""
|
"""
|
||||||
Generates Graphviz source.
|
Generates Graphviz source.
|
||||||
direction: "LR" (Horizontal) or "TB" (Vertical)
|
direction: "LR" (Horizontal) or "TB" (Vertical)
|
||||||
"""
|
"""
|
||||||
node_count = len(self.nodes)
|
|
||||||
is_vertical = direction == "TB"
|
|
||||||
|
|
||||||
# Vertical mode uses much tighter spacing
|
|
||||||
if is_vertical:
|
|
||||||
if node_count <= 5:
|
|
||||||
nodesep, ranksep = 0.3, 0.2
|
|
||||||
elif node_count <= 15:
|
|
||||||
nodesep, ranksep = 0.2, 0.15
|
|
||||||
else:
|
|
||||||
nodesep, ranksep = 0.1, 0.1
|
|
||||||
else:
|
|
||||||
if node_count <= 5:
|
|
||||||
nodesep, ranksep = 0.5, 0.6
|
|
||||||
elif node_count <= 15:
|
|
||||||
nodesep, ranksep = 0.3, 0.4
|
|
||||||
else:
|
|
||||||
nodesep, ranksep = 0.15, 0.25
|
|
||||||
|
|
||||||
# Build reverse lookup: branch tip -> branch name(s)
|
|
||||||
tip_to_branches: dict[str, list[str]] = {}
|
|
||||||
for b_name, tip_id in self.branches.items():
|
|
||||||
if tip_id:
|
|
||||||
tip_to_branches.setdefault(tip_id, []).append(b_name)
|
|
||||||
|
|
||||||
dot = [
|
dot = [
|
||||||
'digraph History {',
|
'digraph History {',
|
||||||
f' rankdir={direction};',
|
f' rankdir={direction};', # Dynamic Direction
|
||||||
' bgcolor="white";',
|
' bgcolor="white";',
|
||||||
' splines=polyline;',
|
' splines=ortho;',
|
||||||
f' nodesep={nodesep};',
|
|
||||||
f' ranksep={ranksep};',
|
# TIGHT SPACING
|
||||||
|
' nodesep=0.2;',
|
||||||
|
' ranksep=0.3;',
|
||||||
|
|
||||||
|
# GLOBAL STYLES
|
||||||
' node [shape=plain, fontname="Arial"];',
|
' node [shape=plain, fontname="Arial"];',
|
||||||
' edge [color="#888888", arrowsize=0.6, penwidth=1.0];'
|
' edge [color="#888888", arrowsize=0.6, penwidth=1.0];'
|
||||||
]
|
]
|
||||||
|
|
||||||
# Build reverse lookup: node_id -> branch name (walk each branch ancestry)
|
|
||||||
node_to_branch: dict[str, str] = {}
|
|
||||||
for b_name, tip_id in self.branches.items():
|
|
||||||
visited = set()
|
|
||||||
current = tip_id
|
|
||||||
while current and current in self.nodes:
|
|
||||||
if current in visited:
|
|
||||||
break
|
|
||||||
if current in node_to_branch:
|
|
||||||
break # this node and all ancestors already assigned
|
|
||||||
visited.add(current)
|
|
||||||
node_to_branch[current] = b_name
|
|
||||||
current = self.nodes[current].get('parent')
|
|
||||||
|
|
||||||
# Per-branch color palette (bg, border) — cycles for many branches
|
|
||||||
_branch_palette = [
|
|
||||||
('#f9f9f9', '#999999'), # grey (default/main)
|
|
||||||
('#eef4ff', '#6699cc'), # blue
|
|
||||||
('#f5eeff', '#9977cc'), # purple
|
|
||||||
('#fff0ee', '#cc7766'), # coral
|
|
||||||
('#eefff5', '#66aa88'), # teal
|
|
||||||
('#fff8ee', '#ccaa55'), # sand
|
|
||||||
]
|
|
||||||
branch_names = list(self.branches.keys())
|
|
||||||
branch_colors = {
|
|
||||||
b: _branch_palette[i % len(_branch_palette)]
|
|
||||||
for i, b in enumerate(branch_names)
|
|
||||||
}
|
|
||||||
|
|
||||||
sorted_nodes = sorted(self.nodes.values(), key=lambda x: x["timestamp"])
|
sorted_nodes = sorted(self.nodes.values(), key=lambda x: x["timestamp"])
|
||||||
|
|
||||||
# Font sizes and padding - smaller for vertical
|
|
||||||
if is_vertical:
|
|
||||||
note_font_size = 8
|
|
||||||
meta_font_size = 7
|
|
||||||
cell_padding = 2
|
|
||||||
max_note_len = 18
|
|
||||||
else:
|
|
||||||
note_font_size = 10
|
|
||||||
meta_font_size = 8
|
|
||||||
cell_padding = 4
|
|
||||||
max_note_len = 25
|
|
||||||
|
|
||||||
for n in sorted_nodes:
|
for n in sorted_nodes:
|
||||||
nid = n["id"]
|
nid = n["id"]
|
||||||
full_note = n.get('note', 'Step')
|
full_note = n.get('note', 'Step')
|
||||||
|
|
||||||
display_note = (full_note[:max_note_len] + '..') if len(full_note) > max_note_len else full_note
|
display_note = (full_note[:15] + '..') if len(full_note) > 15 else full_note
|
||||||
display_note = html.escape(display_note)
|
|
||||||
|
|
||||||
ts = time.strftime('%b %d %H:%M', time.localtime(n['timestamp']))
|
# COLORS
|
||||||
|
bg_color = "#f9f9f9"
|
||||||
# Branch label for tip nodes
|
border_color = "#999999"
|
||||||
branch_label = ""
|
|
||||||
if nid in tip_to_branches:
|
|
||||||
branch_label = html.escape(", ".join(tip_to_branches[nid]))
|
|
||||||
|
|
||||||
# COLORS — per-branch tint, overridden for HEAD and tips
|
|
||||||
b_name = node_to_branch.get(nid)
|
|
||||||
bg_color, border_color = branch_colors.get(
|
|
||||||
b_name, _branch_palette[0])
|
|
||||||
border_width = "1"
|
border_width = "1"
|
||||||
|
|
||||||
if nid == self.head_id:
|
if nid == self.head_id:
|
||||||
bg_color = "#fff6cd"
|
bg_color = "#fff6cd" # Yellow for Current
|
||||||
border_color = "#eebb00"
|
border_color = "#eebb00"
|
||||||
border_width = "2"
|
border_width = "2"
|
||||||
elif nid in self.branches.values():
|
elif nid in self.branches.values():
|
||||||
bg_color = "#e6ffe6"
|
bg_color = "#e6ffe6" # Green for Tips
|
||||||
border_color = "#66aa66"
|
border_color = "#66aa66"
|
||||||
|
|
||||||
# HTML LABEL
|
# HTML LABEL
|
||||||
rows = [
|
|
||||||
f'<TR><TD><B><FONT POINT-SIZE="{note_font_size}">{display_note}</FONT></B></TD></TR>',
|
|
||||||
f'<TR><TD><FONT POINT-SIZE="{meta_font_size}" COLOR="#555555">{ts} • {nid[:4]}</FONT></TD></TR>',
|
|
||||||
]
|
|
||||||
if branch_label:
|
|
||||||
rows.append(f'<TR><TD><FONT POINT-SIZE="{meta_font_size}" COLOR="#4488cc"><I>{branch_label}</I></FONT></TD></TR>')
|
|
||||||
|
|
||||||
label = (
|
label = (
|
||||||
f'<<TABLE BORDER="{border_width}" CELLBORDER="0" CELLSPACING="0" CELLPADDING="{cell_padding}" BGCOLOR="{bg_color}" COLOR="{border_color}">'
|
f'<<TABLE BORDER="{border_width}" CELLBORDER="0" CELLSPACING="0" CELLPADDING="4" BGCOLOR="{bg_color}" COLOR="{border_color}">'
|
||||||
+ "".join(rows)
|
f'<TR><TD><B><FONT POINT-SIZE="10">{display_note}</FONT></B></TD></TR>'
|
||||||
+ '</TABLE>>'
|
f'<TR><TD><FONT POINT-SIZE="8" COLOR="#555555">{nid[:4]}</FONT></TD></TR>'
|
||||||
|
f'</TABLE>>'
|
||||||
)
|
)
|
||||||
|
|
||||||
safe_tooltip = (full_note
|
safe_tooltip = full_note.replace('"', "'")
|
||||||
.replace('\\', '\\\\')
|
dot.append(f' "{nid}" [label={label}, tooltip="{safe_tooltip}"];')
|
||||||
.replace('"', '\\"')
|
|
||||||
.replace('\n', ' ')
|
|
||||||
.replace('\r', '')
|
|
||||||
.replace(']', ']'))
|
|
||||||
safe_nid = nid.replace('"', '_')
|
|
||||||
dot.append(f' "{safe_nid}" [label={label}, tooltip="{safe_tooltip}"];')
|
|
||||||
|
|
||||||
if n.get("parent") and n["parent"] in self.nodes:
|
if n["parent"] and n["parent"] in self.nodes:
|
||||||
safe_parent = n["parent"].replace('"', '_')
|
dot.append(f' "{n["parent"]}" -> "{nid}";')
|
||||||
dot.append(f' "{safe_parent}" -> "{safe_nid}";')
|
|
||||||
|
|
||||||
dot.append("}")
|
dot.append("}")
|
||||||
return "\n".join(dot)
|
return "\n".join(dot)
|
||||||
|
|||||||
+296
@@ -0,0 +1,296 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
|
||||||
|
# --- Shared Helper ---
|
||||||
|
def read_json_data(json_path):
|
||||||
|
if not os.path.exists(json_path):
|
||||||
|
print(f"[JSON Loader] Warning: File not found at {json_path}")
|
||||||
|
return {}
|
||||||
|
try:
|
||||||
|
with open(json_path, 'r') as f:
|
||||||
|
return json.load(f)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[JSON Loader] Error: {e}")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
# ==========================================
|
||||||
|
# 1. STANDARD NODES (Single File)
|
||||||
|
# ==========================================
|
||||||
|
|
||||||
|
class JSONLoaderLoRA:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": {"json_path": ("STRING", {"default": "", "multiline": False})}}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("STRING", "STRING", "STRING", "STRING", "STRING", "STRING")
|
||||||
|
RETURN_NAMES = ("lora_1_high", "lora_1_low", "lora_2_high", "lora_2_low", "lora_3_high", "lora_3_low")
|
||||||
|
FUNCTION = "load_loras"
|
||||||
|
CATEGORY = "utils/json"
|
||||||
|
|
||||||
|
def load_loras(self, json_path):
|
||||||
|
data = read_json_data(json_path)
|
||||||
|
return (
|
||||||
|
str(data.get("lora 1 high", "")), str(data.get("lora 1 low", "")),
|
||||||
|
str(data.get("lora 2 high", "")), str(data.get("lora 2 low", "")),
|
||||||
|
str(data.get("lora 3 high", "")), str(data.get("lora 3 low", ""))
|
||||||
|
)
|
||||||
|
|
||||||
|
class JSONLoaderStandard:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": {"json_path": ("STRING", {"default": "", "multiline": False})}}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("STRING", "STRING", "STRING", "STRING", "STRING", "FLOAT", "INT", "STRING", "STRING", "STRING")
|
||||||
|
RETURN_NAMES = ("general_prompt", "general_negative", "current_prompt", "negative", "camera", "flf", "seed", "video_file_path", "reference_image_path", "flf_image_path")
|
||||||
|
FUNCTION = "load_standard"
|
||||||
|
CATEGORY = "utils/json"
|
||||||
|
|
||||||
|
def load_standard(self, json_path):
|
||||||
|
data = read_json_data(json_path)
|
||||||
|
def to_float(val):
|
||||||
|
try: return float(val)
|
||||||
|
except: return 0.0
|
||||||
|
def to_int(val):
|
||||||
|
try: return int(float(val))
|
||||||
|
except: return 0
|
||||||
|
|
||||||
|
return (
|
||||||
|
str(data.get("general_prompt", "")), str(data.get("general_negative", "")),
|
||||||
|
str(data.get("current_prompt", "")), str(data.get("negative", "")),
|
||||||
|
str(data.get("camera", "")), to_float(data.get("flf", 0.0)),
|
||||||
|
to_int(data.get("seed", 0)), str(data.get("video file path", "")),
|
||||||
|
str(data.get("reference image path", "")), str(data.get("flf image path", ""))
|
||||||
|
)
|
||||||
|
|
||||||
|
class JSONLoaderVACE:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": {"json_path": ("STRING", {"default": "", "multiline": False})}}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("STRING", "STRING", "STRING", "STRING", "STRING", "FLOAT", "INT", "INT", "INT", "INT", "STRING", "INT", "INT", "STRING", "STRING")
|
||||||
|
RETURN_NAMES = ("general_prompt", "general_negative", "current_prompt", "negative", "camera", "flf", "seed", "frame_to_skip", "input_a_frames", "input_b_frames", "reference_path", "reference_switch", "vace_schedule", "video_file_path", "reference_image_path")
|
||||||
|
FUNCTION = "load_vace"
|
||||||
|
CATEGORY = "utils/json"
|
||||||
|
|
||||||
|
def load_vace(self, json_path):
|
||||||
|
data = read_json_data(json_path)
|
||||||
|
def to_float(val):
|
||||||
|
try: return float(val)
|
||||||
|
except: return 0.0
|
||||||
|
def to_int(val):
|
||||||
|
try: return int(float(val))
|
||||||
|
except: return 0
|
||||||
|
|
||||||
|
return (
|
||||||
|
str(data.get("general_prompt", "")), str(data.get("general_negative", "")),
|
||||||
|
str(data.get("current_prompt", "")), str(data.get("negative", "")),
|
||||||
|
str(data.get("camera", "")), to_float(data.get("flf", 0.0)),
|
||||||
|
to_int(data.get("seed", 0)),
|
||||||
|
to_int(data.get("frame_to_skip", 81)), to_int(data.get("input_a_frames", 0)),
|
||||||
|
to_int(data.get("input_b_frames", 0)), str(data.get("reference path", "")),
|
||||||
|
to_int(data.get("reference switch", 1)), to_int(data.get("vace schedule", 1)),
|
||||||
|
str(data.get("video file path", "")), str(data.get("reference image path", ""))
|
||||||
|
)
|
||||||
|
|
||||||
|
# ==========================================
|
||||||
|
# 2. BATCH NODES
|
||||||
|
# ==========================================
|
||||||
|
|
||||||
|
class JSONLoaderBatchLoRA:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": {"json_path": ("STRING", {"default": "", "multiline": False}), "sequence_number": ("INT", {"default": 1, "min": 1, "max": 9999})}}
|
||||||
|
RETURN_TYPES = ("STRING", "STRING", "STRING", "STRING", "STRING", "STRING")
|
||||||
|
RETURN_NAMES = ("lora_1_high", "lora_1_low", "lora_2_high", "lora_2_low", "lora_3_high", "lora_3_low")
|
||||||
|
FUNCTION = "load_batch_loras"
|
||||||
|
CATEGORY = "utils/json"
|
||||||
|
|
||||||
|
def load_batch_loras(self, json_path, sequence_number):
|
||||||
|
data = read_json_data(json_path)
|
||||||
|
target_data = data
|
||||||
|
if "batch_data" in data and isinstance(data["batch_data"], list) and len(data["batch_data"]) > 0:
|
||||||
|
idx = (sequence_number - 1) % len(data["batch_data"])
|
||||||
|
target_data = data["batch_data"][idx]
|
||||||
|
return (
|
||||||
|
str(target_data.get("lora 1 high", "")), str(target_data.get("lora 1 low", "")),
|
||||||
|
str(target_data.get("lora 2 high", "")), str(target_data.get("lora 2 low", "")),
|
||||||
|
str(target_data.get("lora 3 high", "")), str(target_data.get("lora 3 low", ""))
|
||||||
|
)
|
||||||
|
|
||||||
|
class JSONLoaderBatchI2V:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": {"json_path": ("STRING", {"default": "", "multiline": False}), "sequence_number": ("INT", {"default": 1, "min": 1, "max": 9999})}}
|
||||||
|
RETURN_TYPES = ("STRING", "STRING", "STRING", "STRING", "STRING", "FLOAT", "INT", "STRING", "STRING", "STRING")
|
||||||
|
RETURN_NAMES = ("general_prompt", "general_negative", "current_prompt", "negative", "camera", "flf", "seed", "video_file_path", "reference_image_path", "flf_image_path")
|
||||||
|
FUNCTION = "load_batch_i2v"
|
||||||
|
CATEGORY = "utils/json"
|
||||||
|
|
||||||
|
def load_batch_i2v(self, json_path, sequence_number):
|
||||||
|
data = read_json_data(json_path)
|
||||||
|
target_data = data
|
||||||
|
if "batch_data" in data and isinstance(data["batch_data"], list) and len(data["batch_data"]) > 0:
|
||||||
|
idx = (sequence_number - 1) % len(data["batch_data"])
|
||||||
|
target_data = data["batch_data"][idx]
|
||||||
|
def to_float(val):
|
||||||
|
try: return float(val)
|
||||||
|
except: return 0.0
|
||||||
|
def to_int(val):
|
||||||
|
try: return int(float(val))
|
||||||
|
except: return 0
|
||||||
|
return (
|
||||||
|
str(target_data.get("general_prompt", "")), str(target_data.get("general_negative", "")),
|
||||||
|
str(target_data.get("current_prompt", "")), str(target_data.get("negative", "")),
|
||||||
|
str(target_data.get("camera", "")), to_float(target_data.get("flf", 0.0)),
|
||||||
|
to_int(target_data.get("seed", 0)), str(target_data.get("video file path", "")),
|
||||||
|
str(target_data.get("reference image path", "")), str(target_data.get("flf image path", ""))
|
||||||
|
)
|
||||||
|
|
||||||
|
class JSONLoaderBatchVACE:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": {"json_path": ("STRING", {"default": "", "multiline": False}), "sequence_number": ("INT", {"default": 1, "min": 1, "max": 9999})}}
|
||||||
|
RETURN_TYPES = ("STRING", "STRING", "STRING", "STRING", "STRING", "FLOAT", "INT", "INT", "INT", "INT", "STRING", "INT", "INT", "STRING", "STRING")
|
||||||
|
RETURN_NAMES = ("general_prompt", "general_negative", "current_prompt", "negative", "camera", "flf", "seed", "frame_to_skip", "input_a_frames", "input_b_frames", "reference_path", "reference_switch", "vace_schedule", "video_file_path", "reference_image_path")
|
||||||
|
FUNCTION = "load_batch_vace"
|
||||||
|
CATEGORY = "utils/json"
|
||||||
|
|
||||||
|
def load_batch_vace(self, json_path, sequence_number):
|
||||||
|
data = read_json_data(json_path)
|
||||||
|
target_data = data
|
||||||
|
if "batch_data" in data and isinstance(data["batch_data"], list) and len(data["batch_data"]) > 0:
|
||||||
|
idx = (sequence_number - 1) % len(data["batch_data"])
|
||||||
|
target_data = data["batch_data"][idx]
|
||||||
|
def to_float(val):
|
||||||
|
try: return float(val)
|
||||||
|
except: return 0.0
|
||||||
|
def to_int(val):
|
||||||
|
try: return int(float(val))
|
||||||
|
except: return 0
|
||||||
|
return (
|
||||||
|
str(target_data.get("general_prompt", "")), str(target_data.get("general_negative", "")),
|
||||||
|
str(target_data.get("current_prompt", "")), str(target_data.get("negative", "")),
|
||||||
|
str(target_data.get("camera", "")), to_float(target_data.get("flf", 0.0)),
|
||||||
|
to_int(target_data.get("seed", 0)), to_int(target_data.get("frame_to_skip", 81)),
|
||||||
|
to_int(target_data.get("input_a_frames", 0)), to_int(target_data.get("input_b_frames", 0)),
|
||||||
|
str(target_data.get("reference path", "")), to_int(target_data.get("reference switch", 1)),
|
||||||
|
to_int(target_data.get("vace schedule", 1)), str(target_data.get("video file path", "")),
|
||||||
|
str(target_data.get("reference image path", ""))
|
||||||
|
)
|
||||||
|
|
||||||
|
# ==========================================
|
||||||
|
# 3. UNIVERSAL CUSTOM NODES (1, 3, 6 Slots)
|
||||||
|
# ==========================================
|
||||||
|
|
||||||
|
class JSONLoaderCustom1:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"json_path": ("STRING", {"default": "", "multiline": False}),
|
||||||
|
"sequence_number": ("INT", {"default": 1, "min": 1, "max": 9999}),
|
||||||
|
},
|
||||||
|
"optional": { "key_1": ("STRING", {"default": "", "multiline": False}) }
|
||||||
|
}
|
||||||
|
RETURN_TYPES = ("STRING",)
|
||||||
|
RETURN_NAMES = ("val_1",)
|
||||||
|
FUNCTION = "load_custom"
|
||||||
|
CATEGORY = "utils/json"
|
||||||
|
|
||||||
|
def load_custom(self, json_path, sequence_number, key_1=""):
|
||||||
|
data = read_json_data(json_path)
|
||||||
|
target_data = data
|
||||||
|
if "batch_data" in data and isinstance(data["batch_data"], list) and len(data["batch_data"]) > 0:
|
||||||
|
idx = (sequence_number - 1) % len(data["batch_data"])
|
||||||
|
target_data = data["batch_data"][idx]
|
||||||
|
return (str(target_data.get(key_1, "")),)
|
||||||
|
|
||||||
|
class JSONLoaderCustom3:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"json_path": ("STRING", {"default": "", "multiline": False}),
|
||||||
|
"sequence_number": ("INT", {"default": 1, "min": 1, "max": 9999}),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"key_1": ("STRING", {"default": "", "multiline": False}),
|
||||||
|
"key_2": ("STRING", {"default": "", "multiline": False}),
|
||||||
|
"key_3": ("STRING", {"default": "", "multiline": False})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
RETURN_TYPES = ("STRING", "STRING", "STRING")
|
||||||
|
RETURN_NAMES = ("val_1", "val_2", "val_3")
|
||||||
|
FUNCTION = "load_custom"
|
||||||
|
CATEGORY = "utils/json"
|
||||||
|
|
||||||
|
def load_custom(self, json_path, sequence_number, key_1="", key_2="", key_3=""):
|
||||||
|
data = read_json_data(json_path)
|
||||||
|
target_data = data
|
||||||
|
if "batch_data" in data and isinstance(data["batch_data"], list) and len(data["batch_data"]) > 0:
|
||||||
|
idx = (sequence_number - 1) % len(data["batch_data"])
|
||||||
|
target_data = data["batch_data"][idx]
|
||||||
|
return (
|
||||||
|
str(target_data.get(key_1, "")),
|
||||||
|
str(target_data.get(key_2, "")),
|
||||||
|
str(target_data.get(key_3, ""))
|
||||||
|
)
|
||||||
|
|
||||||
|
class JSONLoaderCustom6:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"json_path": ("STRING", {"default": "", "multiline": False}),
|
||||||
|
"sequence_number": ("INT", {"default": 1, "min": 1, "max": 9999}),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"key_1": ("STRING", {"default": "", "multiline": False}),
|
||||||
|
"key_2": ("STRING", {"default": "", "multiline": False}),
|
||||||
|
"key_3": ("STRING", {"default": "", "multiline": False}),
|
||||||
|
"key_4": ("STRING", {"default": "", "multiline": False}),
|
||||||
|
"key_5": ("STRING", {"default": "", "multiline": False}),
|
||||||
|
"key_6": ("STRING", {"default": "", "multiline": False})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
RETURN_TYPES = ("STRING", "STRING", "STRING", "STRING", "STRING", "STRING")
|
||||||
|
RETURN_NAMES = ("val_1", "val_2", "val_3", "val_4", "val_5", "val_6")
|
||||||
|
FUNCTION = "load_custom"
|
||||||
|
CATEGORY = "utils/json"
|
||||||
|
|
||||||
|
def load_custom(self, json_path, sequence_number, key_1="", key_2="", key_3="", key_4="", key_5="", key_6=""):
|
||||||
|
data = read_json_data(json_path)
|
||||||
|
target_data = data
|
||||||
|
if "batch_data" in data and isinstance(data["batch_data"], list) and len(data["batch_data"]) > 0:
|
||||||
|
idx = (sequence_number - 1) % len(data["batch_data"])
|
||||||
|
target_data = data["batch_data"][idx]
|
||||||
|
return (
|
||||||
|
str(target_data.get(key_1, "")), str(target_data.get(key_2, "")),
|
||||||
|
str(target_data.get(key_3, "")), str(target_data.get(key_4, "")),
|
||||||
|
str(target_data.get(key_5, "")), str(target_data.get(key_6, ""))
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- Mappings ---
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"JSONLoaderLoRA": JSONLoaderLoRA,
|
||||||
|
"JSONLoaderStandard": JSONLoaderStandard,
|
||||||
|
"JSONLoaderVACE": JSONLoaderVACE,
|
||||||
|
"JSONLoaderBatchLoRA": JSONLoaderBatchLoRA,
|
||||||
|
"JSONLoaderBatchI2V": JSONLoaderBatchI2V,
|
||||||
|
"JSONLoaderBatchVACE": JSONLoaderBatchVACE,
|
||||||
|
"JSONLoaderCustom1": JSONLoaderCustom1,
|
||||||
|
"JSONLoaderCustom3": JSONLoaderCustom3,
|
||||||
|
"JSONLoaderCustom6": JSONLoaderCustom6
|
||||||
|
}
|
||||||
|
|
||||||
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
|
"JSONLoaderLoRA": "JSON Loader (LoRAs Only)",
|
||||||
|
"JSONLoaderStandard": "JSON Loader (Standard/I2V)",
|
||||||
|
"JSONLoaderVACE": "JSON Loader (VACE Full)",
|
||||||
|
"JSONLoaderBatchLoRA": "JSON Batch Loader (LoRAs)",
|
||||||
|
"JSONLoaderBatchI2V": "JSON Batch Loader (I2V)",
|
||||||
|
"JSONLoaderBatchVACE": "JSON Batch Loader (VACE)",
|
||||||
|
"JSONLoaderCustom1": "JSON Loader (Custom 1)",
|
||||||
|
"JSONLoaderCustom3": "JSON Loader (Custom 3)",
|
||||||
|
"JSONLoaderCustom6": "JSON Loader (Custom 6)"
|
||||||
|
}
|
||||||
@@ -1,582 +0,0 @@
|
|||||||
import asyncio
|
|
||||||
import copy
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from nicegui import ui
|
|
||||||
|
|
||||||
from state import AppState
|
|
||||||
from utils import (
|
|
||||||
load_config, save_config, load_snippets, save_snippets,
|
|
||||||
load_json, save_json, generate_templates, DEFAULTS,
|
|
||||||
KEY_BATCH_DATA, KEY_SEQUENCE_NUMBER,
|
|
||||||
resolve_path_case_insensitive, sync_to_db,
|
|
||||||
)
|
|
||||||
from tab_batch_ng import render_batch_processor
|
|
||||||
from tab_timeline_ng import render_timeline_tab
|
|
||||||
from tab_raw_ng import render_raw_editor
|
|
||||||
from tab_comfy_ng import render_comfy_monitor
|
|
||||||
from tab_projects_ng import render_projects_tab
|
|
||||||
from db import ProjectDB
|
|
||||||
from api_routes import register_api_routes
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# Single shared DB instance for both the UI and API routes
|
|
||||||
_shared_db: ProjectDB | None = None
|
|
||||||
try:
|
|
||||||
_shared_db = ProjectDB()
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Failed to initialize ProjectDB: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
@ui.page('/')
|
|
||||||
def index():
|
|
||||||
ui.dark_mode(True)
|
|
||||||
ui.colors(primary='#F59E0B')
|
|
||||||
ui.add_head_html(
|
|
||||||
'<link rel="stylesheet" href="https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700&display=swap">'
|
|
||||||
)
|
|
||||||
ui.add_css('''
|
|
||||||
/* === Dark Theme with Depth Palette === */
|
|
||||||
:root {
|
|
||||||
--bg-page: #0B0E14;
|
|
||||||
--bg-surface-1: #13161E;
|
|
||||||
--bg-surface-2: #1A1E2A;
|
|
||||||
--bg-surface-3: #242836;
|
|
||||||
--border: rgba(255,255,255,0.08);
|
|
||||||
--text-primary: #EAECF0;
|
|
||||||
--text-secondary: rgba(234,236,240,0.55);
|
|
||||||
--accent: #F59E0B;
|
|
||||||
--accent-subtle: rgba(245,158,11,0.12);
|
|
||||||
--negative: #EF4444;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Backgrounds */
|
|
||||||
body.body--dark,
|
|
||||||
.q-page.body--dark,
|
|
||||||
.body--dark .q-page { background: var(--bg-page) !important; }
|
|
||||||
.body--dark .q-drawer { background: var(--bg-surface-1) !important; }
|
|
||||||
.body--dark .q-card {
|
|
||||||
background: var(--bg-surface-2) !important;
|
|
||||||
border: 1px solid var(--border);
|
|
||||||
border-radius: 0.75rem;
|
|
||||||
}
|
|
||||||
.body--dark .q-tab-panels { background: transparent !important; }
|
|
||||||
.body--dark .q-tab-panel { background: transparent !important; }
|
|
||||||
.body--dark .q-expansion-item { background: transparent !important; }
|
|
||||||
|
|
||||||
/* Text */
|
|
||||||
.body--dark { color: var(--text-primary) !important; }
|
|
||||||
.body--dark .q-field__label { color: var(--text-secondary) !important; }
|
|
||||||
.body--dark .text-caption { color: var(--text-secondary) !important; }
|
|
||||||
.body--dark .text-subtitle1,
|
|
||||||
.body--dark .text-subtitle2 { color: var(--text-primary) !important; }
|
|
||||||
|
|
||||||
/* Inputs & textareas */
|
|
||||||
.body--dark .q-field--outlined .q-field__control {
|
|
||||||
background: var(--bg-surface-3) !important;
|
|
||||||
border-radius: 0.5rem !important;
|
|
||||||
}
|
|
||||||
.body--dark .q-field--outlined .q-field__control:before {
|
|
||||||
border-color: var(--border) !important;
|
|
||||||
border-radius: 0.5rem !important;
|
|
||||||
}
|
|
||||||
.body--dark .q-field--outlined.q-field--focused .q-field__control:after {
|
|
||||||
border-color: var(--accent) !important;
|
|
||||||
}
|
|
||||||
.body--dark .q-field__native,
|
|
||||||
.body--dark .q-field__input { color: var(--text-primary) !important; }
|
|
||||||
|
|
||||||
/* Sidebar inputs get page bg */
|
|
||||||
.body--dark .q-drawer .q-field--outlined .q-field__control {
|
|
||||||
background: var(--bg-page) !important;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Buttons */
|
|
||||||
.body--dark .q-btn--standard { border-radius: 0.5rem !important; }
|
|
||||||
.body--dark .q-btn--outline {
|
|
||||||
transition: background 0.15s ease;
|
|
||||||
}
|
|
||||||
.body--dark .q-btn--outline:hover {
|
|
||||||
background: var(--accent-subtle) !important;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Tabs */
|
|
||||||
.body--dark .q-tab--active { color: var(--accent) !important; }
|
|
||||||
.body--dark .q-tab__indicator { background: var(--accent) !important; }
|
|
||||||
|
|
||||||
/* Separators */
|
|
||||||
.body--dark .q-separator { background: var(--border) !important; }
|
|
||||||
|
|
||||||
/* Expansion items */
|
|
||||||
.body--dark .q-expansion-item__content { padding: 12px 16px; }
|
|
||||||
.body--dark .q-item { border-radius: 0.5rem; }
|
|
||||||
|
|
||||||
/* Splitter */
|
|
||||||
.body--dark .q-splitter__separator { background: var(--border) !important; }
|
|
||||||
.body--dark .q-splitter__before,
|
|
||||||
.body--dark .q-splitter__after { padding: 0 8px; }
|
|
||||||
|
|
||||||
/* Action row wrap */
|
|
||||||
.action-row { flex-wrap: wrap !important; gap: 8px !important; }
|
|
||||||
|
|
||||||
/* Notifications */
|
|
||||||
.body--dark .q-notification { border-radius: 0.5rem; }
|
|
||||||
|
|
||||||
/* Font */
|
|
||||||
body { font-family: "Inter", "Source Sans Pro", "Source Sans 3", sans-serif !important; }
|
|
||||||
|
|
||||||
/* Surface utility classes (need .body--dark to beat .body--dark .q-card specificity) */
|
|
||||||
.body--dark .surface-1 { background: var(--bg-surface-1) !important; }
|
|
||||||
.body--dark .surface-2 { background: var(--bg-surface-2) !important; }
|
|
||||||
.body--dark .surface-3 { background: var(--bg-surface-3) !important; }
|
|
||||||
|
|
||||||
/* Typography utility classes */
|
|
||||||
.section-header {
|
|
||||||
font-size: 0.8rem;
|
|
||||||
font-weight: 600;
|
|
||||||
text-transform: uppercase;
|
|
||||||
letter-spacing: 0.05em;
|
|
||||||
color: var(--text-secondary) !important;
|
|
||||||
}
|
|
||||||
.subsection-header {
|
|
||||||
font-size: 0.85rem;
|
|
||||||
font-weight: 500;
|
|
||||||
color: var(--text-primary) !important;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Scrollbar */
|
|
||||||
::-webkit-scrollbar { width: 6px; height: 6px; }
|
|
||||||
::-webkit-scrollbar-track { background: transparent; }
|
|
||||||
::-webkit-scrollbar-thumb {
|
|
||||||
background: rgba(255,255,255,0.12);
|
|
||||||
border-radius: 3px;
|
|
||||||
}
|
|
||||||
::-webkit-scrollbar-thumb:hover {
|
|
||||||
background: rgba(255,255,255,0.2);
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Sub-sequence accent colors (per sub-index, cycling) */
|
|
||||||
.body--dark .subsegment-color-0 > .q-expansion-item__container > .q-item { border-left: 6px solid #06B6D4; padding-left: 10px; }
|
|
||||||
.body--dark .subsegment-color-0 .q-expansion-item__toggle-icon { color: #06B6D4 !important; }
|
|
||||||
.body--dark .subsegment-color-1 > .q-expansion-item__container > .q-item { border-left: 6px solid #A78BFA; padding-left: 10px; }
|
|
||||||
.body--dark .subsegment-color-1 .q-expansion-item__toggle-icon { color: #A78BFA !important; }
|
|
||||||
.body--dark .subsegment-color-2 > .q-expansion-item__container > .q-item { border-left: 6px solid #34D399; padding-left: 10px; }
|
|
||||||
.body--dark .subsegment-color-2 .q-expansion-item__toggle-icon { color: #34D399 !important; }
|
|
||||||
.body--dark .subsegment-color-3 > .q-expansion-item__container > .q-item { border-left: 6px solid #F472B6; padding-left: 10px; }
|
|
||||||
.body--dark .subsegment-color-3 .q-expansion-item__toggle-icon { color: #F472B6 !important; }
|
|
||||||
.body--dark .subsegment-color-4 > .q-expansion-item__container > .q-item { border-left: 6px solid #FBBF24; padding-left: 10px; }
|
|
||||||
.body--dark .subsegment-color-4 .q-expansion-item__toggle-icon { color: #FBBF24 !important; }
|
|
||||||
.body--dark .subsegment-color-5 > .q-expansion-item__container > .q-item { border-left: 6px solid #FB923C; padding-left: 10px; }
|
|
||||||
.body--dark .subsegment-color-5 .q-expansion-item__toggle-icon { color: #FB923C !important; }
|
|
||||||
|
|
||||||
/* Secondary pane teal accent */
|
|
||||||
.pane-secondary .q-field--outlined.q-field--focused .q-field__control:after {
|
|
||||||
border-color: #06B6D4 !important;
|
|
||||||
}
|
|
||||||
.pane-secondary .q-btn.bg-primary { background-color: #06B6D4 !important; }
|
|
||||||
.pane-secondary .section-header { color: rgba(6,182,212,0.7) !important; }
|
|
||||||
''')
|
|
||||||
|
|
||||||
config = load_config()
|
|
||||||
state = AppState(
|
|
||||||
config=config,
|
|
||||||
current_dir=Path(config.get('last_dir', str(Path.cwd()))),
|
|
||||||
snippets=load_snippets(),
|
|
||||||
db_enabled=config.get('db_enabled', False),
|
|
||||||
current_project=config.get('current_project', ''),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Use the shared DB instance
|
|
||||||
state.db = _shared_db
|
|
||||||
|
|
||||||
dual_pane = {'active': False, 'state': None}
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# Define helpers FIRST (before sidebar, which needs them)
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
|
|
||||||
@ui.refreshable
|
|
||||||
def render_main_content():
|
|
||||||
import time as _time
|
|
||||||
_t0 = _time.perf_counter()
|
|
||||||
logger.info("render_main_content START")
|
|
||||||
max_w = '2400px' if dual_pane['active'] else '1200px'
|
|
||||||
with ui.column().classes('w-full q-pa-md').style(f'max-width: {max_w}; margin: 0 auto'):
|
|
||||||
if not state.file_path or not state.file_path.exists():
|
|
||||||
ui.label('Select a file from the sidebar to begin.').classes(
|
|
||||||
'text-subtitle1 q-pa-lg')
|
|
||||||
return
|
|
||||||
|
|
||||||
ui.label(f'Editing: {state.file_path.name}').classes('text-h5 q-mb-lg').style('font-weight: 600')
|
|
||||||
|
|
||||||
with ui.tabs().classes('w-full').style('border-bottom: 1px solid var(--border)') as tabs:
|
|
||||||
ui.tab('batch', label='Batch Processor')
|
|
||||||
ui.tab('timeline', label='Timeline')
|
|
||||||
ui.tab('raw', label='Raw Editor')
|
|
||||||
ui.tab('projects', label='Projects')
|
|
||||||
|
|
||||||
with ui.tab_panels(tabs, value='batch').classes('w-full'):
|
|
||||||
with ui.tab_panel('batch'):
|
|
||||||
_render_batch_tab_content()
|
|
||||||
with ui.tab_panel('timeline'):
|
|
||||||
render_timeline_tab(state)
|
|
||||||
with ui.tab_panel('raw'):
|
|
||||||
render_raw_editor(state)
|
|
||||||
with ui.tab_panel('projects'):
|
|
||||||
render_projects_tab(state)
|
|
||||||
|
|
||||||
if state.show_comfy_monitor:
|
|
||||||
ui.separator()
|
|
||||||
with ui.expansion('ComfyUI Monitor', icon='dns').classes('w-full'):
|
|
||||||
render_comfy_monitor(state)
|
|
||||||
|
|
||||||
logger.info("render_main_content END (%.3fs)", _time.perf_counter() - _t0)
|
|
||||||
|
|
||||||
@ui.refreshable
|
|
||||||
def _render_batch_tab_content():
|
|
||||||
def on_toggle(e):
|
|
||||||
dual_pane['active'] = e.value
|
|
||||||
if e.value and dual_pane['state'] is None:
|
|
||||||
s2 = state.create_secondary()
|
|
||||||
s2._render_main = _render_batch_tab_content
|
|
||||||
dual_pane['state'] = s2
|
|
||||||
render_main_content.refresh()
|
|
||||||
|
|
||||||
ui.switch('Dual Pane', value=dual_pane['active'], on_change=on_toggle)
|
|
||||||
|
|
||||||
if not dual_pane['active']:
|
|
||||||
render_batch_processor(state)
|
|
||||||
else:
|
|
||||||
s2 = dual_pane['state']
|
|
||||||
with ui.row().classes('w-full gap-4'):
|
|
||||||
with ui.column().classes('col'):
|
|
||||||
ui.label('Pane A').classes('section-header q-mb-sm')
|
|
||||||
_render_pane_file_selector(state)
|
|
||||||
render_batch_processor(state)
|
|
||||||
with ui.column().classes('col pane-secondary'):
|
|
||||||
ui.label('Pane B').classes('section-header q-mb-sm')
|
|
||||||
_render_pane_file_selector(s2)
|
|
||||||
if s2.file_path and s2.file_path.exists():
|
|
||||||
render_batch_processor(s2)
|
|
||||||
else:
|
|
||||||
ui.label('Select a file above to begin.').classes(
|
|
||||||
'text-caption q-pa-md')
|
|
||||||
|
|
||||||
def _render_pane_file_selector(pane_state: AppState):
|
|
||||||
if not pane_state.current_dir.exists():
|
|
||||||
ui.label('Directory not found.').classes('text-warning')
|
|
||||||
return
|
|
||||||
json_files = sorted(pane_state.current_dir.glob('*.json'))
|
|
||||||
json_files = [f for f in json_files if f.name not in (
|
|
||||||
'.editor_config.json', '.editor_snippets.json')]
|
|
||||||
file_names = [f.name for f in json_files]
|
|
||||||
|
|
||||||
current_val = pane_state.file_path.name if pane_state.file_path else None
|
|
||||||
|
|
||||||
async def on_select(e):
|
|
||||||
if not e.value:
|
|
||||||
return
|
|
||||||
import time as _time
|
|
||||||
_t0 = _time.perf_counter()
|
|
||||||
logger.info("on_select START: %s", e.value)
|
|
||||||
fp = pane_state.current_dir / e.value
|
|
||||||
file_stem = fp.stem
|
|
||||||
data = None
|
|
||||||
if pane_state.db and pane_state.db_enabled and pane_state.current_project:
|
|
||||||
data = await asyncio.to_thread(
|
|
||||||
pane_state.db.load_full_data, pane_state.current_project, file_stem)
|
|
||||||
if data is None:
|
|
||||||
data, _ = await asyncio.to_thread(load_json, fp)
|
|
||||||
if pane_state.db and pane_state.db_enabled and pane_state.current_project:
|
|
||||||
await asyncio.to_thread(
|
|
||||||
sync_to_db, pane_state.db, pane_state.current_project, fp, data)
|
|
||||||
tree = data.get('history_tree')
|
|
||||||
if tree and isinstance(tree, dict):
|
|
||||||
for entry in tree.get('snapshots', tree.get('nodes', {})).values():
|
|
||||||
entry.pop('data', None)
|
|
||||||
for backup in data.get('history_tree_backup', []):
|
|
||||||
if isinstance(backup, dict):
|
|
||||||
for entry in backup.get('snapshots', backup.get('nodes', {})).values():
|
|
||||||
entry.pop('data', None)
|
|
||||||
pane_state.data_cache = data
|
|
||||||
pane_state.last_mtime = fp.stat().st_mtime if fp.exists() else 0
|
|
||||||
pane_state.loaded_file = str(fp)
|
|
||||||
pane_state.file_path = fp
|
|
||||||
pane_state.restored_indicator = None
|
|
||||||
pane_state._src_cache = {'data': None, 'batch': [], 'name': None}
|
|
||||||
_render_batch_tab_content.refresh()
|
|
||||||
logger.info("on_select END (%.3fs)", _time.perf_counter() - _t0)
|
|
||||||
|
|
||||||
ui.select(
|
|
||||||
file_names,
|
|
||||||
value=current_val,
|
|
||||||
label='File',
|
|
||||||
on_change=on_select,
|
|
||||||
).classes('w-full')
|
|
||||||
|
|
||||||
async def load_file(file_name: str):
|
|
||||||
"""Load data from DB (fast) with JSON fallback, and refresh the main content."""
|
|
||||||
import time as _time
|
|
||||||
_t0 = _time.perf_counter()
|
|
||||||
logger.info("load_file START: %s", file_name)
|
|
||||||
fp = state.current_dir / file_name
|
|
||||||
if state.loaded_file == str(fp):
|
|
||||||
return
|
|
||||||
file_stem = fp.stem
|
|
||||||
data = None
|
|
||||||
if state.db and state.db_enabled and state.current_project:
|
|
||||||
data = await asyncio.to_thread(
|
|
||||||
state.db.load_full_data, state.current_project, file_stem)
|
|
||||||
if data is None:
|
|
||||||
data, _ = await asyncio.to_thread(load_json, fp)
|
|
||||||
# When loading from JSON fallback and DB is enabled, sync to DB
|
|
||||||
# so snapshots are persisted, then strip from memory
|
|
||||||
if state.db and state.db_enabled and state.current_project:
|
|
||||||
await asyncio.to_thread(
|
|
||||||
sync_to_db, state.db, state.current_project, fp, data)
|
|
||||||
tree = data.get('history_tree')
|
|
||||||
if tree and isinstance(tree, dict):
|
|
||||||
for entry in tree.get('snapshots', tree.get('nodes', {})).values():
|
|
||||||
entry.pop('data', None)
|
|
||||||
# Strip snapshot data from history_tree_backup to prevent RAM/disk bloat
|
|
||||||
for backup in data.get('history_tree_backup', []):
|
|
||||||
if isinstance(backup, dict):
|
|
||||||
for entry in backup.get('snapshots', backup.get('nodes', {})).values():
|
|
||||||
entry.pop('data', None)
|
|
||||||
state.data_cache = data
|
|
||||||
state.last_mtime = fp.stat().st_mtime if fp.exists() else 0
|
|
||||||
state.loaded_file = str(fp)
|
|
||||||
state.file_path = fp
|
|
||||||
state.restored_indicator = None
|
|
||||||
state._src_cache = {'data': None, 'batch': [], 'name': None}
|
|
||||||
if state._main_rendered:
|
|
||||||
render_main_content.refresh()
|
|
||||||
logger.info("load_file END (%.3fs)", _time.perf_counter() - _t0)
|
|
||||||
|
|
||||||
# Attach helpers to state so sidebar can call them
|
|
||||||
state._load_file = load_file
|
|
||||||
state._render_main = render_main_content
|
|
||||||
state._main_rendered = False
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# Sidebar (rendered AFTER helpers are attached)
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
with ui.left_drawer(value=True).classes('q-pa-md').style('width: 320px'):
|
|
||||||
render_sidebar(state, dual_pane)
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# Main content area
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
render_main_content()
|
|
||||||
state._main_rendered = True
|
|
||||||
|
|
||||||
|
|
||||||
# ======================================================================
|
|
||||||
# Sidebar
|
|
||||||
# ======================================================================
|
|
||||||
|
|
||||||
def render_sidebar(state: AppState, dual_pane: dict):
|
|
||||||
ui.label('Navigator').classes('text-h6')
|
|
||||||
|
|
||||||
# --- Path input + Pin ---
|
|
||||||
with ui.card().classes('w-full q-pa-md q-mb-md'):
|
|
||||||
path_input = ui.input(
|
|
||||||
'Current Path',
|
|
||||||
value=str(state.current_dir),
|
|
||||||
).classes('w-full')
|
|
||||||
|
|
||||||
def on_path_enter():
|
|
||||||
p = resolve_path_case_insensitive(path_input.value)
|
|
||||||
if p is not None and p.is_dir():
|
|
||||||
state.current_dir = p
|
|
||||||
if dual_pane['state']:
|
|
||||||
dual_pane['state'].current_dir = state.current_dir
|
|
||||||
dual_pane['state'].file_path = None
|
|
||||||
dual_pane['state'].loaded_file = None
|
|
||||||
dual_pane['state'].data_cache = {}
|
|
||||||
state.config['last_dir'] = str(p)
|
|
||||||
save_config(state.current_dir, state.config['favorites'], state.config)
|
|
||||||
state.loaded_file = None
|
|
||||||
state.file_path = None
|
|
||||||
path_input.set_value(str(p))
|
|
||||||
render_file_list.refresh()
|
|
||||||
# Auto-load inside render_file_list already refreshed main content
|
|
||||||
# if files exist; only refresh here for the empty-directory case.
|
|
||||||
if not state.loaded_file:
|
|
||||||
state._render_main.refresh()
|
|
||||||
|
|
||||||
path_input.on('keydown.enter', lambda _: on_path_enter())
|
|
||||||
|
|
||||||
def pin_folder():
|
|
||||||
d = str(state.current_dir)
|
|
||||||
if d not in state.config['favorites']:
|
|
||||||
state.config['favorites'].append(d)
|
|
||||||
save_config(state.current_dir, state.config['favorites'], state.config)
|
|
||||||
render_favorites.refresh()
|
|
||||||
|
|
||||||
ui.button('Pin Folder', icon='push_pin', on_click=pin_folder).classes('w-full')
|
|
||||||
|
|
||||||
# --- Favorites ---
|
|
||||||
with ui.card().classes('w-full q-pa-md q-mb-md'):
|
|
||||||
ui.label('Favorites').classes('section-header')
|
|
||||||
|
|
||||||
@ui.refreshable
|
|
||||||
def render_favorites():
|
|
||||||
for fav in list(state.config['favorites']):
|
|
||||||
with ui.row().classes('w-full items-center'):
|
|
||||||
ui.button(
|
|
||||||
fav,
|
|
||||||
on_click=lambda f=fav: _jump_to(f),
|
|
||||||
).props('flat dense').classes('col')
|
|
||||||
ui.button(
|
|
||||||
icon='close',
|
|
||||||
on_click=lambda f=fav: _unpin(f),
|
|
||||||
).props('flat dense color=negative')
|
|
||||||
|
|
||||||
def _jump_to(fav: str):
|
|
||||||
state.current_dir = Path(fav)
|
|
||||||
if dual_pane['state']:
|
|
||||||
dual_pane['state'].current_dir = state.current_dir
|
|
||||||
dual_pane['state'].file_path = None
|
|
||||||
dual_pane['state'].loaded_file = None
|
|
||||||
dual_pane['state'].data_cache = {}
|
|
||||||
state.config['last_dir'] = fav
|
|
||||||
save_config(state.current_dir, state.config['favorites'], state.config)
|
|
||||||
state.loaded_file = None
|
|
||||||
state.file_path = None
|
|
||||||
path_input.set_value(fav)
|
|
||||||
render_file_list.refresh()
|
|
||||||
if not state.loaded_file:
|
|
||||||
state._render_main.refresh()
|
|
||||||
|
|
||||||
def _unpin(fav: str):
|
|
||||||
if fav in state.config['favorites']:
|
|
||||||
state.config['favorites'].remove(fav)
|
|
||||||
save_config(state.current_dir, state.config['favorites'], state.config)
|
|
||||||
render_favorites.refresh()
|
|
||||||
|
|
||||||
render_favorites()
|
|
||||||
|
|
||||||
# --- Snippet Library ---
|
|
||||||
with ui.card().classes('w-full q-pa-md q-mb-md'):
|
|
||||||
ui.label('Snippet Library').classes('section-header')
|
|
||||||
|
|
||||||
with ui.expansion('Add New Snippet'):
|
|
||||||
snip_name_input = ui.input('Name', placeholder='e.g. Cinematic').classes('w-full')
|
|
||||||
snip_content_input = ui.textarea('Content', placeholder='4k, high quality...').classes('w-full')
|
|
||||||
|
|
||||||
def save_snippet():
|
|
||||||
name = snip_name_input.value
|
|
||||||
content = snip_content_input.value
|
|
||||||
if name and content:
|
|
||||||
state.snippets[name] = content
|
|
||||||
save_snippets(state.snippets)
|
|
||||||
snip_name_input.set_value('')
|
|
||||||
snip_content_input.set_value('')
|
|
||||||
ui.notify(f"Saved '{name}'")
|
|
||||||
render_snippet_list.refresh()
|
|
||||||
|
|
||||||
ui.button('Save Snippet', on_click=save_snippet).classes('w-full')
|
|
||||||
|
|
||||||
@ui.refreshable
|
|
||||||
def render_snippet_list():
|
|
||||||
if not state.snippets:
|
|
||||||
return
|
|
||||||
ui.label('Click to copy snippet text:').classes('text-caption')
|
|
||||||
for name, content in list(state.snippets.items()):
|
|
||||||
with ui.row().classes('w-full items-center'):
|
|
||||||
async def copy_snippet(c=content):
|
|
||||||
await ui.run_javascript(
|
|
||||||
f'navigator.clipboard.writeText({json.dumps(c)})', timeout=3.0)
|
|
||||||
ui.notify('Copied to clipboard')
|
|
||||||
|
|
||||||
ui.button(
|
|
||||||
f'{name}',
|
|
||||||
on_click=copy_snippet,
|
|
||||||
).props('flat dense').classes('col')
|
|
||||||
ui.button(
|
|
||||||
icon='delete',
|
|
||||||
on_click=lambda n=name: _del_snippet(n),
|
|
||||||
).props('flat dense color=negative')
|
|
||||||
|
|
||||||
def _del_snippet(name: str):
|
|
||||||
if name in state.snippets:
|
|
||||||
del state.snippets[name]
|
|
||||||
save_snippets(state.snippets)
|
|
||||||
render_snippet_list.refresh()
|
|
||||||
|
|
||||||
render_snippet_list()
|
|
||||||
|
|
||||||
# --- File List ---
|
|
||||||
with ui.card().classes('w-full q-pa-md q-mb-md'):
|
|
||||||
@ui.refreshable
|
|
||||||
def render_file_list():
|
|
||||||
if not state.current_dir.exists():
|
|
||||||
ui.label('Directory not found.').classes('text-warning')
|
|
||||||
return
|
|
||||||
json_files = sorted(state.current_dir.glob('*.json'))
|
|
||||||
json_files = [f for f in json_files if f.name not in ('.editor_config.json', '.editor_snippets.json')]
|
|
||||||
|
|
||||||
if not json_files:
|
|
||||||
ui.label('No JSON files in this folder.').classes('text-caption')
|
|
||||||
ui.button('Generate Templates', on_click=lambda: _gen_templates()).classes('w-full')
|
|
||||||
return
|
|
||||||
|
|
||||||
with ui.expansion('Create New JSON'):
|
|
||||||
new_fn_input = ui.input('Filename', placeholder='my_prompt_vace').classes('w-full')
|
|
||||||
|
|
||||||
async def create_new():
|
|
||||||
fn = new_fn_input.value
|
|
||||||
if not fn:
|
|
||||||
return
|
|
||||||
if not fn.endswith('.json'):
|
|
||||||
fn += '.json'
|
|
||||||
path = state.current_dir / fn
|
|
||||||
first_item = copy.deepcopy(DEFAULTS)
|
|
||||||
first_item[KEY_SEQUENCE_NUMBER] = 1
|
|
||||||
await asyncio.to_thread(save_json, path, {KEY_BATCH_DATA: [first_item]})
|
|
||||||
new_fn_input.set_value('')
|
|
||||||
render_file_list.refresh()
|
|
||||||
|
|
||||||
ui.button('Create', on_click=create_new).classes('w-full')
|
|
||||||
|
|
||||||
ui.label('Select File').classes('subsection-header q-mt-sm')
|
|
||||||
file_names = [f.name for f in json_files]
|
|
||||||
current = Path(state.loaded_file).name if state.loaded_file else None
|
|
||||||
selected = current if current in file_names else (file_names[0] if file_names else None)
|
|
||||||
async def _on_radio(e):
|
|
||||||
if e.value:
|
|
||||||
await state._load_file(e.value)
|
|
||||||
|
|
||||||
ui.radio(
|
|
||||||
file_names,
|
|
||||||
value=selected,
|
|
||||||
on_change=_on_radio,
|
|
||||||
).classes('w-full')
|
|
||||||
|
|
||||||
# Auto-load first file if nothing loaded yet
|
|
||||||
if file_names and not state.loaded_file:
|
|
||||||
asyncio.ensure_future(state._load_file(file_names[0]))
|
|
||||||
|
|
||||||
def _gen_templates():
|
|
||||||
generate_templates(state.current_dir)
|
|
||||||
render_file_list.refresh()
|
|
||||||
|
|
||||||
render_file_list()
|
|
||||||
|
|
||||||
# --- Comfy Monitor toggle ---
|
|
||||||
def on_monitor_toggle(e):
|
|
||||||
state.show_comfy_monitor = e.value
|
|
||||||
state._render_main.refresh()
|
|
||||||
|
|
||||||
ui.checkbox('Show Comfy Monitor', value=state.show_comfy_monitor, on_change=on_monitor_toggle)
|
|
||||||
|
|
||||||
|
|
||||||
# Register REST API routes for ComfyUI connectivity (uses the shared DB instance)
|
|
||||||
if _shared_db is not None:
|
|
||||||
register_api_routes(_shared_db)
|
|
||||||
|
|
||||||
ui.run(title='AI Settings Manager', port=8080, reload=False)
|
|
||||||
@@ -1,420 +0,0 @@
|
|||||||
import asyncio
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import urllib.parse
|
|
||||||
import urllib.request
|
|
||||||
import urllib.error
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
MAX_DYNAMIC_OUTPUTS = 32
|
|
||||||
|
|
||||||
|
|
||||||
class AnyType(str):
|
|
||||||
"""Universal connector type that matches any ComfyUI type."""
|
|
||||||
def __ne__(self, __value: object) -> bool:
|
|
||||||
return False
|
|
||||||
|
|
||||||
any_type = AnyType("*")
|
|
||||||
|
|
||||||
|
|
||||||
try:
|
|
||||||
from server import PromptServer
|
|
||||||
from aiohttp import web
|
|
||||||
except ImportError:
|
|
||||||
PromptServer = None
|
|
||||||
|
|
||||||
|
|
||||||
def to_float(val: Any) -> float:
|
|
||||||
try:
|
|
||||||
return float(val)
|
|
||||||
except (ValueError, TypeError):
|
|
||||||
return 0.0
|
|
||||||
|
|
||||||
def to_int(val: Any) -> int:
|
|
||||||
try:
|
|
||||||
return int(float(val))
|
|
||||||
except (ValueError, TypeError):
|
|
||||||
return 0
|
|
||||||
|
|
||||||
|
|
||||||
def _fetch_json(url: str) -> dict:
|
|
||||||
"""Fetch JSON from a URL using stdlib urllib.
|
|
||||||
|
|
||||||
On error, returns a dict with an "error" key describing the failure.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
with urllib.request.urlopen(url, timeout=5) as resp:
|
|
||||||
return json.loads(resp.read())
|
|
||||||
except urllib.error.HTTPError as e:
|
|
||||||
# HTTPError is a subclass of URLError — must be caught first
|
|
||||||
body = ""
|
|
||||||
try:
|
|
||||||
raw = e.read()
|
|
||||||
detail = json.loads(raw)
|
|
||||||
body = detail.get("detail", str(raw, "utf-8", errors="replace"))
|
|
||||||
except Exception:
|
|
||||||
body = str(e)
|
|
||||||
logger.warning(f"HTTP {e.code} from {url}: {body}")
|
|
||||||
return {"error": "http_error", "status": e.code, "message": body}
|
|
||||||
except (urllib.error.URLError, OSError) as e:
|
|
||||||
reason = str(e.reason) if hasattr(e, "reason") else str(e)
|
|
||||||
logger.warning(f"Network error fetching {url}: {reason}")
|
|
||||||
return {"error": "network_error", "message": reason}
|
|
||||||
except json.JSONDecodeError as e:
|
|
||||||
logger.warning(f"Invalid JSON from {url}: {e}")
|
|
||||||
return {"error": "parse_error", "message": str(e)}
|
|
||||||
|
|
||||||
|
|
||||||
def _fetch_project(manager_url: str, project: str) -> dict:
|
|
||||||
"""Fetch project details (including folder_path) from the NiceGUI REST API."""
|
|
||||||
p = urllib.parse.quote(project, safe='')
|
|
||||||
url = f"{manager_url.rstrip('/')}/api/projects/{p}"
|
|
||||||
return _fetch_json(url)
|
|
||||||
|
|
||||||
|
|
||||||
def _fetch_data(manager_url: str, project: str, file: str, seq: int) -> dict:
|
|
||||||
"""Fetch sequence data from the NiceGUI REST API."""
|
|
||||||
p = urllib.parse.quote(project, safe='')
|
|
||||||
f = urllib.parse.quote(file, safe='')
|
|
||||||
url = f"{manager_url.rstrip('/')}/api/projects/{p}/files/{f}/data?seq={seq}"
|
|
||||||
return _fetch_json(url)
|
|
||||||
|
|
||||||
|
|
||||||
def _fetch_keys(manager_url: str, project: str, file: str, seq: int) -> dict:
|
|
||||||
"""Fetch keys/types from the NiceGUI REST API."""
|
|
||||||
p = urllib.parse.quote(project, safe='')
|
|
||||||
f = urllib.parse.quote(file, safe='')
|
|
||||||
url = f"{manager_url.rstrip('/')}/api/projects/{p}/files/{f}/keys?seq={seq}"
|
|
||||||
return _fetch_json(url)
|
|
||||||
|
|
||||||
|
|
||||||
# --- ComfyUI-side proxy endpoints (for frontend JS) ---
|
|
||||||
if PromptServer is not None:
|
|
||||||
@PromptServer.instance.routes.get("/json_manager/list_projects")
|
|
||||||
async def list_projects_proxy(request):
|
|
||||||
manager_url = request.query.get("url", "http://localhost:8080")
|
|
||||||
url = f"{manager_url.rstrip('/')}/api/projects"
|
|
||||||
data = await asyncio.to_thread(_fetch_json, url)
|
|
||||||
return web.json_response(data)
|
|
||||||
|
|
||||||
@PromptServer.instance.routes.get("/json_manager/list_project_files")
|
|
||||||
async def list_project_files_proxy(request):
|
|
||||||
manager_url = request.query.get("url", "http://localhost:8080")
|
|
||||||
project = urllib.parse.quote(request.query.get("project", ""), safe='')
|
|
||||||
url = f"{manager_url.rstrip('/')}/api/projects/{project}/files"
|
|
||||||
data = await asyncio.to_thread(_fetch_json, url)
|
|
||||||
return web.json_response(data)
|
|
||||||
|
|
||||||
@PromptServer.instance.routes.get("/json_manager/list_project_sequences")
|
|
||||||
async def list_project_sequences_proxy(request):
|
|
||||||
manager_url = request.query.get("url", "http://localhost:8080")
|
|
||||||
project = urllib.parse.quote(request.query.get("project", ""), safe='')
|
|
||||||
file_name = urllib.parse.quote(request.query.get("file", ""), safe='')
|
|
||||||
url = f"{manager_url.rstrip('/')}/api/projects/{project}/files/{file_name}/sequences"
|
|
||||||
data = await asyncio.to_thread(_fetch_json, url)
|
|
||||||
return web.json_response(data)
|
|
||||||
|
|
||||||
@PromptServer.instance.routes.get("/json_manager/get_project_keys")
|
|
||||||
async def get_project_keys_proxy(request):
|
|
||||||
manager_url = request.query.get("url", "http://localhost:8080")
|
|
||||||
project = request.query.get("project", "")
|
|
||||||
file_name = request.query.get("file", "")
|
|
||||||
try:
|
|
||||||
seq = int(request.query.get("seq", "1"))
|
|
||||||
except (ValueError, TypeError):
|
|
||||||
seq = 1
|
|
||||||
data = await asyncio.to_thread(_fetch_keys, manager_url, project, file_name, seq)
|
|
||||||
if data.get("error") in ("http_error", "network_error", "parse_error"):
|
|
||||||
status = data.get("status", 502)
|
|
||||||
return web.json_response(data, status=status)
|
|
||||||
return web.json_response(data)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# ==========================================
|
|
||||||
# 0. DYNAMIC NODE (Project-based)
|
|
||||||
# ==========================================
|
|
||||||
|
|
||||||
class ProjectLoaderDynamic:
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(s):
|
|
||||||
return {
|
|
||||||
"required": {
|
|
||||||
"manager_url": ("STRING", {"default": "http://localhost:8080", "multiline": False}),
|
|
||||||
"project_name": ("STRING", {"default": "", "multiline": False}),
|
|
||||||
"file_name": ("STRING", {"default": "", "multiline": False}),
|
|
||||||
"sequence_number": ("INT", {"default": 1, "min": 1, "max": 9999}),
|
|
||||||
"refresh": (["off", "on"],),
|
|
||||||
},
|
|
||||||
"optional": {
|
|
||||||
"output_keys": ("STRING", {"default": ""}),
|
|
||||||
"output_types": ("STRING", {"default": ""}),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
RETURN_TYPES = ("INT",) + tuple(any_type for _ in range(MAX_DYNAMIC_OUTPUTS))
|
|
||||||
RETURN_NAMES = ("total_sequences",) + tuple(f"output_{i}" for i in range(MAX_DYNAMIC_OUTPUTS))
|
|
||||||
FUNCTION = "load_dynamic"
|
|
||||||
CATEGORY = "JSON Manager/project"
|
|
||||||
OUTPUT_NODE = False
|
|
||||||
|
|
||||||
def load_dynamic(self, manager_url, project_name, file_name, sequence_number,
|
|
||||||
refresh="off", output_keys="", output_types=""):
|
|
||||||
# Fetch keys metadata (includes total_sequences count)
|
|
||||||
keys_meta = _fetch_keys(manager_url, project_name, file_name, sequence_number)
|
|
||||||
if keys_meta.get("error") in ("http_error", "network_error", "parse_error"):
|
|
||||||
msg = keys_meta.get("message", "Unknown error")
|
|
||||||
raise RuntimeError(f"Failed to fetch project keys: {msg}")
|
|
||||||
total_sequences = keys_meta.get("total_sequences", 0)
|
|
||||||
|
|
||||||
data = _fetch_data(manager_url, project_name, file_name, sequence_number)
|
|
||||||
if data.get("error") in ("http_error", "network_error", "parse_error"):
|
|
||||||
msg = data.get("message", "Unknown error")
|
|
||||||
raise RuntimeError(f"Failed to fetch sequence data: {msg}")
|
|
||||||
|
|
||||||
# Parse keys — try JSON array first, fall back to comma-split for compat
|
|
||||||
keys = []
|
|
||||||
if output_keys:
|
|
||||||
try:
|
|
||||||
keys = json.loads(output_keys)
|
|
||||||
except (json.JSONDecodeError, TypeError):
|
|
||||||
keys = [k.strip() for k in output_keys.split(",") if k.strip()]
|
|
||||||
|
|
||||||
# Parse types for coercion
|
|
||||||
types = []
|
|
||||||
if output_types:
|
|
||||||
try:
|
|
||||||
types = json.loads(output_types)
|
|
||||||
except (json.JSONDecodeError, TypeError):
|
|
||||||
types = [t.strip() for t in output_types.split(",")]
|
|
||||||
|
|
||||||
results = []
|
|
||||||
for i, key in enumerate(keys):
|
|
||||||
val = data.get(key, "")
|
|
||||||
declared_type = types[i] if i < len(types) else ""
|
|
||||||
# Coerce based on declared output type when possible
|
|
||||||
if declared_type == "INT":
|
|
||||||
results.append(to_int(val))
|
|
||||||
elif declared_type == "FLOAT":
|
|
||||||
results.append(to_float(val))
|
|
||||||
elif isinstance(val, bool):
|
|
||||||
results.append(str(val).lower())
|
|
||||||
elif isinstance(val, int):
|
|
||||||
results.append(val)
|
|
||||||
elif isinstance(val, float):
|
|
||||||
results.append(val)
|
|
||||||
else:
|
|
||||||
results.append(str(val))
|
|
||||||
|
|
||||||
while len(results) < MAX_DYNAMIC_OUTPUTS:
|
|
||||||
results.append("")
|
|
||||||
|
|
||||||
return (total_sequences,) + tuple(results)
|
|
||||||
|
|
||||||
|
|
||||||
class ProjectSource:
|
|
||||||
"""Config node — holds project connection settings, outputs sequence_number."""
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(s):
|
|
||||||
return {
|
|
||||||
"required": {
|
|
||||||
"manager_url": ("STRING", {"default": "http://localhost:8080", "multiline": False}),
|
|
||||||
"project_name": ("STRING", {"default": "", "multiline": False}),
|
|
||||||
"file_name": ("STRING", {"default": "", "multiline": False}),
|
|
||||||
"sequence_number": ("INT", {"default": 1, "min": 1, "max": 9999}),
|
|
||||||
"label": ("STRING", {"default": "source", "multiline": False}),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
RETURN_TYPES = ("INT", "STRING", "STRING")
|
|
||||||
RETURN_NAMES = ("sequence_number", "file_name", "project_path")
|
|
||||||
FUNCTION = "hold_config"
|
|
||||||
CATEGORY = "JSON Manager/project"
|
|
||||||
OUTPUT_NODE = True
|
|
||||||
|
|
||||||
def hold_config(self, manager_url, project_name, file_name, sequence_number, label):
|
|
||||||
name = project_name.strip()
|
|
||||||
if not name:
|
|
||||||
active = _fetch_json(f"{manager_url.rstrip('/')}/api/active-project")
|
|
||||||
name = active.get("project", "") if "error" not in active else ""
|
|
||||||
folder_path = ""
|
|
||||||
if name:
|
|
||||||
proj = _fetch_project(manager_url, name)
|
|
||||||
folder_path = proj.get("folder_path", "") if "error" not in proj else ""
|
|
||||||
if folder_path and not folder_path.endswith("/"):
|
|
||||||
folder_path += "/"
|
|
||||||
return (sequence_number, file_name, folder_path)
|
|
||||||
|
|
||||||
|
|
||||||
class ProjectKey:
|
|
||||||
"""Single-output relay — fetches one key from a ProjectSource."""
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(s):
|
|
||||||
return {
|
|
||||||
"required": {
|
|
||||||
"source_label": ("STRING", {"default": "", "multiline": False}),
|
|
||||||
"key_name": ("STRING", {"default": "", "multiline": False}),
|
|
||||||
"key_type": ("STRING", {"default": "STRING", "multiline": False}),
|
|
||||||
},
|
|
||||||
"optional": {
|
|
||||||
"manager_url": ("STRING", {"default": "http://localhost:8080", "multiline": False}),
|
|
||||||
"project_name": ("STRING", {"default": "", "multiline": False}),
|
|
||||||
"file_name": ("STRING", {"default": "", "multiline": False}),
|
|
||||||
"sequence_number": ("INT", {"default": 1, "min": 1, "max": 9999}),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
RETURN_TYPES = (any_type,)
|
|
||||||
RETURN_NAMES = ("value",)
|
|
||||||
FUNCTION = "fetch_key"
|
|
||||||
CATEGORY = "JSON Manager/project"
|
|
||||||
OUTPUT_NODE = False
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def IS_CHANGED(cls, **kwargs):
|
|
||||||
return float("nan") # Always re-fetch from API
|
|
||||||
|
|
||||||
def fetch_key(self, source_label, key_name, key_type,
|
|
||||||
manager_url="http://localhost:8080", project_name="",
|
|
||||||
file_name="", sequence_number=1):
|
|
||||||
# source_label is used by JS to identify which ProjectSource to sync
|
|
||||||
# config from. The actual config arrives via the optional widgets below.
|
|
||||||
sequence_number = int(sequence_number)
|
|
||||||
logger.info("ProjectKey.fetch_key: source=%s key=%s url=%s project=%s file=%s seq=%s",
|
|
||||||
source_label, key_name, manager_url, project_name, file_name, sequence_number)
|
|
||||||
data = _fetch_data(manager_url, project_name, file_name, sequence_number)
|
|
||||||
if data.get("error") in ("http_error", "network_error", "parse_error"):
|
|
||||||
msg = data.get("message", "Unknown error")
|
|
||||||
logger.warning("ProjectKey.fetch_key failed: %s", msg)
|
|
||||||
# Return empty/default instead of crashing the workflow
|
|
||||||
if key_type == "INT":
|
|
||||||
return (0,)
|
|
||||||
elif key_type == "FLOAT":
|
|
||||||
return (0.0,)
|
|
||||||
else:
|
|
||||||
return ("",)
|
|
||||||
|
|
||||||
val = data.get(key_name, "")
|
|
||||||
|
|
||||||
if key_type == "INT":
|
|
||||||
result = to_int(val)
|
|
||||||
return {"ui": {"value": [str(result)]}, "result": (result,)}
|
|
||||||
elif key_type == "FLOAT":
|
|
||||||
result = to_float(val)
|
|
||||||
return {"ui": {"value": [f"{result:.4g}"]}, "result": (result,)}
|
|
||||||
elif isinstance(val, bool):
|
|
||||||
return {"ui": {"value": [str(val).lower()]}, "result": (str(val).lower(),)}
|
|
||||||
elif isinstance(val, (int, float)):
|
|
||||||
return {"ui": {"value": [str(val)]}, "result": (val,)}
|
|
||||||
else:
|
|
||||||
return (str(val),)
|
|
||||||
|
|
||||||
|
|
||||||
class ProjectResolution:
|
|
||||||
"""Fetches a (width, height) pair from a resolution series by loop index."""
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(s):
|
|
||||||
return {
|
|
||||||
"required": {
|
|
||||||
"source_label": ("STRING", {"default": "", "multiline": False}),
|
|
||||||
"key_name": ("STRING", {"default": "resolutions", "multiline": False}),
|
|
||||||
"index": ("INT", {"default": 0, "min": 0, "max": 9999}),
|
|
||||||
},
|
|
||||||
"optional": {
|
|
||||||
"manager_url": ("STRING", {"default": "http://localhost:8080", "multiline": False}),
|
|
||||||
"project_name": ("STRING", {"default": "", "multiline": False}),
|
|
||||||
"file_name": ("STRING", {"default": "", "multiline": False}),
|
|
||||||
"sequence_number": ("INT", {"default": 1, "min": 1, "max": 9999}),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
RETURN_TYPES = ("INT", "INT", "INT")
|
|
||||||
RETURN_NAMES = ("width", "height", "seed")
|
|
||||||
FUNCTION = "fetch_resolution"
|
|
||||||
CATEGORY = "JSON Manager/project"
|
|
||||||
OUTPUT_NODE = False
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def IS_CHANGED(cls, **kwargs):
|
|
||||||
return float("nan")
|
|
||||||
|
|
||||||
def fetch_resolution(self, source_label, key_name, index,
|
|
||||||
manager_url="http://localhost:8080", project_name="",
|
|
||||||
file_name="", sequence_number=1):
|
|
||||||
sequence_number = int(sequence_number)
|
|
||||||
logger.info("ProjectResolution.fetch_resolution: source=%s key=%s url=%s project=%s file=%s seq=%s index=%s",
|
|
||||||
source_label, key_name, manager_url, project_name, file_name, sequence_number, index)
|
|
||||||
# source_label is used by JS to identify which ProjectSource to sync
|
|
||||||
# config from. The actual config arrives via the optional widgets below.
|
|
||||||
data = _fetch_data(manager_url, project_name, file_name, sequence_number)
|
|
||||||
if data.get("error") in ("http_error", "network_error", "parse_error"):
|
|
||||||
logger.warning("ProjectResolution.fetch_resolution failed: %s", data.get("message"))
|
|
||||||
return (512, 512, 0)
|
|
||||||
|
|
||||||
series = data.get(key_name)
|
|
||||||
if not isinstance(series, list) or len(series) == 0:
|
|
||||||
logger.warning("ProjectResolution: key '%s' is not a resolution series", key_name)
|
|
||||||
return (512, 512, 0)
|
|
||||||
|
|
||||||
clamped = max(0, min(index, len(series) - 1))
|
|
||||||
entry = series[clamped]
|
|
||||||
if not isinstance(entry, (list, tuple)) or len(entry) < 2:
|
|
||||||
logger.warning("ProjectResolution: entry at index %d is malformed: %r", clamped, entry)
|
|
||||||
return (512, 512, 0)
|
|
||||||
|
|
||||||
seed = to_int(entry[2]) if len(entry) >= 3 else 0
|
|
||||||
return (to_int(entry[0]), to_int(entry[1]), seed)
|
|
||||||
|
|
||||||
|
|
||||||
class BinaryIndexDecoder:
|
|
||||||
"""Decodes an integer index into 3 boolean flags using binary (bit-field) encoding.
|
|
||||||
|
|
||||||
index 0 → (False, False, False)
|
|
||||||
index 1 → (True, False, False) # bit 0
|
|
||||||
index 2 → (False, True, False) # bit 1
|
|
||||||
index 3 → (True, True, False) # bits 0+1
|
|
||||||
index 4 → (False, False, True) # bit 2
|
|
||||||
...
|
|
||||||
index 7 → (True, True, True)
|
|
||||||
"""
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(s):
|
|
||||||
return {
|
|
||||||
"required": {
|
|
||||||
"index": ("INT", {"default": 0, "min": 0, "max": 7}),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
RETURN_TYPES = ("BOOLEAN", "BOOLEAN", "BOOLEAN")
|
|
||||||
RETURN_NAMES = ("flag_0", "flag_1", "flag_2")
|
|
||||||
FUNCTION = "decode"
|
|
||||||
CATEGORY = "JSON Manager/utils"
|
|
||||||
OUTPUT_NODE = False
|
|
||||||
|
|
||||||
def decode(self, index: int):
|
|
||||||
f0 = bool((index >> 0) & 1)
|
|
||||||
f1 = bool((index >> 1) & 1)
|
|
||||||
f2 = bool((index >> 2) & 1)
|
|
||||||
return {"ui": {"values": [str(f0).lower(), str(f1).lower(), str(f2).lower()]},
|
|
||||||
"result": (f0, f1, f2)}
|
|
||||||
|
|
||||||
|
|
||||||
# --- Mappings ---
|
|
||||||
PROJECT_NODE_CLASS_MAPPINGS = {
|
|
||||||
"ProjectLoaderDynamic": ProjectLoaderDynamic,
|
|
||||||
"ProjectSource": ProjectSource,
|
|
||||||
"ProjectKey": ProjectKey,
|
|
||||||
"ProjectResolution": ProjectResolution,
|
|
||||||
"BinaryIndexDecoder": BinaryIndexDecoder,
|
|
||||||
}
|
|
||||||
|
|
||||||
PROJECT_NODE_DISPLAY_NAME_MAPPINGS = {
|
|
||||||
"ProjectLoaderDynamic": "Project Loader (Dynamic)",
|
|
||||||
"ProjectSource": "Project Source",
|
|
||||||
"ProjectKey": "Project Key",
|
|
||||||
"ProjectResolution": "Project Resolution",
|
|
||||||
"BinaryIndexDecoder": "Binary Index Decoder",
|
|
||||||
}
|
|
||||||
@@ -1,184 +0,0 @@
|
|||||||
import time
|
|
||||||
import uuid
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
KEY_PROMPT_HISTORY = "prompt_history"
|
|
||||||
|
|
||||||
|
|
||||||
class SnapshotTimeline:
|
|
||||||
"""Flat chronological snapshot list — replaces the old HistoryTree DAG."""
|
|
||||||
|
|
||||||
def __init__(self, raw_data: dict[str, Any]) -> None:
|
|
||||||
# Detect and migrate old HistoryTree format
|
|
||||||
if "nodes" in raw_data and "branches" in raw_data:
|
|
||||||
self._migrate_from_tree(raw_data)
|
|
||||||
elif KEY_PROMPT_HISTORY in raw_data and isinstance(raw_data[KEY_PROMPT_HISTORY], list):
|
|
||||||
self._migrate_legacy(raw_data[KEY_PROMPT_HISTORY])
|
|
||||||
else:
|
|
||||||
self.snapshots: dict[str, dict[str, Any]] = raw_data.get("snapshots", {})
|
|
||||||
self.current_id: str | None = raw_data.get("current_id", None)
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# Migration
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
|
|
||||||
def _migrate_from_tree(self, raw_data: dict[str, Any]) -> None:
|
|
||||||
"""Flatten old HistoryTree nodes into snapshot list, discarding DAG info."""
|
|
||||||
self.snapshots = {}
|
|
||||||
nodes = raw_data.get("nodes", {})
|
|
||||||
for nid, node in nodes.items():
|
|
||||||
self.snapshots[nid] = {
|
|
||||||
"id": nid,
|
|
||||||
"timestamp": node.get("timestamp", time.time()),
|
|
||||||
"note": node.get("note", "Migrated"),
|
|
||||||
"pinned": False,
|
|
||||||
"auto": False,
|
|
||||||
"seq_count": self._count_seqs(node.get("data")),
|
|
||||||
}
|
|
||||||
# Preserve snapshot data if present
|
|
||||||
if "data" in node and node["data"]:
|
|
||||||
self.snapshots[nid]["data"] = node["data"]
|
|
||||||
self.current_id = raw_data.get("head_id")
|
|
||||||
|
|
||||||
def _migrate_legacy(self, old_list: list[dict[str, Any]]) -> None:
|
|
||||||
"""Convert ancient prompt_history list into snapshots."""
|
|
||||||
self.snapshots = {}
|
|
||||||
self.current_id = None
|
|
||||||
for item in reversed(old_list):
|
|
||||||
sid = self._make_id()
|
|
||||||
self.snapshots[sid] = {
|
|
||||||
"id": sid,
|
|
||||||
"timestamp": time.time(),
|
|
||||||
"note": item.get("note", "Legacy Import"),
|
|
||||||
"pinned": False,
|
|
||||||
"auto": False,
|
|
||||||
"seq_count": self._count_seqs(item),
|
|
||||||
"data": item,
|
|
||||||
}
|
|
||||||
self.current_id = sid
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# Core operations
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
|
|
||||||
def record(self, data: dict[str, Any], note: str = "Snapshot",
|
|
||||||
auto: bool = False) -> str:
|
|
||||||
"""Create a new snapshot and return its ID."""
|
|
||||||
sid = self._make_id()
|
|
||||||
self.snapshots[sid] = {
|
|
||||||
"id": sid,
|
|
||||||
"timestamp": time.time(),
|
|
||||||
"note": note,
|
|
||||||
"pinned": False,
|
|
||||||
"auto": auto,
|
|
||||||
"seq_count": self._count_seqs(data),
|
|
||||||
"data": data,
|
|
||||||
}
|
|
||||||
self.current_id = sid
|
|
||||||
return sid
|
|
||||||
|
|
||||||
def get_snapshot_data(self, snapshot_id: str) -> dict[str, Any] | None:
|
|
||||||
"""Return the inline snapshot data if present."""
|
|
||||||
snap = self.snapshots.get(snapshot_id)
|
|
||||||
if snap:
|
|
||||||
return snap.get("data")
|
|
||||||
return None
|
|
||||||
|
|
||||||
def toggle_pin(self, snapshot_id: str) -> bool:
|
|
||||||
"""Toggle pinned state, return new value."""
|
|
||||||
snap = self.snapshots.get(snapshot_id)
|
|
||||||
if snap:
|
|
||||||
snap["pinned"] = not snap.get("pinned", False)
|
|
||||||
return snap["pinned"]
|
|
||||||
return False
|
|
||||||
|
|
||||||
def delete(self, snapshot_id: str) -> None:
|
|
||||||
"""Remove a snapshot."""
|
|
||||||
self.snapshots.pop(snapshot_id, None)
|
|
||||||
if self.current_id == snapshot_id:
|
|
||||||
# Fall back to most recent remaining
|
|
||||||
if self.snapshots:
|
|
||||||
self.current_id = max(
|
|
||||||
self.snapshots.values(), key=lambda s: s["timestamp"]
|
|
||||||
)["id"]
|
|
||||||
else:
|
|
||||||
self.current_id = None
|
|
||||||
|
|
||||||
def strip_snapshots(self) -> None:
|
|
||||||
"""Remove inline data from all snapshots (for slim JSON storage)."""
|
|
||||||
for snap in self.snapshots.values():
|
|
||||||
snap.pop("data", None)
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# Serialization
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
|
|
||||||
def to_dict(self) -> dict[str, Any]:
|
|
||||||
return {
|
|
||||||
"snapshots": self.snapshots,
|
|
||||||
"current_id": self.current_id,
|
|
||||||
}
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# Helpers
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
|
|
||||||
def _make_id(self) -> str:
|
|
||||||
for _ in range(10):
|
|
||||||
sid = str(uuid.uuid4())[:8]
|
|
||||||
if sid not in self.snapshots:
|
|
||||||
return sid
|
|
||||||
raise ValueError("Failed to generate unique snapshot ID after 10 attempts")
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _count_seqs(data: dict | None) -> int:
|
|
||||||
if not data:
|
|
||||||
return 0
|
|
||||||
from utils import KEY_BATCH_DATA
|
|
||||||
batch = data.get(KEY_BATCH_DATA, [])
|
|
||||||
return len(batch) if isinstance(batch, list) else 0
|
|
||||||
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# Diff function
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
|
|
||||||
def diff_snapshots(old_batch: list[dict], new_batch: list[dict]) -> list[dict]:
|
|
||||||
"""Compare two batch lists by sequence_number, return per-sequence diffs.
|
|
||||||
|
|
||||||
Returns a list of dicts:
|
|
||||||
{
|
|
||||||
"seq_num": int,
|
|
||||||
"status": "unchanged" | "changed" | "added" | "removed",
|
|
||||||
"changes": [{"field": str, "old": Any, "new": Any}],
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
from utils import KEY_SEQUENCE_NUMBER
|
|
||||||
|
|
||||||
old_by_seq = {int(s.get(KEY_SEQUENCE_NUMBER, 0)): s for s in old_batch}
|
|
||||||
new_by_seq = {int(s.get(KEY_SEQUENCE_NUMBER, 0)): s for s in new_batch}
|
|
||||||
|
|
||||||
all_seqs = sorted(set(old_by_seq) | set(new_by_seq))
|
|
||||||
result = []
|
|
||||||
|
|
||||||
for seq_num in all_seqs:
|
|
||||||
old_item = old_by_seq.get(seq_num)
|
|
||||||
new_item = new_by_seq.get(seq_num)
|
|
||||||
|
|
||||||
if old_item and not new_item:
|
|
||||||
result.append({"seq_num": seq_num, "status": "removed", "changes": []})
|
|
||||||
elif new_item and not old_item:
|
|
||||||
result.append({"seq_num": seq_num, "status": "added", "changes": []})
|
|
||||||
else:
|
|
||||||
# Both exist — field-by-field comparison
|
|
||||||
all_keys = sorted(set(old_item) | set(new_item))
|
|
||||||
changes = []
|
|
||||||
for k in all_keys:
|
|
||||||
old_val = old_item.get(k)
|
|
||||||
new_val = new_item.get(k)
|
|
||||||
if old_val != new_val:
|
|
||||||
changes.append({"field": k, "old": old_val, "new": new_val})
|
|
||||||
status = "changed" if changes else "unchanged"
|
|
||||||
result.append({"seq_num": seq_num, "status": status, "changes": changes})
|
|
||||||
|
|
||||||
return result
|
|
||||||
@@ -1,41 +0,0 @@
|
|||||||
from dataclasses import dataclass, field
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any, Callable
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class AppState:
|
|
||||||
config: dict
|
|
||||||
current_dir: Path
|
|
||||||
loaded_file: str | None = None
|
|
||||||
last_mtime: float = 0
|
|
||||||
data_cache: dict = field(default_factory=dict)
|
|
||||||
snippets: dict = field(default_factory=dict)
|
|
||||||
file_path: Path | None = None
|
|
||||||
restored_indicator: str | None = None
|
|
||||||
timeline_selected_id: str | None = None
|
|
||||||
live_toggles: dict = field(default_factory=dict)
|
|
||||||
show_comfy_monitor: bool = True
|
|
||||||
|
|
||||||
# Project DB fields
|
|
||||||
db: Any = None
|
|
||||||
current_project: str = ""
|
|
||||||
db_enabled: bool = False
|
|
||||||
|
|
||||||
# Set at runtime by main.py / tab_comfy_ng.py
|
|
||||||
_render_main: Any = None
|
|
||||||
_load_file: Callable | None = None
|
|
||||||
_main_rendered: bool = False
|
|
||||||
_live_checkboxes: dict = field(default_factory=dict)
|
|
||||||
_live_refreshables: dict = field(default_factory=dict)
|
|
||||||
_src_cache: dict = field(default_factory=lambda: {'data': None, 'batch': [], 'name': None})
|
|
||||||
|
|
||||||
def create_secondary(self) -> 'AppState':
|
|
||||||
return AppState(
|
|
||||||
config=self.config,
|
|
||||||
current_dir=self.current_dir,
|
|
||||||
snippets=self.snippets,
|
|
||||||
db=self.db,
|
|
||||||
current_project=self.current_project,
|
|
||||||
db_enabled=self.db_enabled,
|
|
||||||
)
|
|
||||||
+304
@@ -0,0 +1,304 @@
|
|||||||
|
import streamlit as st
|
||||||
|
import random
|
||||||
|
from utils import DEFAULTS, save_json, load_json
|
||||||
|
from history_tree import HistoryTree
|
||||||
|
|
||||||
|
def create_batch_callback(original_filename, current_data, current_dir):
|
||||||
|
new_name = f"batch_{original_filename}"
|
||||||
|
new_path = current_dir / new_name
|
||||||
|
|
||||||
|
if new_path.exists():
|
||||||
|
st.toast(f"File {new_name} already exists!", icon="⚠️")
|
||||||
|
return
|
||||||
|
|
||||||
|
first_item = current_data.copy()
|
||||||
|
if "prompt_history" in first_item: del first_item["prompt_history"]
|
||||||
|
if "history_tree" in first_item: del first_item["history_tree"]
|
||||||
|
|
||||||
|
first_item["sequence_number"] = 1
|
||||||
|
|
||||||
|
new_data = {
|
||||||
|
"batch_data": [first_item],
|
||||||
|
"history_tree": {},
|
||||||
|
"prompt_history": []
|
||||||
|
}
|
||||||
|
|
||||||
|
save_json(new_path, new_data)
|
||||||
|
st.toast(f"Created {new_name}", icon="✨")
|
||||||
|
st.session_state.file_selector = new_name
|
||||||
|
|
||||||
|
|
||||||
|
def render_batch_processor(data, file_path, json_files, current_dir, selected_file_name):
|
||||||
|
is_batch_file = "batch_data" in data or isinstance(data, list)
|
||||||
|
|
||||||
|
if not is_batch_file:
|
||||||
|
st.warning("This is a Single file. To use Batch mode, create a copy.")
|
||||||
|
st.button("✨ Create Batch Copy", on_click=create_batch_callback, args=(selected_file_name, data, current_dir))
|
||||||
|
return
|
||||||
|
|
||||||
|
if 'restored_indicator' in st.session_state and st.session_state.restored_indicator:
|
||||||
|
st.info(f"📍 Editing Restored Version: **{st.session_state.restored_indicator}**")
|
||||||
|
|
||||||
|
batch_list = data.get("batch_data", [])
|
||||||
|
|
||||||
|
# --- ADD NEW SEQUENCE AREA ---
|
||||||
|
st.subheader("Add New Sequence")
|
||||||
|
ac1, ac2 = st.columns(2)
|
||||||
|
|
||||||
|
with ac1:
|
||||||
|
file_options = [f.name for f in json_files]
|
||||||
|
d_idx = file_options.index(selected_file_name) if selected_file_name in file_options else 0
|
||||||
|
src_name = st.selectbox("Source File:", file_options, index=d_idx, key="batch_src_file")
|
||||||
|
src_data, _ = load_json(current_dir / src_name)
|
||||||
|
|
||||||
|
with ac2:
|
||||||
|
src_hist = src_data.get("prompt_history", [])
|
||||||
|
h_opts = [f"#{i+1}: {h.get('note', 'No Note')} ({h.get('prompt', '')[:15]}...)" for i, h in enumerate(src_hist)] if src_hist else []
|
||||||
|
sel_hist = st.selectbox("History Entry (Legacy):", h_opts, key="batch_src_hist")
|
||||||
|
|
||||||
|
bc1, bc2, bc3 = st.columns(3)
|
||||||
|
|
||||||
|
def add_sequence(new_item):
|
||||||
|
max_seq = 0
|
||||||
|
for s in batch_list:
|
||||||
|
if "sequence_number" in s: max_seq = max(max_seq, int(s["sequence_number"]))
|
||||||
|
new_item["sequence_number"] = max_seq + 1
|
||||||
|
|
||||||
|
for k in ["prompt_history", "history_tree", "note", "loras"]:
|
||||||
|
if k in new_item: del new_item[k]
|
||||||
|
|
||||||
|
batch_list.append(new_item)
|
||||||
|
data["batch_data"] = batch_list
|
||||||
|
save_json(file_path, data)
|
||||||
|
st.session_state.ui_reset_token += 1
|
||||||
|
st.rerun()
|
||||||
|
|
||||||
|
if bc1.button("➕ Add Empty", use_container_width=True):
|
||||||
|
add_sequence(DEFAULTS.copy())
|
||||||
|
|
||||||
|
if bc2.button("➕ From File", use_container_width=True, help=f"Copy {src_name}"):
|
||||||
|
item = DEFAULTS.copy()
|
||||||
|
flat = src_data["batch_data"][0] if "batch_data" in src_data and src_data["batch_data"] else src_data
|
||||||
|
item.update(flat)
|
||||||
|
add_sequence(item)
|
||||||
|
|
||||||
|
if bc3.button("➕ From History", use_container_width=True, disabled=not src_hist):
|
||||||
|
if sel_hist:
|
||||||
|
idx = int(sel_hist.split(":")[0].replace("#", "")) - 1
|
||||||
|
item = DEFAULTS.copy()
|
||||||
|
h_item = src_hist[idx]
|
||||||
|
item.update(h_item)
|
||||||
|
if "loras" in h_item and isinstance(h_item["loras"], dict):
|
||||||
|
item.update(h_item["loras"])
|
||||||
|
add_sequence(item)
|
||||||
|
|
||||||
|
# --- RENDER LIST ---
|
||||||
|
st.markdown("---")
|
||||||
|
st.info(f"Batch contains {len(batch_list)} sequences.")
|
||||||
|
|
||||||
|
lora_keys = ["lora 1 high", "lora 1 low", "lora 2 high", "lora 2 low", "lora 3 high", "lora 3 low"]
|
||||||
|
standard_keys = {
|
||||||
|
"general_prompt", "general_negative", "current_prompt", "negative", "prompt", "seed",
|
||||||
|
"camera", "flf", "sequence_number"
|
||||||
|
}
|
||||||
|
standard_keys.update(lora_keys)
|
||||||
|
standard_keys.update([
|
||||||
|
"frame_to_skip", "input_a_frames", "input_b_frames", "reference switch", "vace schedule",
|
||||||
|
"reference path", "video file path", "reference image path", "flf image path"
|
||||||
|
])
|
||||||
|
|
||||||
|
for i, seq in enumerate(batch_list):
|
||||||
|
seq_num = seq.get("sequence_number", i+1)
|
||||||
|
prefix = f"{selected_file_name}_seq{i}_v{st.session_state.ui_reset_token}"
|
||||||
|
|
||||||
|
with st.expander(f"🎬 Sequence #{seq_num}", expanded=False):
|
||||||
|
# --- NEW: ACTION ROW WITH CLONING ---
|
||||||
|
act_c1, act_c2, act_c3, act_c4 = st.columns([1.2, 1.8, 1.2, 0.5])
|
||||||
|
|
||||||
|
# 1. Copy Source
|
||||||
|
with act_c1:
|
||||||
|
if st.button(f"📥 Copy {src_name}", key=f"{prefix}_copy", use_container_width=True):
|
||||||
|
item = DEFAULTS.copy()
|
||||||
|
flat = src_data["batch_data"][0] if "batch_data" in src_data and src_data["batch_data"] else src_data
|
||||||
|
item.update(flat)
|
||||||
|
item["sequence_number"] = seq_num
|
||||||
|
for k in ["prompt_history", "history_tree"]:
|
||||||
|
if k in item: del item[k]
|
||||||
|
batch_list[i] = item
|
||||||
|
data["batch_data"] = batch_list
|
||||||
|
save_json(file_path, data)
|
||||||
|
st.session_state.ui_reset_token += 1
|
||||||
|
st.toast("Copied!", icon="📥")
|
||||||
|
st.rerun()
|
||||||
|
|
||||||
|
# 2. Cloning Tools (Next / End)
|
||||||
|
with act_c2:
|
||||||
|
cl_1, cl_2 = st.columns(2)
|
||||||
|
|
||||||
|
# Clone Next
|
||||||
|
if cl_1.button("👯 Next", key=f"{prefix}_c_next", help="Clone and insert below", use_container_width=True):
|
||||||
|
new_seq = seq.copy()
|
||||||
|
# Calculate new max sequence number
|
||||||
|
max_sn = 0
|
||||||
|
for s in batch_list: max_sn = max(max_sn, int(s.get("sequence_number", 0)))
|
||||||
|
new_seq["sequence_number"] = max_sn + 1
|
||||||
|
|
||||||
|
batch_list.insert(i + 1, new_seq)
|
||||||
|
data["batch_data"] = batch_list
|
||||||
|
save_json(file_path, data)
|
||||||
|
st.session_state.ui_reset_token += 1
|
||||||
|
st.toast("Cloned to Next!", icon="👯")
|
||||||
|
st.rerun()
|
||||||
|
|
||||||
|
# Clone End
|
||||||
|
if cl_2.button("⏬ End", key=f"{prefix}_c_end", help="Clone and add to bottom", use_container_width=True):
|
||||||
|
new_seq = seq.copy()
|
||||||
|
max_sn = 0
|
||||||
|
for s in batch_list: max_sn = max(max_sn, int(s.get("sequence_number", 0)))
|
||||||
|
new_seq["sequence_number"] = max_sn + 1
|
||||||
|
|
||||||
|
batch_list.append(new_seq)
|
||||||
|
data["batch_data"] = batch_list
|
||||||
|
save_json(file_path, data)
|
||||||
|
st.session_state.ui_reset_token += 1
|
||||||
|
st.toast("Cloned to End!", icon="⏬")
|
||||||
|
st.rerun()
|
||||||
|
|
||||||
|
# 3. Promote
|
||||||
|
with act_c3:
|
||||||
|
if st.button("↖️ Promote", key=f"{prefix}_prom", help="Save as Single File", use_container_width=True):
|
||||||
|
single_data = seq.copy()
|
||||||
|
single_data["prompt_history"] = data.get("prompt_history", [])
|
||||||
|
single_data["history_tree"] = data.get("history_tree", {})
|
||||||
|
if "sequence_number" in single_data: del single_data["sequence_number"]
|
||||||
|
save_json(file_path, single_data)
|
||||||
|
st.toast("Converted to Single!", icon="✅")
|
||||||
|
st.rerun()
|
||||||
|
|
||||||
|
# 4. Remove
|
||||||
|
with act_c4:
|
||||||
|
if st.button("🗑️", key=f"{prefix}_del", use_container_width=True):
|
||||||
|
batch_list.pop(i)
|
||||||
|
data["batch_data"] = batch_list
|
||||||
|
save_json(file_path, data)
|
||||||
|
st.rerun()
|
||||||
|
|
||||||
|
st.markdown("---")
|
||||||
|
c1, c2 = st.columns([2, 1])
|
||||||
|
with c1:
|
||||||
|
seq["general_prompt"] = st.text_area("General Prompt", value=seq.get("general_prompt", ""), height=60, key=f"{prefix}_gp")
|
||||||
|
seq["general_negative"] = st.text_area("General Negative", value=seq.get("general_negative", ""), height=60, key=f"{prefix}_gn")
|
||||||
|
seq["current_prompt"] = st.text_area("Specific Prompt", value=seq.get("current_prompt", ""), height=100, key=f"{prefix}_sp")
|
||||||
|
seq["negative"] = st.text_area("Specific Negative", value=seq.get("negative", ""), height=60, key=f"{prefix}_sn")
|
||||||
|
|
||||||
|
with c2:
|
||||||
|
seq["sequence_number"] = st.number_input("Seq Num", value=int(seq_num), key=f"{prefix}_sn_val")
|
||||||
|
|
||||||
|
s_row1, s_row2 = st.columns([3, 1])
|
||||||
|
seed_key = f"{prefix}_seed"
|
||||||
|
with s_row2:
|
||||||
|
st.write("")
|
||||||
|
st.write("")
|
||||||
|
if st.button("🎲", key=f"{prefix}_rand"):
|
||||||
|
st.session_state[seed_key] = random.randint(0, 999999999999)
|
||||||
|
st.rerun()
|
||||||
|
with s_row1:
|
||||||
|
current_seed = st.session_state.get(seed_key, int(seq.get("seed", 0)))
|
||||||
|
val = st.number_input("Seed", value=current_seed, key=seed_key)
|
||||||
|
seq["seed"] = val
|
||||||
|
|
||||||
|
seq["camera"] = st.text_input("Camera", value=seq.get("camera", ""), key=f"{prefix}_cam")
|
||||||
|
seq["flf"] = st.text_input("FLF", value=str(seq.get("flf", DEFAULTS["flf"])), key=f"{prefix}_flf")
|
||||||
|
|
||||||
|
if "video file path" in seq or "vace" in selected_file_name:
|
||||||
|
seq["video file path"] = st.text_input("Video Path", value=seq.get("video file path", ""), key=f"{prefix}_vid")
|
||||||
|
with st.expander("VACE Settings"):
|
||||||
|
seq["frame_to_skip"] = st.number_input("Skip", value=int(seq.get("frame_to_skip", 81)), key=f"{prefix}_fts")
|
||||||
|
seq["input_a_frames"] = st.number_input("In A", value=int(seq.get("input_a_frames", 0)), key=f"{prefix}_ia")
|
||||||
|
seq["input_b_frames"] = st.number_input("In B", value=int(seq.get("input_b_frames", 0)), key=f"{prefix}_ib")
|
||||||
|
seq["reference switch"] = st.number_input("Switch", value=int(seq.get("reference switch", 1)), key=f"{prefix}_rsw")
|
||||||
|
seq["vace schedule"] = st.number_input("Sched", value=int(seq.get("vace schedule", 1)), key=f"{prefix}_vsc")
|
||||||
|
seq["reference path"] = st.text_input("Ref Path", value=seq.get("reference path", ""), key=f"{prefix}_rp")
|
||||||
|
seq["reference image path"] = st.text_input("Ref Img", value=seq.get("reference image path", ""), key=f"{prefix}_rip")
|
||||||
|
|
||||||
|
if "i2v" in selected_file_name and "vace" not in selected_file_name:
|
||||||
|
seq["reference image path"] = st.text_input("Ref Img", value=seq.get("reference image path", ""), key=f"{prefix}_ri2")
|
||||||
|
seq["flf image path"] = st.text_input("FLF Img", value=seq.get("flf image path", ""), key=f"{prefix}_flfi")
|
||||||
|
|
||||||
|
# --- LoRA Settings (Reverted to plain text) ---
|
||||||
|
with st.expander("💊 LoRA Settings"):
|
||||||
|
lc1, lc2, lc3 = st.columns(3)
|
||||||
|
with lc1:
|
||||||
|
seq["lora 1 high"] = st.text_input("LoRA 1 Name", value=seq.get("lora 1 high", ""), key=f"{prefix}_l1h")
|
||||||
|
seq["lora 1 low"] = st.text_input("LoRA 1 Strength", value=str(seq.get("lora 1 low", "")), key=f"{prefix}_l1l")
|
||||||
|
with lc2:
|
||||||
|
seq["lora 2 high"] = st.text_input("LoRA 2 Name", value=seq.get("lora 2 high", ""), key=f"{prefix}_l2h")
|
||||||
|
seq["lora 2 low"] = st.text_input("LoRA 2 Strength", value=str(seq.get("lora 2 low", "")), key=f"{prefix}_l2l")
|
||||||
|
with lc3:
|
||||||
|
seq["lora 3 high"] = st.text_input("LoRA 3 Name", value=seq.get("lora 3 high", ""), key=f"{prefix}_l3h")
|
||||||
|
seq["lora 3 low"] = st.text_input("LoRA 3 Strength", value=str(seq.get("lora 3 low", "")), key=f"{prefix}_l3l")
|
||||||
|
|
||||||
|
# --- CUSTOM PARAMETERS ---
|
||||||
|
st.markdown("---")
|
||||||
|
st.caption("🔧 Custom Parameters")
|
||||||
|
|
||||||
|
custom_keys = [k for k in seq.keys() if k not in standard_keys]
|
||||||
|
keys_to_remove = []
|
||||||
|
|
||||||
|
if custom_keys:
|
||||||
|
for k in custom_keys:
|
||||||
|
ck1, ck2, ck3 = st.columns([1, 2, 0.5])
|
||||||
|
ck1.text_input("Key", value=k, disabled=True, key=f"{prefix}_ck_lbl_{k}", label_visibility="collapsed")
|
||||||
|
val = ck2.text_input("Value", value=str(seq[k]), key=f"{prefix}_cv_{k}", label_visibility="collapsed")
|
||||||
|
seq[k] = val
|
||||||
|
|
||||||
|
if ck3.button("🗑️", key=f"{prefix}_cdel_{k}"):
|
||||||
|
keys_to_remove.append(k)
|
||||||
|
|
||||||
|
with st.expander("➕ Add Parameter"):
|
||||||
|
nk_col, nv_col = st.columns(2)
|
||||||
|
new_k = nk_col.text_input("Key", key=f"{prefix}_new_k")
|
||||||
|
new_v = nv_col.text_input("Value", key=f"{prefix}_new_v")
|
||||||
|
|
||||||
|
if st.button("Add", key=f"{prefix}_add_cust"):
|
||||||
|
if new_k and new_k not in seq:
|
||||||
|
seq[new_k] = new_v
|
||||||
|
save_json(file_path, data)
|
||||||
|
st.session_state.ui_reset_token += 1
|
||||||
|
st.rerun()
|
||||||
|
|
||||||
|
if keys_to_remove:
|
||||||
|
for k in keys_to_remove:
|
||||||
|
del seq[k]
|
||||||
|
save_json(file_path, data)
|
||||||
|
st.session_state.ui_reset_token += 1
|
||||||
|
st.rerun()
|
||||||
|
|
||||||
|
st.markdown("---")
|
||||||
|
|
||||||
|
# --- SAVE ACTIONS WITH HISTORY COMMIT ---
|
||||||
|
col_save, col_note = st.columns([1, 2])
|
||||||
|
|
||||||
|
with col_note:
|
||||||
|
commit_msg = st.text_input("Change Note (Optional)", placeholder="e.g. Added sequence 3")
|
||||||
|
|
||||||
|
with col_save:
|
||||||
|
if st.button("💾 Save & Snap", use_container_width=True):
|
||||||
|
data["batch_data"] = batch_list
|
||||||
|
|
||||||
|
tree_data = data.get("history_tree", {})
|
||||||
|
htree = HistoryTree(tree_data)
|
||||||
|
|
||||||
|
snapshot_payload = data.copy()
|
||||||
|
if "history_tree" in snapshot_payload: del snapshot_payload["history_tree"]
|
||||||
|
|
||||||
|
htree.commit(snapshot_payload, note=commit_msg if commit_msg else "Batch Update")
|
||||||
|
|
||||||
|
data["history_tree"] = htree.to_dict()
|
||||||
|
save_json(file_path, data)
|
||||||
|
|
||||||
|
if 'restored_indicator' in st.session_state:
|
||||||
|
del st.session_state.restored_indicator
|
||||||
|
|
||||||
|
st.toast("Batch Saved & Snapshot Created!", icon="🚀")
|
||||||
|
st.rerun()
|
||||||
-974
@@ -1,974 +0,0 @@
|
|||||||
import asyncio
|
|
||||||
import copy
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import math
|
|
||||||
import random
|
|
||||||
import time
|
|
||||||
from pathlib import Path
|
|
||||||
from urllib.parse import quote
|
|
||||||
|
|
||||||
from nicegui import ui
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
from state import AppState
|
|
||||||
from utils import (
|
|
||||||
DEFAULTS, save_json, load_json, sync_to_db,
|
|
||||||
KEY_BATCH_DATA, KEY_HISTORY_TREE, KEY_PROMPT_HISTORY, KEY_SEQUENCE_NUMBER,
|
|
||||||
)
|
|
||||||
from snapshot_timeline import SnapshotTimeline
|
|
||||||
|
|
||||||
IMAGE_EXTENSIONS = {'.png', '.jpg', '.jpeg', '.webp', '.bmp', '.gif'}
|
|
||||||
_AUTO_SNAP_DEBOUNCE = 30 # seconds between auto-snapshots
|
|
||||||
_last_auto_snap: dict[str, float] = {} # file_path -> timestamp
|
|
||||||
SUB_SEGMENT_MULTIPLIER = 1000
|
|
||||||
SUB_SEGMENT_NUM_COLORS = 6
|
|
||||||
FRAME_TO_SKIP_DEFAULT = DEFAULTS['frame_to_skip']
|
|
||||||
|
|
||||||
VACE_MODES = [
|
|
||||||
'End Extend', 'Pre Extend', 'Middle Extend', 'Edge Extend',
|
|
||||||
'Join Extend', 'Bidirectional Extend', 'Frame Interpolation',
|
|
||||||
'Replace/Inpaint', 'Video Inpaint', 'Keyframe',
|
|
||||||
]
|
|
||||||
VACE_FORMULAS = [
|
|
||||||
'base + A', 'base + B', 'base + A + B', 'base + A + B',
|
|
||||||
'base + A + B', 'base + A + B', '(B-1) * step',
|
|
||||||
'snap(source)', 'snap(source)', 'base + A + B',
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
# --- Sub-segment helpers (same as original) ---
|
|
||||||
|
|
||||||
def is_subsegment(seq_num):
|
|
||||||
return int(seq_num) >= SUB_SEGMENT_MULTIPLIER
|
|
||||||
|
|
||||||
def parent_of(seq_num):
|
|
||||||
seq_num = int(seq_num)
|
|
||||||
return seq_num // SUB_SEGMENT_MULTIPLIER if is_subsegment(seq_num) else seq_num
|
|
||||||
|
|
||||||
def sub_index_of(seq_num):
|
|
||||||
seq_num = int(seq_num)
|
|
||||||
return seq_num % SUB_SEGMENT_MULTIPLIER if is_subsegment(seq_num) else 0
|
|
||||||
|
|
||||||
def format_seq_label(seq_num):
|
|
||||||
seq_num = int(seq_num)
|
|
||||||
if is_subsegment(seq_num):
|
|
||||||
return f'Sub #{parent_of(seq_num)}.{sub_index_of(seq_num)}'
|
|
||||||
return f'Sequence #{seq_num}'
|
|
||||||
|
|
||||||
def next_sub_segment_number(batch_list, parent_seq_num):
|
|
||||||
parent_seq_num = int(parent_seq_num)
|
|
||||||
max_sub = 0
|
|
||||||
for s in batch_list:
|
|
||||||
sn = int(s.get(KEY_SEQUENCE_NUMBER, 0))
|
|
||||||
if is_subsegment(sn) and parent_of(sn) == parent_seq_num:
|
|
||||||
max_sub = max(max_sub, sub_index_of(sn))
|
|
||||||
return parent_seq_num * SUB_SEGMENT_MULTIPLIER + max_sub + 1
|
|
||||||
|
|
||||||
def max_main_seq_number(batch_list):
|
|
||||||
"""Highest non-subsegment sequence number in the batch."""
|
|
||||||
return max(
|
|
||||||
(int(x.get(KEY_SEQUENCE_NUMBER, 0))
|
|
||||||
for x in batch_list if not is_subsegment(x.get(KEY_SEQUENCE_NUMBER, 0))),
|
|
||||||
default=0,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def find_insert_position(batch_list, parent_index, parent_seq_num):
|
|
||||||
parent_seq_num = int(parent_seq_num)
|
|
||||||
pos = parent_index + 1
|
|
||||||
while pos < len(batch_list):
|
|
||||||
sn = int(batch_list[pos].get(KEY_SEQUENCE_NUMBER, 0))
|
|
||||||
if is_subsegment(sn) and parent_of(sn) == parent_seq_num:
|
|
||||||
pos += 1
|
|
||||||
else:
|
|
||||||
break
|
|
||||||
return pos
|
|
||||||
|
|
||||||
|
|
||||||
# --- Auto change note ---
|
|
||||||
|
|
||||||
def _auto_change_note(timeline, batch_list, state=None, file_path=None):
|
|
||||||
"""Compare current batch_list against last snapshot and describe changes."""
|
|
||||||
# Get previous batch data from the current snapshot
|
|
||||||
if not timeline.current_id or timeline.current_id not in timeline.snapshots:
|
|
||||||
return f'Initial save ({len(batch_list)} sequences)'
|
|
||||||
|
|
||||||
# Load previous snapshot from inline data or DB
|
|
||||||
prev_data = timeline.get_snapshot_data(timeline.current_id)
|
|
||||||
if not prev_data and state and state.db_enabled and state.db and state.current_project and file_path:
|
|
||||||
df = state.db.get_data_file_by_names(state.current_project, file_path.stem)
|
|
||||||
if df:
|
|
||||||
prev_data = state.db.get_node_snapshot(df['id'], timeline.current_id)
|
|
||||||
prev_batch = (prev_data or {}).get(KEY_BATCH_DATA, [])
|
|
||||||
|
|
||||||
prev_by_seq = {int(s.get(KEY_SEQUENCE_NUMBER, 0)): s for s in prev_batch}
|
|
||||||
curr_by_seq = {int(s.get(KEY_SEQUENCE_NUMBER, 0)): s for s in batch_list}
|
|
||||||
|
|
||||||
added = sorted(set(curr_by_seq) - set(prev_by_seq))
|
|
||||||
removed = sorted(set(prev_by_seq) - set(curr_by_seq))
|
|
||||||
|
|
||||||
changed_keys = set()
|
|
||||||
for seq_num in sorted(set(curr_by_seq) & set(prev_by_seq)):
|
|
||||||
old, new = prev_by_seq[seq_num], curr_by_seq[seq_num]
|
|
||||||
all_keys = set(old) | set(new)
|
|
||||||
for k in all_keys:
|
|
||||||
if old.get(k) != new.get(k):
|
|
||||||
changed_keys.add(k)
|
|
||||||
|
|
||||||
parts = []
|
|
||||||
if added:
|
|
||||||
parts.append(f'Added seq {", ".join(str(s) for s in added)}')
|
|
||||||
if removed:
|
|
||||||
parts.append(f'Removed seq {", ".join(str(s) for s in removed)}')
|
|
||||||
if changed_keys:
|
|
||||||
# Show up to 4 changed field names
|
|
||||||
keys_list = sorted(changed_keys)
|
|
||||||
if len(keys_list) > 4:
|
|
||||||
keys_str = ', '.join(keys_list[:4]) + f' +{len(keys_list) - 4} more'
|
|
||||||
else:
|
|
||||||
keys_str = ', '.join(keys_list)
|
|
||||||
parts.append(f'Changed: {keys_str}')
|
|
||||||
|
|
||||||
return '; '.join(parts) if parts else 'No changes detected'
|
|
||||||
|
|
||||||
|
|
||||||
# --- Helper for repetitive dict-bound inputs ---
|
|
||||||
|
|
||||||
def dict_input(element_fn, label, seq, key, **kwargs):
|
|
||||||
"""Create an input element bound to seq[key] via blur and model-value update."""
|
|
||||||
val = seq.get(key, '')
|
|
||||||
if isinstance(val, (int, float)):
|
|
||||||
val = str(val) if element_fn != ui.number else val
|
|
||||||
el = element_fn(label, value=val, **kwargs)
|
|
||||||
|
|
||||||
def _sync(k=key):
|
|
||||||
seq[k] = el.value
|
|
||||||
|
|
||||||
el.on('blur', lambda _: _sync())
|
|
||||||
el.on('update:model-value', lambda _: _sync())
|
|
||||||
return el
|
|
||||||
|
|
||||||
|
|
||||||
def dict_number(label, seq, key, default=0, **kwargs):
|
|
||||||
"""Number input bound to seq[key] via blur and model-value update."""
|
|
||||||
val = seq.get(key, default)
|
|
||||||
try:
|
|
||||||
# Try float first to handle "1.5" strings, then check if it's a clean int
|
|
||||||
fval = float(val)
|
|
||||||
if not math.isfinite(fval):
|
|
||||||
fval = float(default)
|
|
||||||
val = int(fval) if fval == int(fval) else fval
|
|
||||||
except (ValueError, TypeError, OverflowError):
|
|
||||||
val = default
|
|
||||||
el = ui.number(label, value=val, **kwargs)
|
|
||||||
|
|
||||||
def _sync(k=key, d=default):
|
|
||||||
v = el.value
|
|
||||||
if v is None:
|
|
||||||
v = d
|
|
||||||
elif isinstance(v, float):
|
|
||||||
if not math.isfinite(v):
|
|
||||||
v = d
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
v = int(v) if v == int(v) else v
|
|
||||||
except (OverflowError, ValueError):
|
|
||||||
v = d
|
|
||||||
seq[k] = v
|
|
||||||
|
|
||||||
el.on('blur', lambda _: _sync())
|
|
||||||
el.on('update:model-value', lambda _: _sync())
|
|
||||||
return el
|
|
||||||
|
|
||||||
|
|
||||||
def dict_textarea(label, seq, key, **kwargs):
|
|
||||||
"""Textarea bound to seq[key] via blur and model-value update."""
|
|
||||||
el = ui.textarea(label, value=seq.get(key, ''), **kwargs)
|
|
||||||
|
|
||||||
def _sync(k=key):
|
|
||||||
seq[k] = el.value
|
|
||||||
|
|
||||||
el.on('blur', lambda _: _sync())
|
|
||||||
el.on('update:model-value', lambda _: _sync())
|
|
||||||
return el
|
|
||||||
|
|
||||||
|
|
||||||
# ======================================================================
|
|
||||||
# Main render function
|
|
||||||
# ======================================================================
|
|
||||||
|
|
||||||
def render_batch_processor(state: AppState):
|
|
||||||
t0 = time.perf_counter()
|
|
||||||
logger.info("render_batch_processor START")
|
|
||||||
data = state.data_cache
|
|
||||||
file_path = state.file_path
|
|
||||||
if isinstance(data, list):
|
|
||||||
data = {KEY_BATCH_DATA: data}
|
|
||||||
state.data_cache = data
|
|
||||||
is_batch_file = KEY_BATCH_DATA in data
|
|
||||||
|
|
||||||
if not is_batch_file:
|
|
||||||
ui.label('This is a Single file. To use Batch mode, create a copy.').classes(
|
|
||||||
'text-warning')
|
|
||||||
|
|
||||||
async def create_batch():
|
|
||||||
new_name = f'batch_{file_path.name}'
|
|
||||||
new_path = file_path.parent / new_name
|
|
||||||
if new_path.exists():
|
|
||||||
ui.notify(f'File {new_name} already exists!', type='warning')
|
|
||||||
return
|
|
||||||
first_item = copy.deepcopy(data)
|
|
||||||
first_item.pop(KEY_PROMPT_HISTORY, None)
|
|
||||||
first_item.pop(KEY_HISTORY_TREE, None)
|
|
||||||
first_item[KEY_SEQUENCE_NUMBER] = 1
|
|
||||||
new_data = {KEY_BATCH_DATA: [first_item], KEY_HISTORY_TREE: {},
|
|
||||||
KEY_PROMPT_HISTORY: []}
|
|
||||||
await asyncio.to_thread(save_json, new_path, new_data)
|
|
||||||
if state.db_enabled and state.current_project and state.db:
|
|
||||||
await asyncio.to_thread(sync_to_db, state.db, state.current_project, new_path, new_data)
|
|
||||||
ui.notify(f'Created {new_name}', type='positive')
|
|
||||||
|
|
||||||
ui.button('Create Batch Copy', icon='content_copy', on_click=create_batch)
|
|
||||||
return
|
|
||||||
|
|
||||||
if state.restored_indicator:
|
|
||||||
ui.label(f'Editing Restored Version: {state.restored_indicator}').classes(
|
|
||||||
'text-info q-pa-sm')
|
|
||||||
|
|
||||||
batch_list = data.get(KEY_BATCH_DATA, [])
|
|
||||||
|
|
||||||
# Source file data for importing
|
|
||||||
with ui.card().classes('w-full q-pa-md q-mb-lg'):
|
|
||||||
with ui.expansion('Add New Sequence from Source File', icon='playlist_add').classes('w-full'):
|
|
||||||
json_files = sorted(state.current_dir.glob('*.json'))
|
|
||||||
json_files = [f for f in json_files if f.name not in (
|
|
||||||
'.editor_config.json', '.editor_snippets.json')]
|
|
||||||
file_options = {f.name: f.name for f in json_files}
|
|
||||||
|
|
||||||
src_file_select = ui.select(
|
|
||||||
file_options,
|
|
||||||
value=file_path.name,
|
|
||||||
label='Source File:',
|
|
||||||
).classes('w-64')
|
|
||||||
|
|
||||||
src_seq_select = ui.select([], label='Source Sequence:').classes('w-64')
|
|
||||||
|
|
||||||
# Track loaded source data (on state so it's cleared on file switch)
|
|
||||||
_src_cache = state._src_cache
|
|
||||||
|
|
||||||
def _update_src():
|
|
||||||
name = src_file_select.value
|
|
||||||
if name and name != _src_cache['name']:
|
|
||||||
# Reuse current data if source is the same file
|
|
||||||
if name == file_path.name:
|
|
||||||
src_data = data
|
|
||||||
else:
|
|
||||||
src_data, _ = load_json(state.current_dir / name)
|
|
||||||
_src_cache['data'] = src_data
|
|
||||||
_src_cache['batch'] = src_data.get(KEY_BATCH_DATA, [])
|
|
||||||
_src_cache['name'] = name
|
|
||||||
if _src_cache['batch']:
|
|
||||||
opts = {i: format_seq_label(s.get(KEY_SEQUENCE_NUMBER, i+1))
|
|
||||||
for i, s in enumerate(_src_cache['batch'])}
|
|
||||||
src_seq_select.set_options(opts, value=0)
|
|
||||||
else:
|
|
||||||
src_seq_select.set_options({})
|
|
||||||
|
|
||||||
src_file_select.on_value_change(lambda _: _update_src())
|
|
||||||
_update_src()
|
|
||||||
|
|
||||||
async def _add_sequence(new_item):
|
|
||||||
new_item[KEY_SEQUENCE_NUMBER] = max_main_seq_number(batch_list) + 1
|
|
||||||
for k in [KEY_PROMPT_HISTORY, KEY_HISTORY_TREE, 'note', 'loras']:
|
|
||||||
new_item.pop(k, None)
|
|
||||||
batch_list.append(new_item)
|
|
||||||
data[KEY_BATCH_DATA] = batch_list
|
|
||||||
snapshot = json.loads(json.dumps(data))
|
|
||||||
await asyncio.to_thread(save_json, file_path, snapshot)
|
|
||||||
if state.db_enabled and state.current_project and state.db:
|
|
||||||
await asyncio.to_thread(sync_to_db, state.db, state.current_project, file_path, snapshot)
|
|
||||||
render_sequence_list.refresh()
|
|
||||||
|
|
||||||
with ui.row().classes('q-mt-sm'):
|
|
||||||
async def add_empty():
|
|
||||||
await _add_sequence(copy.deepcopy(DEFAULTS))
|
|
||||||
|
|
||||||
async def add_from_source():
|
|
||||||
item = copy.deepcopy(DEFAULTS)
|
|
||||||
src_batch = _src_cache['batch']
|
|
||||||
sel_idx = src_seq_select.value
|
|
||||||
if src_batch and sel_idx is not None and int(sel_idx) < len(src_batch):
|
|
||||||
item.update(copy.deepcopy(src_batch[int(sel_idx)]))
|
|
||||||
elif _src_cache['data']:
|
|
||||||
item.update(copy.deepcopy(_src_cache['data']))
|
|
||||||
await _add_sequence(item)
|
|
||||||
|
|
||||||
ui.button('Add Empty', icon='add', on_click=add_empty)
|
|
||||||
ui.button('From Source', icon='file_download', on_click=add_from_source)
|
|
||||||
|
|
||||||
# --- Standard / LoRA / VACE key sets ---
|
|
||||||
lora_keys = ['lora 1 high', 'lora 1 high strength', 'lora 1 low', 'lora 1 low strength',
|
|
||||||
'lora 2 high', 'lora 2 high strength', 'lora 2 low', 'lora 2 low strength',
|
|
||||||
'lora 3 high', 'lora 3 high strength', 'lora 3 low', 'lora 3 low strength']
|
|
||||||
standard_keys = {
|
|
||||||
'name', 'mode', 'general_prompt', 'general_negative', 'current_prompt', 'negative', 'prompt',
|
|
||||||
'seed', 'camera', KEY_SEQUENCE_NUMBER,
|
|
||||||
'frame_to_skip', 'logic index', 'transition', 'vace_length',
|
|
||||||
'input_a_frames', 'input_b_frames', 'reference switch', 'vace schedule',
|
|
||||||
'start frame path', 'start frame high strength', 'start frame low strength',
|
|
||||||
'middle frame path', 'middle frame high strength', 'middle frame low strength',
|
|
||||||
'end frame path', 'end frame high strength', 'end frame low strength',
|
|
||||||
'video file path',
|
|
||||||
}
|
|
||||||
standard_keys.update(lora_keys)
|
|
||||||
|
|
||||||
async def sort_by_number():
|
|
||||||
batch_list.sort(key=lambda s: int(s.get(KEY_SEQUENCE_NUMBER, 0)))
|
|
||||||
data[KEY_BATCH_DATA] = batch_list
|
|
||||||
snapshot = json.loads(json.dumps(data))
|
|
||||||
await asyncio.to_thread(save_json, file_path, snapshot)
|
|
||||||
if state.db_enabled and state.current_project and state.db:
|
|
||||||
await asyncio.to_thread(sync_to_db, state.db, state.current_project, file_path, snapshot)
|
|
||||||
ui.notify('Sorted by sequence number!', type='positive')
|
|
||||||
render_sequence_list.refresh()
|
|
||||||
|
|
||||||
# --- Sequence list + mass update (inside refreshable so they stay in sync) ---
|
|
||||||
@ui.refreshable
|
|
||||||
def render_sequence_list():
|
|
||||||
t1 = time.perf_counter()
|
|
||||||
logger.info("render_sequence_list START (%d sequences)", len(batch_list))
|
|
||||||
# Mass update (rebuilt on refresh so checkboxes match current sequences)
|
|
||||||
_render_mass_update(batch_list, data, file_path, state, render_sequence_list)
|
|
||||||
|
|
||||||
with ui.row().classes('w-full items-center'):
|
|
||||||
ui.label(f'Batch contains {len(batch_list)} sequences.')
|
|
||||||
ui.button('Sort by Number', icon='sort', on_click=sort_by_number).props('flat')
|
|
||||||
|
|
||||||
for i, seq in enumerate(batch_list):
|
|
||||||
with ui.card().classes('w-full q-mb-sm'):
|
|
||||||
_render_sequence_card(
|
|
||||||
i, seq, batch_list, data, file_path, state,
|
|
||||||
_src_cache, src_seq_select,
|
|
||||||
standard_keys, render_sequence_list,
|
|
||||||
)
|
|
||||||
logger.info("render_sequence_list END (%.3fs)", time.perf_counter() - t1)
|
|
||||||
|
|
||||||
render_sequence_list()
|
|
||||||
logger.info("render_batch_processor END (%.3fs)", time.perf_counter() - t0)
|
|
||||||
|
|
||||||
# --- Save & Snap ---
|
|
||||||
with ui.card().classes('w-full q-pa-md q-mt-lg'):
|
|
||||||
with ui.row().classes('w-full items-end q-gutter-md'):
|
|
||||||
commit_input = ui.input('Change Note (Optional)',
|
|
||||||
placeholder='e.g. Added sequence 3').classes('col')
|
|
||||||
|
|
||||||
async def save_and_snap():
|
|
||||||
t_ss = time.perf_counter()
|
|
||||||
logger.info("save_and_snap START")
|
|
||||||
data[KEY_BATCH_DATA] = batch_list
|
|
||||||
tree_data = data.get(KEY_HISTORY_TREE, {})
|
|
||||||
timeline = SnapshotTimeline(tree_data)
|
|
||||||
note = commit_input.value if commit_input.value else _auto_change_note(timeline, batch_list, state=state, file_path=file_path)
|
|
||||||
# Single serialization: json roundtrip gives us an isolated snapshot
|
|
||||||
t1 = time.perf_counter()
|
|
||||||
snapshot_json = json.dumps({k: v for k, v in data.items()
|
|
||||||
if k != KEY_HISTORY_TREE})
|
|
||||||
snapshot_payload = json.loads(snapshot_json)
|
|
||||||
logger.info("save_and_snap snapshot %.3fs", time.perf_counter() - t1)
|
|
||||||
try:
|
|
||||||
timeline.record(snapshot_payload, note=note)
|
|
||||||
except ValueError as e:
|
|
||||||
ui.notify(f'Save failed: {e}', type='negative')
|
|
||||||
return
|
|
||||||
if state.db_enabled and state.current_project and state.db:
|
|
||||||
full_tree = timeline.to_dict()
|
|
||||||
data[KEY_HISTORY_TREE] = full_tree
|
|
||||||
t1 = time.perf_counter()
|
|
||||||
db_snapshot = json.loads(json.dumps(data))
|
|
||||||
await asyncio.to_thread(sync_to_db, state.db, state.current_project, file_path, db_snapshot)
|
|
||||||
logger.info("save_and_snap sync_to_db %.3fs", time.perf_counter() - t1)
|
|
||||||
timeline.strip_snapshots()
|
|
||||||
data[KEY_HISTORY_TREE] = timeline.to_dict()
|
|
||||||
t1 = time.perf_counter()
|
|
||||||
slim_snapshot = json.loads(json.dumps(data))
|
|
||||||
await asyncio.to_thread(save_json, file_path, slim_snapshot)
|
|
||||||
logger.info("save_and_snap save_json %.3fs", time.perf_counter() - t1)
|
|
||||||
else:
|
|
||||||
data[KEY_HISTORY_TREE] = timeline.to_dict()
|
|
||||||
t1 = time.perf_counter()
|
|
||||||
save_snapshot = json.loads(json.dumps(data))
|
|
||||||
await asyncio.to_thread(save_json, file_path, save_snapshot)
|
|
||||||
logger.info("save_and_snap save_json %.3fs", time.perf_counter() - t1)
|
|
||||||
state.restored_indicator = None
|
|
||||||
commit_input.set_value('')
|
|
||||||
logger.info("save_and_snap END (%.3fs)", time.perf_counter() - t_ss)
|
|
||||||
ui.notify('Batch Saved & Snapshot Created!', type='positive')
|
|
||||||
|
|
||||||
ui.button('Save & Snap', icon='save', on_click=save_and_snap).props('color=primary')
|
|
||||||
|
|
||||||
|
|
||||||
# ======================================================================
|
|
||||||
# Single sequence card
|
|
||||||
# ======================================================================
|
|
||||||
|
|
||||||
|
|
||||||
def _render_sequence_card(i, seq, batch_list, data, file_path, state,
|
|
||||||
src_cache, src_seq_select, standard_keys,
|
|
||||||
refresh_list):
|
|
||||||
async def commit(message=None):
|
|
||||||
data[KEY_BATCH_DATA] = batch_list
|
|
||||||
# Auto-snapshot with debounce
|
|
||||||
fp_key = str(file_path)
|
|
||||||
now = time.time()
|
|
||||||
did_snap = False
|
|
||||||
if now - _last_auto_snap.get(fp_key, 0) >= _AUTO_SNAP_DEBOUNCE:
|
|
||||||
timeline = SnapshotTimeline(data.get(KEY_HISTORY_TREE, {}))
|
|
||||||
snap_json = json.dumps({k: v for k, v in data.items()
|
|
||||||
if k != KEY_HISTORY_TREE})
|
|
||||||
snap_payload = json.loads(snap_json)
|
|
||||||
try:
|
|
||||||
timeline.record(snap_payload, note=message or "Auto-save", auto=True)
|
|
||||||
if state.db_enabled and state.current_project and state.db:
|
|
||||||
data[KEY_HISTORY_TREE] = timeline.to_dict()
|
|
||||||
db_snap = json.loads(json.dumps(data))
|
|
||||||
await asyncio.to_thread(sync_to_db, state.db, state.current_project, file_path, db_snap)
|
|
||||||
timeline.strip_snapshots()
|
|
||||||
did_snap = True
|
|
||||||
data[KEY_HISTORY_TREE] = timeline.to_dict()
|
|
||||||
_last_auto_snap[fp_key] = now
|
|
||||||
except ValueError:
|
|
||||||
pass # Non-critical: skip auto-snapshot on ID collision
|
|
||||||
snapshot = json.loads(json.dumps(data))
|
|
||||||
await asyncio.to_thread(save_json, file_path, snapshot)
|
|
||||||
if state.db_enabled and state.current_project and state.db and not did_snap:
|
|
||||||
await asyncio.to_thread(sync_to_db, state.db, state.current_project, file_path, snapshot)
|
|
||||||
if message:
|
|
||||||
ui.notify(message, type='positive')
|
|
||||||
refresh_list.refresh()
|
|
||||||
|
|
||||||
seq_num = seq.get(KEY_SEQUENCE_NUMBER, i + 1)
|
|
||||||
seq_name = seq.get('name', '')
|
|
||||||
|
|
||||||
if is_subsegment(seq_num):
|
|
||||||
label = f'Sub #{parent_of(seq_num)}.{sub_index_of(seq_num)} ({int(seq_num)})'
|
|
||||||
else:
|
|
||||||
label = f'Sequence #{seq_num}'
|
|
||||||
if seq_name:
|
|
||||||
label += f' — {seq_name}'
|
|
||||||
|
|
||||||
if is_subsegment(seq_num):
|
|
||||||
color_idx = (sub_index_of(seq_num) - 1) % SUB_SEGMENT_NUM_COLORS
|
|
||||||
exp_classes = f'w-full subsegment-color-{color_idx}'
|
|
||||||
else:
|
|
||||||
exp_classes = 'w-full'
|
|
||||||
with ui.expansion(label, icon='movie').classes(exp_classes) as expansion:
|
|
||||||
# --- Action row ---
|
|
||||||
with ui.row().classes('w-full q-gutter-sm action-row'):
|
|
||||||
# Rename
|
|
||||||
async def rename(s=seq):
|
|
||||||
result = await ui.run_javascript(
|
|
||||||
f'prompt("Rename sequence:", {json.dumps(s.get("name", ""))})',
|
|
||||||
timeout=30.0,
|
|
||||||
)
|
|
||||||
if result is not None:
|
|
||||||
s['name'] = result
|
|
||||||
await commit('Renamed!')
|
|
||||||
|
|
||||||
ui.button('Rename', icon='edit', on_click=rename).props('outline')
|
|
||||||
# Copy from source
|
|
||||||
async def copy_source(idx=i, sn=seq_num):
|
|
||||||
item = copy.deepcopy(DEFAULTS)
|
|
||||||
src_batch = src_cache['batch']
|
|
||||||
sel_idx = src_seq_select.value
|
|
||||||
if src_batch and sel_idx is not None and int(sel_idx) < len(src_batch):
|
|
||||||
item.update(copy.deepcopy(src_batch[int(sel_idx)]))
|
|
||||||
elif src_cache['data']:
|
|
||||||
item.update(copy.deepcopy(src_cache['data']))
|
|
||||||
item[KEY_SEQUENCE_NUMBER] = sn
|
|
||||||
item.pop(KEY_PROMPT_HISTORY, None)
|
|
||||||
item.pop(KEY_HISTORY_TREE, None)
|
|
||||||
batch_list[idx] = item
|
|
||||||
await commit('Copied!')
|
|
||||||
|
|
||||||
ui.button('Copy Src', icon='file_download', on_click=copy_source).props('outline')
|
|
||||||
|
|
||||||
# Clone Next
|
|
||||||
async def clone_next(idx=i, sn=seq_num, s=seq):
|
|
||||||
new_seq = copy.deepcopy(s)
|
|
||||||
new_seq[KEY_SEQUENCE_NUMBER] = max_main_seq_number(batch_list) + 1
|
|
||||||
if not is_subsegment(sn):
|
|
||||||
pos = find_insert_position(batch_list, idx, int(sn))
|
|
||||||
else:
|
|
||||||
pos = idx + 1
|
|
||||||
batch_list.insert(pos, new_seq)
|
|
||||||
await commit('Cloned to Next!')
|
|
||||||
|
|
||||||
ui.button('Clone Next', icon='content_copy', on_click=clone_next).props('outline')
|
|
||||||
|
|
||||||
# Clone End
|
|
||||||
async def clone_end(s=seq):
|
|
||||||
new_seq = copy.deepcopy(s)
|
|
||||||
new_seq[KEY_SEQUENCE_NUMBER] = max_main_seq_number(batch_list) + 1
|
|
||||||
batch_list.append(new_seq)
|
|
||||||
await commit('Cloned to End!')
|
|
||||||
|
|
||||||
ui.button('Clone End', icon='vertical_align_bottom', on_click=clone_end).props('outline')
|
|
||||||
|
|
||||||
# Clone Sub
|
|
||||||
async def clone_sub(idx=i, sn=seq_num, s=seq):
|
|
||||||
new_seq = copy.deepcopy(s)
|
|
||||||
p_seq = parent_of(sn)
|
|
||||||
p_idx = idx
|
|
||||||
if is_subsegment(sn):
|
|
||||||
for pi, ps in enumerate(batch_list):
|
|
||||||
if int(ps.get(KEY_SEQUENCE_NUMBER, 0)) == p_seq:
|
|
||||||
p_idx = pi
|
|
||||||
break
|
|
||||||
new_seq[KEY_SEQUENCE_NUMBER] = next_sub_segment_number(batch_list, p_seq)
|
|
||||||
pos = find_insert_position(batch_list, p_idx, p_seq)
|
|
||||||
batch_list.insert(pos, new_seq)
|
|
||||||
await commit(f'Created {format_seq_label(new_seq[KEY_SEQUENCE_NUMBER])}!')
|
|
||||||
|
|
||||||
ui.button('Clone Sub', icon='link', on_click=clone_sub).props('outline')
|
|
||||||
|
|
||||||
ui.element('div').classes('col')
|
|
||||||
|
|
||||||
# Delete
|
|
||||||
async def delete(idx=i):
|
|
||||||
if idx < len(batch_list):
|
|
||||||
batch_list.pop(idx)
|
|
||||||
await commit()
|
|
||||||
|
|
||||||
ui.button(icon='delete', on_click=delete).props('color=negative')
|
|
||||||
|
|
||||||
ui.separator()
|
|
||||||
|
|
||||||
# --- Prompts + Settings (2-column) ---
|
|
||||||
frame_switches = [] # populated below, used for bidirectional sync with logic index
|
|
||||||
with ui.splitter(value=66).classes('w-full') as splitter:
|
|
||||||
with splitter.before:
|
|
||||||
dict_textarea('General Prompt', seq, 'general_prompt').classes(
|
|
||||||
'w-full q-mt-sm').props('outlined rows=2')
|
|
||||||
dict_textarea('General Negative', seq, 'general_negative').classes(
|
|
||||||
'w-full q-mt-sm').props('outlined rows=2')
|
|
||||||
dict_textarea('Specific Prompt', seq, 'current_prompt').classes(
|
|
||||||
'w-full q-mt-sm').props('outlined rows=10')
|
|
||||||
dict_textarea('Specific Negative', seq, 'negative').classes(
|
|
||||||
'w-full q-mt-sm').props('outlined rows=2')
|
|
||||||
|
|
||||||
# --- Frame paths (start / middle / end) ---
|
|
||||||
logic_val = int(seq.get('logic index', 0))
|
|
||||||
for bit, img_label, img_key, hi_key, lo_key in [
|
|
||||||
(0, 'Start Frame', 'start frame path', 'start frame high strength', 'start frame low strength'),
|
|
||||||
(1, 'Middle Frame', 'middle frame path', 'middle frame high strength', 'middle frame low strength'),
|
|
||||||
(2, 'End Frame', 'end frame path', 'end frame high strength', 'end frame low strength'),
|
|
||||||
]:
|
|
||||||
ui.label(img_label).classes('text-caption text-weight-bold q-mt-sm')
|
|
||||||
is_on = bool((logic_val >> bit) & 1)
|
|
||||||
with ui.row().classes('w-full items-center no-wrap q-mt-xs'):
|
|
||||||
inp = dict_input(ui.input, 'Path', seq, img_key).classes(
|
|
||||||
'col').props('outlined dense input-style="text-align: right"')
|
|
||||||
thumb = None
|
|
||||||
img_path = Path(seq.get(img_key, '')) if seq.get(img_key) else None
|
|
||||||
if (img_path and img_path.exists() and
|
|
||||||
img_path.suffix.lower() in IMAGE_EXTENSIONS):
|
|
||||||
img_url = f'/api/image-preview?path={quote(str(img_path))}'
|
|
||||||
with ui.dialog() as img_dlg, ui.card().style('max-width:90vw; padding:0'):
|
|
||||||
ui.html(f'<img src="{img_url}" '
|
|
||||||
f'style="max-width:80vw;max-height:80vh;display:block">')
|
|
||||||
thumb = ui.html(
|
|
||||||
f'<img src="{img_url}" '
|
|
||||||
f'style="width:36px;height:36px;object-fit:cover;'
|
|
||||||
f'border-radius:4px;cursor:pointer;flex-shrink:0;'
|
|
||||||
f'opacity:{"1.0" if is_on else "0.25"}">'
|
|
||||||
).on('click', img_dlg.open)
|
|
||||||
sw = ui.switch(value=is_on)
|
|
||||||
frame_switches.append(sw)
|
|
||||||
if thumb is not None:
|
|
||||||
sw.on('update:model-value',
|
|
||||||
lambda e, t=thumb, s=sw: t.style(f'opacity: {"1.0" if s.value else "0.25"}'))
|
|
||||||
with ui.row().classes('w-full no-wrap q-mt-xs q-gutter-xs'):
|
|
||||||
dict_number('High', seq, hi_key, default=1.0,
|
|
||||||
step=0.05, format='%.2f').classes('col').props('outlined dense')
|
|
||||||
dict_number('Low', seq, lo_key, default=1.0,
|
|
||||||
step=0.05, format='%.2f').classes('col').props('outlined dense')
|
|
||||||
|
|
||||||
with splitter.after:
|
|
||||||
# Mode
|
|
||||||
dict_number('Mode', seq, 'mode').props('outlined').classes('w-full')
|
|
||||||
|
|
||||||
# Sequence number
|
|
||||||
sn_label = (
|
|
||||||
f'Seq Number (Sub #{parent_of(seq_num)}.{sub_index_of(seq_num)})'
|
|
||||||
if is_subsegment(seq_num) else 'Sequence Number'
|
|
||||||
)
|
|
||||||
sn_input = dict_number(sn_label, seq, KEY_SEQUENCE_NUMBER)
|
|
||||||
sn_input.props('outlined').classes('w-full')
|
|
||||||
|
|
||||||
# Seed + randomize
|
|
||||||
with ui.row().classes('w-full items-end'):
|
|
||||||
seed_input = dict_number('Seed', seq, 'seed').classes('col').props('outlined')
|
|
||||||
|
|
||||||
def randomize_seed(si=seed_input, s=seq):
|
|
||||||
new_seed = random.randint(0, 999999999999)
|
|
||||||
si.set_value(new_seed)
|
|
||||||
s['seed'] = new_seed
|
|
||||||
|
|
||||||
ui.button(icon='casino', on_click=randomize_seed).props('flat')
|
|
||||||
|
|
||||||
dict_input(ui.input, 'Camera', seq, 'camera').props('outlined').classes('w-full')
|
|
||||||
seq.setdefault('logic index', 0)
|
|
||||||
li_input = dict_number('Logic Index', seq, 'logic index').props('outlined readonly').classes('w-full')
|
|
||||||
with li_input:
|
|
||||||
ui.tooltip(
|
|
||||||
'Binary flags — bit 0: start frame | bit 1: middle frame | bit 2: end frame\n'
|
|
||||||
'0: none 1: start 2: middle 3: start+middle\n'
|
|
||||||
'4: end 5: start+end 6: middle+end 7: all'
|
|
||||||
)
|
|
||||||
dict_input(ui.input, 'Video File Path', seq, 'video file path').props(
|
|
||||||
'outlined input-style="text-align: right"').classes('w-full')
|
|
||||||
|
|
||||||
# Switches → logic index (sole writer)
|
|
||||||
def _sync_switches_to_logic(li=li_input, switches=frame_switches, s=seq):
|
|
||||||
v = sum(int(sw.value) << b for b, sw in enumerate(switches))
|
|
||||||
s['logic index'] = v
|
|
||||||
li.set_value(v)
|
|
||||||
|
|
||||||
for frame_sw in frame_switches:
|
|
||||||
frame_sw.on('update:model-value', lambda _, s=_sync_switches_to_logic: s())
|
|
||||||
|
|
||||||
# --- Resolutions (8 fixed slots) ---
|
|
||||||
resolutions = seq.setdefault('resolutions', [])
|
|
||||||
while len(resolutions) < 8:
|
|
||||||
resolutions.append([512, 512, 0])
|
|
||||||
for r_i in range(len(resolutions)):
|
|
||||||
if len(resolutions[r_i]) < 3:
|
|
||||||
resolutions[r_i] = list(resolutions[r_i]) + [0]
|
|
||||||
with ui.expansion('Resolutions', icon='aspect_ratio').classes('w-full'):
|
|
||||||
for idx in range(8):
|
|
||||||
entry = resolutions[idx]
|
|
||||||
with ui.row().classes('items-center w-full q-mt-xs no-wrap'):
|
|
||||||
ui.label(str(idx)).classes('text-caption').style('min-width:16px')
|
|
||||||
w_inp = ui.number(value=int(entry[0]), min=1, step=1, label='W').style(
|
|
||||||
'width:70px').props('outlined dense hide-bottom-space')
|
|
||||||
h_inp = ui.number(value=int(entry[1]), min=1, step=1, label='H').style(
|
|
||||||
'width:70px').props('outlined dense hide-bottom-space')
|
|
||||||
seed_inp = ui.number(value=int(entry[2]), min=0, step=1, label='Seed').style(
|
|
||||||
'flex:1; min-width:60px').props('outlined dense hide-bottom-space')
|
|
||||||
|
|
||||||
async def _sync_entry(r=idx, wi=w_inp, hi=h_inp, si=seed_inp):
|
|
||||||
seq['resolutions'][r] = [
|
|
||||||
int(wi.value) if wi.value else 512,
|
|
||||||
int(hi.value) if hi.value else 512,
|
|
||||||
int(si.value) if si.value else 0,
|
|
||||||
]
|
|
||||||
await commit()
|
|
||||||
|
|
||||||
async def _randomize(si=seed_inp, r=idx):
|
|
||||||
si.value = random.randint(0, 2**32 - 1)
|
|
||||||
seq['resolutions'][r][2] = int(si.value)
|
|
||||||
await commit()
|
|
||||||
|
|
||||||
ui.button(icon='casino', on_click=_randomize).props(
|
|
||||||
'flat dense round').classes('q-ml-xs')
|
|
||||||
|
|
||||||
w_inp.on('blur', lambda _, s=_sync_entry: s())
|
|
||||||
w_inp.on('update:model-value', lambda _, s=_sync_entry: s())
|
|
||||||
h_inp.on('blur', lambda _, s=_sync_entry: s())
|
|
||||||
h_inp.on('update:model-value', lambda _, s=_sync_entry: s())
|
|
||||||
seed_inp.on('blur', lambda _, s=_sync_entry: s())
|
|
||||||
seed_inp.on('update:model-value', lambda _, s=_sync_entry: s())
|
|
||||||
|
|
||||||
# --- VACE Settings (full width) ---
|
|
||||||
with ui.expansion('VACE Settings', icon='settings').classes('w-full'):
|
|
||||||
_render_vace_settings(i, seq, batch_list, data, file_path, state, refresh_list)
|
|
||||||
|
|
||||||
# --- LoRA Settings ---
|
|
||||||
with ui.expansion('LoRA Settings', icon='style').classes('w-full'):
|
|
||||||
for lora_idx in range(1, 4):
|
|
||||||
for tier, tier_label in [('high', 'High'), ('low', 'Low')]:
|
|
||||||
lora_key = f'lora {lora_idx} {tier}'
|
|
||||||
|
|
||||||
lora_name = str(seq.get(lora_key, ''))
|
|
||||||
strength_key = f'lora {lora_idx} {tier} strength'
|
|
||||||
lora_strength = seq.get(strength_key, 1.0)
|
|
||||||
try:
|
|
||||||
lora_strength = float(lora_strength)
|
|
||||||
except (ValueError, TypeError):
|
|
||||||
lora_strength = 1.0
|
|
||||||
|
|
||||||
with ui.row().classes('w-full items-center q-gutter-sm'):
|
|
||||||
ui.label(f'L{lora_idx} {tier_label}').classes(
|
|
||||||
'text-caption').style('min-width: 55px')
|
|
||||||
name_input = ui.input(
|
|
||||||
'Name',
|
|
||||||
value=lora_name,
|
|
||||||
).classes('col').props('outlined dense')
|
|
||||||
strength_input = ui.number(
|
|
||||||
'Str',
|
|
||||||
value=lora_strength,
|
|
||||||
min=0, max=10, step=0.1,
|
|
||||||
format='%.1f',
|
|
||||||
).props('outlined dense').style('max-width: 80px')
|
|
||||||
|
|
||||||
def _lora_sync(k=lora_key, sk=strength_key, n_inp=name_input, s_inp=strength_input):
|
|
||||||
seq[k] = n_inp.value or ''
|
|
||||||
seq[sk] = float(s_inp.value) if s_inp.value is not None else 1.0
|
|
||||||
|
|
||||||
name_input.on('blur', lambda _, s=_lora_sync: s())
|
|
||||||
name_input.on('update:model-value', lambda _, s=_lora_sync: s())
|
|
||||||
strength_input.on('blur', lambda _, s=_lora_sync: s())
|
|
||||||
strength_input.on('update:model-value', lambda _, s=_lora_sync: s())
|
|
||||||
|
|
||||||
# --- Custom Parameters ---
|
|
||||||
ui.label('Custom Parameters').classes('section-header q-mt-md')
|
|
||||||
|
|
||||||
custom_keys = [k for k in seq.keys() if k not in standard_keys and k != 'resolutions']
|
|
||||||
if custom_keys:
|
|
||||||
for k in custom_keys:
|
|
||||||
with ui.row().classes('w-full items-center'):
|
|
||||||
ui.input('Key', value=k).props('readonly outlined dense').classes('w-32')
|
|
||||||
dict_input(ui.input, 'Value', seq, k).props('outlined dense').classes('col')
|
|
||||||
|
|
||||||
async def del_custom(key=k):
|
|
||||||
del seq[key]
|
|
||||||
await commit()
|
|
||||||
|
|
||||||
ui.button(icon='delete', on_click=del_custom).props('flat dense color=negative')
|
|
||||||
|
|
||||||
with ui.expansion('Add Parameter', icon='add').classes('w-full'):
|
|
||||||
new_k_input = ui.input('Key').props('outlined dense')
|
|
||||||
new_v_input = ui.input('Value').props('outlined dense')
|
|
||||||
|
|
||||||
async def add_param():
|
|
||||||
k = new_k_input.value
|
|
||||||
v = new_v_input.value
|
|
||||||
if k and k not in seq:
|
|
||||||
seq[k] = v
|
|
||||||
new_k_input.set_value('')
|
|
||||||
new_v_input.set_value('')
|
|
||||||
await commit()
|
|
||||||
|
|
||||||
ui.button('Add', on_click=add_param).props('flat')
|
|
||||||
|
|
||||||
|
|
||||||
# ======================================================================
|
|
||||||
# VACE Settings sub-section
|
|
||||||
# ======================================================================
|
|
||||||
|
|
||||||
def _render_vace_settings(i, seq, batch_list, data, file_path, state, refresh_list):
|
|
||||||
# VACE Schedule (needed early for both columns)
|
|
||||||
def _safe_int(val, default=0):
|
|
||||||
try:
|
|
||||||
return int(float(val))
|
|
||||||
except (ValueError, TypeError, OverflowError):
|
|
||||||
return default
|
|
||||||
|
|
||||||
sched_val = max(0, min(_safe_int(seq.get('vace schedule', 1), 1), len(VACE_MODES) - 1))
|
|
||||||
|
|
||||||
# Mode reference dialog
|
|
||||||
with ui.dialog() as ref_dlg, ui.card():
|
|
||||||
table_md = (
|
|
||||||
'| # | Mode | Formula |\n|:--|:-----|:--------|\n'
|
|
||||||
+ '\n'.join(
|
|
||||||
f'| **{j}** | {VACE_MODES[j]} | `{VACE_FORMULAS[j]}` |'
|
|
||||||
for j in range(len(VACE_MODES)))
|
|
||||||
+ '\n\n*All totals snapped to 4n+1 (1,5,9,...,49,...,81,...)*'
|
|
||||||
)
|
|
||||||
ui.markdown(table_md)
|
|
||||||
|
|
||||||
with ui.row().classes('w-full q-gutter-md'):
|
|
||||||
# --- Left column ---
|
|
||||||
with ui.column().classes('col'):
|
|
||||||
# Frame to Skip + shift
|
|
||||||
with ui.row().classes('w-full items-end'):
|
|
||||||
fts_input = dict_number('Frame to Skip', seq, 'frame_to_skip').classes(
|
|
||||||
'col').props('outlined')
|
|
||||||
|
|
||||||
_original_fts = _safe_int(seq.get('frame_to_skip', FRAME_TO_SKIP_DEFAULT), FRAME_TO_SKIP_DEFAULT)
|
|
||||||
|
|
||||||
async def shift_fts(idx=i, orig=_original_fts):
|
|
||||||
new_fts = _safe_int(fts_input.value, orig)
|
|
||||||
delta = new_fts - orig
|
|
||||||
if delta == 0:
|
|
||||||
ui.notify('No change to shift', type='info')
|
|
||||||
return
|
|
||||||
shifted = 0
|
|
||||||
for j in range(idx + 1, len(batch_list)):
|
|
||||||
batch_list[j]['frame_to_skip'] = _safe_int(
|
|
||||||
batch_list[j].get('frame_to_skip', FRAME_TO_SKIP_DEFAULT), FRAME_TO_SKIP_DEFAULT) + delta
|
|
||||||
shifted += 1
|
|
||||||
data[KEY_BATCH_DATA] = batch_list
|
|
||||||
snapshot = json.loads(json.dumps(data))
|
|
||||||
await asyncio.to_thread(save_json, file_path, snapshot)
|
|
||||||
if state.db_enabled and state.current_project and state.db:
|
|
||||||
await asyncio.to_thread(sync_to_db, state.db, state.current_project, file_path, snapshot)
|
|
||||||
ui.notify(f'Shifted {shifted} sequences by {delta:+d}', type='positive')
|
|
||||||
refresh_list.refresh()
|
|
||||||
|
|
||||||
ui.button('Shift', icon='arrow_downward', on_click=shift_fts).props(
|
|
||||||
'outline').style('height: 40px')
|
|
||||||
|
|
||||||
dict_input(ui.input, 'Transition', seq, 'transition').props('outlined').classes(
|
|
||||||
'w-full q-mt-sm')
|
|
||||||
|
|
||||||
# VACE Schedule
|
|
||||||
with ui.row().classes('w-full items-center q-mt-sm'):
|
|
||||||
vs_input = dict_number('VACE Schedule', seq, 'vace schedule', default=1,
|
|
||||||
min=0, max=len(VACE_MODES) - 1).classes('col').props(
|
|
||||||
'outlined')
|
|
||||||
mode_label = ui.label(VACE_MODES[sched_val]).classes('text-caption')
|
|
||||||
ui.button(icon='help', on_click=ref_dlg.open).props('flat dense round')
|
|
||||||
|
|
||||||
def update_mode_label(e):
|
|
||||||
idx = _safe_int(e.sender.value, 0)
|
|
||||||
idx = max(0, min(idx, len(VACE_MODES) - 1))
|
|
||||||
mode_label.set_text(VACE_MODES[idx])
|
|
||||||
|
|
||||||
vs_input.on('update:model-value', update_mode_label)
|
|
||||||
|
|
||||||
# --- Right column ---
|
|
||||||
with ui.column().classes('col'):
|
|
||||||
ia_input = dict_number('Input A Frames', seq, 'input_a_frames').props(
|
|
||||||
'outlined').classes('w-full')
|
|
||||||
ib_input = dict_number('Input B Frames', seq, 'input_b_frames').props(
|
|
||||||
'outlined').classes('w-full q-mt-sm')
|
|
||||||
|
|
||||||
# VACE Length + output calculation
|
|
||||||
input_a = _safe_int(seq.get('input_a_frames', 16), 16)
|
|
||||||
input_b = _safe_int(seq.get('input_b_frames', 16), 16)
|
|
||||||
stored_total = _safe_int(seq.get('vace_length', 49), 49)
|
|
||||||
mode_idx = _safe_int(seq.get('vace schedule', 1), 1)
|
|
||||||
|
|
||||||
if mode_idx == 0:
|
|
||||||
base_length = max(stored_total - input_a, 1)
|
|
||||||
elif mode_idx == 1:
|
|
||||||
base_length = max(stored_total - input_b, 1)
|
|
||||||
else:
|
|
||||||
base_length = max(stored_total - input_a - input_b, 1)
|
|
||||||
|
|
||||||
with ui.row().classes('w-full items-center q-mt-sm'):
|
|
||||||
vl_input = ui.number('VACE Length', value=base_length, min=1).classes(
|
|
||||||
'col').props('outlined')
|
|
||||||
output_label = ui.label(f'Output: {stored_total}').classes('text-bold')
|
|
||||||
|
|
||||||
dict_number('Reference Switch', seq, 'reference switch').props(
|
|
||||||
'outlined').classes('w-full q-mt-sm')
|
|
||||||
|
|
||||||
# Recalculate VACE output when any input changes
|
|
||||||
def recalc_vace(*_args):
|
|
||||||
mi = _safe_int(vs_input.value, 0)
|
|
||||||
ia = _safe_int(ia_input.value, 16)
|
|
||||||
ib = _safe_int(ib_input.value, 16)
|
|
||||||
nb = _safe_int(vl_input.value, 1)
|
|
||||||
|
|
||||||
if mi == 0:
|
|
||||||
raw = nb + ia
|
|
||||||
elif mi == 1:
|
|
||||||
raw = nb + ib
|
|
||||||
else:
|
|
||||||
raw = nb + ia + ib
|
|
||||||
|
|
||||||
snapped = ((raw + 2) // 4) * 4 + 1
|
|
||||||
seq['vace_length'] = snapped
|
|
||||||
output_label.set_text(f'Output: {snapped}')
|
|
||||||
|
|
||||||
for inp in (vs_input, ia_input, ib_input, vl_input):
|
|
||||||
inp.on('update:model-value', recalc_vace)
|
|
||||||
|
|
||||||
|
|
||||||
# ======================================================================
|
|
||||||
# Mass Update
|
|
||||||
# ======================================================================
|
|
||||||
|
|
||||||
def _render_mass_update(batch_list, data, file_path, state: AppState, refresh_list=None):
|
|
||||||
with ui.expansion('Mass Update', icon='sync').classes('w-full'):
|
|
||||||
if len(batch_list) < 2:
|
|
||||||
ui.label('Need at least 2 sequences for mass update.').classes('text-caption')
|
|
||||||
return
|
|
||||||
|
|
||||||
source_options = {i: format_seq_label(s.get(KEY_SEQUENCE_NUMBER, i+1))
|
|
||||||
for i, s in enumerate(batch_list)}
|
|
||||||
source_select = ui.select(source_options, value=0,
|
|
||||||
label='Copy from sequence:').classes('w-full')
|
|
||||||
|
|
||||||
field_select = ui.select([], multiple=True,
|
|
||||||
label='Fields to copy:').classes('w-full')
|
|
||||||
|
|
||||||
def update_fields(_=None):
|
|
||||||
idx = source_select.value
|
|
||||||
if idx is not None and 0 <= idx < len(batch_list):
|
|
||||||
src = batch_list[idx]
|
|
||||||
keys = [k for k in src.keys() if k != 'sequence_number']
|
|
||||||
field_select.set_options(keys)
|
|
||||||
|
|
||||||
source_select.on_value_change(update_fields)
|
|
||||||
update_fields()
|
|
||||||
|
|
||||||
ui.label('Apply to:').classes('subsection-header q-mt-md')
|
|
||||||
select_all_cb = ui.checkbox('Select All')
|
|
||||||
target_checks = {}
|
|
||||||
with ui.scroll_area().style('max-height: 250px'):
|
|
||||||
for idx, s in enumerate(batch_list):
|
|
||||||
sn = s.get(KEY_SEQUENCE_NUMBER, idx + 1)
|
|
||||||
cb = ui.checkbox(format_seq_label(sn))
|
|
||||||
target_checks[idx] = cb
|
|
||||||
|
|
||||||
def on_select_all(e):
|
|
||||||
for cb in target_checks.values():
|
|
||||||
cb.set_value(e.value)
|
|
||||||
|
|
||||||
select_all_cb.on_value_change(on_select_all)
|
|
||||||
|
|
||||||
async def apply_mass_update():
|
|
||||||
src_idx = source_select.value
|
|
||||||
if src_idx is None or src_idx >= len(batch_list):
|
|
||||||
ui.notify('Source sequence no longer exists', type='warning')
|
|
||||||
return
|
|
||||||
selected_keys = field_select.value or []
|
|
||||||
if not selected_keys:
|
|
||||||
ui.notify('No fields selected', type='warning')
|
|
||||||
return
|
|
||||||
|
|
||||||
source_seq = batch_list[src_idx]
|
|
||||||
targets = [idx for idx, cb in target_checks.items()
|
|
||||||
if cb.value and idx != src_idx and idx < len(batch_list)]
|
|
||||||
if not targets:
|
|
||||||
ui.notify('No target sequences selected', type='warning')
|
|
||||||
return
|
|
||||||
|
|
||||||
for idx in targets:
|
|
||||||
for key in selected_keys:
|
|
||||||
batch_list[idx][key] = copy.deepcopy(source_seq.get(key))
|
|
||||||
|
|
||||||
data[KEY_BATCH_DATA] = batch_list
|
|
||||||
timeline = SnapshotTimeline(data.get(KEY_HISTORY_TREE, {}))
|
|
||||||
snapshot_json = json.dumps({k: v for k, v in data.items()
|
|
||||||
if k != KEY_HISTORY_TREE})
|
|
||||||
snapshot = json.loads(snapshot_json)
|
|
||||||
try:
|
|
||||||
timeline.record(snapshot, f"Mass update: {', '.join(selected_keys)}")
|
|
||||||
except ValueError as e:
|
|
||||||
ui.notify(f'Mass update failed: {e}', type='negative')
|
|
||||||
return
|
|
||||||
if state.db_enabled and state.current_project and state.db:
|
|
||||||
full_tree = timeline.to_dict()
|
|
||||||
data[KEY_HISTORY_TREE] = full_tree
|
|
||||||
db_snapshot = json.loads(json.dumps(data))
|
|
||||||
await asyncio.to_thread(sync_to_db, state.db, state.current_project, file_path, db_snapshot)
|
|
||||||
timeline.strip_snapshots()
|
|
||||||
data[KEY_HISTORY_TREE] = timeline.to_dict()
|
|
||||||
slim_snapshot = json.loads(json.dumps(data))
|
|
||||||
await asyncio.to_thread(save_json, file_path, slim_snapshot)
|
|
||||||
else:
|
|
||||||
data[KEY_HISTORY_TREE] = timeline.to_dict()
|
|
||||||
save_snapshot = json.loads(json.dumps(data))
|
|
||||||
await asyncio.to_thread(save_json, file_path, save_snapshot)
|
|
||||||
ui.notify(f'Updated {len(targets)} sequences', type='positive')
|
|
||||||
if refresh_list:
|
|
||||||
refresh_list.refresh()
|
|
||||||
|
|
||||||
ui.button('Apply Changes', icon='check', on_click=apply_mass_update).props(
|
|
||||||
'color=primary')
|
|
||||||
+165
@@ -0,0 +1,165 @@
|
|||||||
|
import streamlit as st
|
||||||
|
import requests
|
||||||
|
from PIL import Image
|
||||||
|
from io import BytesIO
|
||||||
|
from utils import save_config
|
||||||
|
|
||||||
|
def render_single_instance(instance_config, index, all_instances):
|
||||||
|
url = instance_config.get("url", "http://127.0.0.1:8188")
|
||||||
|
name = instance_config.get("name", f"Server {index+1}")
|
||||||
|
|
||||||
|
COMFY_URL = url.rstrip("/")
|
||||||
|
|
||||||
|
c_head, c_set = st.columns([3, 1])
|
||||||
|
c_head.markdown(f"### 🔌 {name}")
|
||||||
|
|
||||||
|
with c_set.popover("⚙️ Settings"):
|
||||||
|
st.caption("Press Update to apply changes!")
|
||||||
|
new_name = st.text_input("Name", value=name, key=f"name_{index}")
|
||||||
|
new_url = st.text_input("URL", value=url, key=f"url_{index}")
|
||||||
|
|
||||||
|
if new_url != url:
|
||||||
|
st.warning("⚠️ Unsaved URL! Click Update below.")
|
||||||
|
|
||||||
|
if st.button("💾 Update & Save", key=f"save_{index}", type="primary"):
|
||||||
|
all_instances[index]["name"] = new_name
|
||||||
|
all_instances[index]["url"] = new_url
|
||||||
|
st.session_state.config["comfy_instances"] = all_instances
|
||||||
|
|
||||||
|
save_config(
|
||||||
|
st.session_state.current_dir,
|
||||||
|
st.session_state.config['favorites'],
|
||||||
|
{"comfy_instances": all_instances}
|
||||||
|
)
|
||||||
|
st.toast("Server config saved!", icon="💾")
|
||||||
|
st.rerun()
|
||||||
|
|
||||||
|
st.divider()
|
||||||
|
if st.button("🗑️ Remove Server", key=f"del_{index}"):
|
||||||
|
all_instances.pop(index)
|
||||||
|
st.session_state.config["comfy_instances"] = all_instances
|
||||||
|
save_config(
|
||||||
|
st.session_state.current_dir,
|
||||||
|
st.session_state.config['favorites'],
|
||||||
|
{"comfy_instances": all_instances}
|
||||||
|
)
|
||||||
|
st.rerun()
|
||||||
|
|
||||||
|
# --- 1. STATUS DASHBOARD ---
|
||||||
|
with st.expander("📊 Server Status", expanded=True):
|
||||||
|
col1, col2, col3, col4 = st.columns([1, 1, 1, 1])
|
||||||
|
try:
|
||||||
|
res = requests.get(f"{COMFY_URL}/queue", timeout=1.5)
|
||||||
|
queue_data = res.json()
|
||||||
|
running_cnt = len(queue_data.get("queue_running", []))
|
||||||
|
pending_cnt = len(queue_data.get("queue_pending", []))
|
||||||
|
|
||||||
|
col1.metric("Status", "🟢 Online" if running_cnt > 0 else "💤 Idle")
|
||||||
|
col2.metric("Pending", pending_cnt)
|
||||||
|
col3.metric("Running", running_cnt)
|
||||||
|
|
||||||
|
if col4.button("🔄 Check Img", key=f"refresh_{index}", use_container_width=True):
|
||||||
|
st.session_state[f"force_img_refresh_{index}"] = True
|
||||||
|
except Exception:
|
||||||
|
col1.metric("Status", "🔴 Offline")
|
||||||
|
col2.metric("Pending", "-")
|
||||||
|
col3.metric("Running", "-")
|
||||||
|
st.error(f"Could not connect to {COMFY_URL}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# --- 2. LIVE VIEW (WITH TOGGLE) ---
|
||||||
|
st.write("")
|
||||||
|
c_label, c_ctrl = st.columns([1, 2])
|
||||||
|
c_label.subheader("📺 Live View")
|
||||||
|
|
||||||
|
# LIVE PREVIEW TOGGLE
|
||||||
|
enable_preview = c_ctrl.checkbox("Enable Live Preview", value=True, key=f"live_toggle_{index}")
|
||||||
|
|
||||||
|
if enable_preview:
|
||||||
|
# Height Slider
|
||||||
|
iframe_h = st.slider(
|
||||||
|
"Height (px)",
|
||||||
|
min_value=600, max_value=2500, value=1000, step=50,
|
||||||
|
key=f"h_slider_{index}"
|
||||||
|
)
|
||||||
|
|
||||||
|
st.markdown(
|
||||||
|
f"""
|
||||||
|
<iframe src="{COMFY_URL}" width="100%" height="{iframe_h}px"
|
||||||
|
style="border: 1px solid #444; border-radius: 8px; box-shadow: 0 4px 6px rgba(0,0,0,0.3);">
|
||||||
|
</iframe>
|
||||||
|
""",
|
||||||
|
unsafe_allow_html=True
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
st.info("Live Preview is disabled. Enable it above to see the interface.")
|
||||||
|
|
||||||
|
st.markdown("---")
|
||||||
|
|
||||||
|
# --- 3. LATEST OUTPUT ---
|
||||||
|
if st.session_state.get(f"force_img_refresh_{index}", False):
|
||||||
|
st.caption("🖼️ Most Recent Output")
|
||||||
|
try:
|
||||||
|
hist_res = requests.get(f"{COMFY_URL}/history", timeout=2)
|
||||||
|
history = hist_res.json()
|
||||||
|
if history:
|
||||||
|
last_prompt_id = list(history.keys())[-1]
|
||||||
|
outputs = history[last_prompt_id].get("outputs", {})
|
||||||
|
found_img = None
|
||||||
|
for node_id, node_output in outputs.items():
|
||||||
|
if "images" in node_output:
|
||||||
|
for img_info in node_output["images"]:
|
||||||
|
if img_info["type"] == "output":
|
||||||
|
found_img = img_info
|
||||||
|
break
|
||||||
|
if found_img: break
|
||||||
|
|
||||||
|
if found_img:
|
||||||
|
img_name = found_img['filename']
|
||||||
|
folder = found_img['subfolder']
|
||||||
|
img_type = found_img['type']
|
||||||
|
img_url = f"{COMFY_URL}/view?filename={img_name}&subfolder={folder}&type={img_type}"
|
||||||
|
img_res = requests.get(img_url)
|
||||||
|
image = Image.open(BytesIO(img_res.content))
|
||||||
|
st.image(image, caption=f"Last Output: {img_name}")
|
||||||
|
else:
|
||||||
|
st.warning("Last run had no image output.")
|
||||||
|
else:
|
||||||
|
st.info("No history found.")
|
||||||
|
st.session_state[f"force_img_refresh_{index}"] = False
|
||||||
|
except Exception as e:
|
||||||
|
st.error(f"Error fetching image: {e}")
|
||||||
|
|
||||||
|
def render_comfy_monitor():
|
||||||
|
if "comfy_instances" not in st.session_state.config:
|
||||||
|
st.session_state.config["comfy_instances"] = [
|
||||||
|
{"name": "Main Server", "url": "http://192.168.1.100:8188"}
|
||||||
|
]
|
||||||
|
|
||||||
|
instances = st.session_state.config["comfy_instances"]
|
||||||
|
tab_names = [i["name"] for i in instances] + ["➕ Add Server"]
|
||||||
|
tabs = st.tabs(tab_names)
|
||||||
|
|
||||||
|
for i, tab in enumerate(tabs[:-1]):
|
||||||
|
with tab:
|
||||||
|
render_single_instance(instances[i], i, instances)
|
||||||
|
|
||||||
|
with tabs[-1]:
|
||||||
|
st.header("Add New ComfyUI Instance")
|
||||||
|
with st.form("add_server_form"):
|
||||||
|
new_name = st.text_input("Server Name", placeholder="e.g. Render Node 2")
|
||||||
|
new_url = st.text_input("URL", placeholder="http://192.168.1.50:8188")
|
||||||
|
if st.form_submit_button("Add Instance"):
|
||||||
|
if new_name and new_url:
|
||||||
|
instances.append({"name": new_name, "url": new_url})
|
||||||
|
st.session_state.config["comfy_instances"] = instances
|
||||||
|
|
||||||
|
save_config(
|
||||||
|
st.session_state.current_dir,
|
||||||
|
st.session_state.config['favorites'],
|
||||||
|
{"comfy_instances": instances}
|
||||||
|
)
|
||||||
|
st.success("Server Added!")
|
||||||
|
st.rerun()
|
||||||
|
else:
|
||||||
|
st.error("Please fill in both Name and URL.")
|
||||||
-281
@@ -1,281 +0,0 @@
|
|||||||
import asyncio
|
|
||||||
import html
|
|
||||||
import time
|
|
||||||
import urllib.parse
|
|
||||||
|
|
||||||
import requests
|
|
||||||
from nicegui import ui
|
|
||||||
|
|
||||||
from state import AppState
|
|
||||||
from utils import save_config
|
|
||||||
|
|
||||||
|
|
||||||
def render_comfy_monitor(state: AppState):
|
|
||||||
config = state.config
|
|
||||||
|
|
||||||
# --- Global Monitor Settings ---
|
|
||||||
with ui.expansion('Monitor Settings', icon='settings').classes('w-full'):
|
|
||||||
with ui.row().classes('w-full items-end'):
|
|
||||||
viewer_input = ui.input(
|
|
||||||
'Remote Browser URL',
|
|
||||||
value=config.get('viewer_url', ''),
|
|
||||||
placeholder='e.g., http://localhost:5800',
|
|
||||||
).classes('col')
|
|
||||||
timeout_slider = ui.slider(
|
|
||||||
min=0, max=60, step=1,
|
|
||||||
value=config.get('monitor_timeout', 0),
|
|
||||||
).classes('col')
|
|
||||||
ui.label().bind_text_from(timeout_slider, 'value',
|
|
||||||
backward=lambda v: f'Timeout: {v} min')
|
|
||||||
|
|
||||||
def save_monitor_settings():
|
|
||||||
config['viewer_url'] = viewer_input.value
|
|
||||||
config['monitor_timeout'] = int(timeout_slider.value)
|
|
||||||
save_config(state.current_dir, config['favorites'], config)
|
|
||||||
ui.notify('Monitor settings saved!', type='positive')
|
|
||||||
|
|
||||||
ui.button('Save Monitor Settings', icon='save', on_click=save_monitor_settings)
|
|
||||||
|
|
||||||
# --- Instance Management ---
|
|
||||||
if 'comfy_instances' not in config:
|
|
||||||
config['comfy_instances'] = [
|
|
||||||
{'name': 'Main Server', 'url': 'http://192.168.1.100:8188'}
|
|
||||||
]
|
|
||||||
|
|
||||||
instances = config['comfy_instances']
|
|
||||||
|
|
||||||
@ui.refreshable
|
|
||||||
def render_instance_tabs():
|
|
||||||
if not instances:
|
|
||||||
ui.label('No servers configured. Add one below.')
|
|
||||||
|
|
||||||
for idx, inst in enumerate(instances):
|
|
||||||
with ui.expansion(inst.get('name', f'Server {idx+1}'), icon='dns').classes('w-full'):
|
|
||||||
_render_single_instance(state, inst, idx, instances, render_instance_tabs)
|
|
||||||
|
|
||||||
# Add server section
|
|
||||||
ui.separator()
|
|
||||||
ui.label('Add New Server').classes('section-header')
|
|
||||||
with ui.row().classes('w-full items-end'):
|
|
||||||
new_name = ui.input('Server Name', placeholder='e.g. Render Node 2').classes('col')
|
|
||||||
new_url = ui.input('URL', placeholder='http://192.168.1.50:8188').classes('col')
|
|
||||||
|
|
||||||
def add_instance():
|
|
||||||
if new_name.value and new_url.value:
|
|
||||||
instances.append({'name': new_name.value, 'url': new_url.value})
|
|
||||||
config['comfy_instances'] = instances
|
|
||||||
save_config(state.current_dir, config['favorites'], config)
|
|
||||||
ui.notify('Server Added!', type='positive')
|
|
||||||
new_name.set_value('')
|
|
||||||
new_url.set_value('')
|
|
||||||
render_instance_tabs.refresh()
|
|
||||||
else:
|
|
||||||
ui.notify('Please fill in both Name and URL.', type='warning')
|
|
||||||
|
|
||||||
ui.button('Add Instance', icon='add', on_click=add_instance)
|
|
||||||
|
|
||||||
render_instance_tabs()
|
|
||||||
|
|
||||||
# --- Auto-poll timer (every 300s) ---
|
|
||||||
# Store live_checkbox references so the timer can update them
|
|
||||||
_live_checkboxes = state._live_checkboxes
|
|
||||||
_live_refreshables = state._live_refreshables
|
|
||||||
|
|
||||||
def poll_all():
|
|
||||||
try:
|
|
||||||
timeout_val = config.get('monitor_timeout', 0)
|
|
||||||
if timeout_val > 0:
|
|
||||||
for key, start_time in list(state.live_toggles.items()):
|
|
||||||
if start_time and (time.time() - start_time) > (timeout_val * 60):
|
|
||||||
state.live_toggles[key] = None
|
|
||||||
if key in _live_checkboxes:
|
|
||||||
_live_checkboxes[key].set_value(False)
|
|
||||||
if key in _live_refreshables:
|
|
||||||
_live_refreshables[key].refresh()
|
|
||||||
except RuntimeError:
|
|
||||||
pass # Parent slot deleted during refresh
|
|
||||||
|
|
||||||
ui.timer(300, poll_all)
|
|
||||||
|
|
||||||
|
|
||||||
def _fetch_blocking(url, timeout=1.5):
|
|
||||||
"""Run a blocking GET request; returns (response, error)."""
|
|
||||||
try:
|
|
||||||
res = requests.get(url, timeout=timeout)
|
|
||||||
return res, None
|
|
||||||
except Exception as e:
|
|
||||||
return None, e
|
|
||||||
|
|
||||||
|
|
||||||
def _render_single_instance(state: AppState, instance_config: dict, index: int,
|
|
||||||
all_instances: list, refresh_fn):
|
|
||||||
config = state.config
|
|
||||||
url = instance_config.get('url', 'http://127.0.0.1:8188')
|
|
||||||
name = instance_config.get('name', f'Server {index+1}')
|
|
||||||
comfy_url = url.rstrip('/')
|
|
||||||
|
|
||||||
# --- Settings popover ---
|
|
||||||
with ui.expansion('Settings', icon='settings'):
|
|
||||||
name_input = ui.input('Name', value=name).classes('w-full')
|
|
||||||
url_input = ui.input('URL', value=url).classes('w-full')
|
|
||||||
|
|
||||||
def update_server():
|
|
||||||
all_instances[index]['name'] = name_input.value
|
|
||||||
all_instances[index]['url'] = url_input.value
|
|
||||||
config['comfy_instances'] = all_instances
|
|
||||||
save_config(state.current_dir, config['favorites'], config)
|
|
||||||
ui.notify('Server config saved!', type='positive')
|
|
||||||
refresh_fn.refresh()
|
|
||||||
|
|
||||||
def remove_server():
|
|
||||||
all_instances.pop(index)
|
|
||||||
config['comfy_instances'] = all_instances
|
|
||||||
save_config(state.current_dir, config['favorites'], config)
|
|
||||||
ui.notify('Server removed', type='info')
|
|
||||||
refresh_fn.refresh()
|
|
||||||
|
|
||||||
ui.button('Update & Save', icon='save', on_click=update_server).props('color=primary')
|
|
||||||
ui.button('Remove Server', icon='delete', on_click=remove_server).props('color=negative')
|
|
||||||
|
|
||||||
# --- Status Dashboard ---
|
|
||||||
status_container = ui.row().classes('w-full items-center q-gutter-md')
|
|
||||||
|
|
||||||
async def refresh_status():
|
|
||||||
status_container.clear()
|
|
||||||
loop = asyncio.get_running_loop()
|
|
||||||
res, err = await loop.run_in_executor(
|
|
||||||
None, lambda: _fetch_blocking(f'{comfy_url}/queue'))
|
|
||||||
with status_container:
|
|
||||||
if res is not None:
|
|
||||||
try:
|
|
||||||
queue_data = res.json()
|
|
||||||
except (ValueError, Exception):
|
|
||||||
ui.label('Invalid response from server').classes('text-negative')
|
|
||||||
return
|
|
||||||
running_cnt = len(queue_data.get('queue_running', []))
|
|
||||||
pending_cnt = len(queue_data.get('queue_pending', []))
|
|
||||||
|
|
||||||
with ui.card().classes('q-pa-md text-center').style('min-width: 100px'):
|
|
||||||
ui.label('Status')
|
|
||||||
ui.label('Online' if running_cnt > 0 else 'Idle').classes(
|
|
||||||
'text-positive' if running_cnt > 0 else 'text-grey')
|
|
||||||
with ui.card().classes('q-pa-md text-center').style('min-width: 100px'):
|
|
||||||
ui.label('Pending')
|
|
||||||
ui.label(str(pending_cnt))
|
|
||||||
with ui.card().classes('q-pa-md text-center').style('min-width: 100px'):
|
|
||||||
ui.label('Running')
|
|
||||||
ui.label(str(running_cnt))
|
|
||||||
else:
|
|
||||||
with ui.card().classes('q-pa-md text-center').style('min-width: 100px'):
|
|
||||||
ui.label('Status')
|
|
||||||
ui.label('Offline').classes('text-negative')
|
|
||||||
ui.label(f'Could not connect to {comfy_url}').classes('text-negative')
|
|
||||||
|
|
||||||
# Initial status fetch (non-blocking via button click handler pattern)
|
|
||||||
ui.timer(0.1, refresh_status, once=True)
|
|
||||||
ui.button('Refresh Status', icon='refresh', on_click=refresh_status).props('flat dense')
|
|
||||||
|
|
||||||
# --- Live View ---
|
|
||||||
with ui.card().classes('w-full q-pa-md q-mt-md'):
|
|
||||||
ui.label('Live View').classes('section-header')
|
|
||||||
toggle_key = f'live_toggle_{index}'
|
|
||||||
|
|
||||||
live_checkbox = ui.checkbox('Enable Live Preview', value=False)
|
|
||||||
# Store reference so poll_all timer can disable it on timeout
|
|
||||||
state._live_checkboxes[toggle_key] = live_checkbox
|
|
||||||
|
|
||||||
@ui.refreshable
|
|
||||||
def render_live_view():
|
|
||||||
if not live_checkbox.value:
|
|
||||||
ui.label('Live Preview is disabled.').classes('text-caption')
|
|
||||||
return
|
|
||||||
|
|
||||||
# Record start time
|
|
||||||
if toggle_key not in state.live_toggles or state.live_toggles.get(toggle_key) is None:
|
|
||||||
state.live_toggles[toggle_key] = time.time()
|
|
||||||
|
|
||||||
timeout_val = config.get('monitor_timeout', 0)
|
|
||||||
if timeout_val > 0:
|
|
||||||
start = state.live_toggles.get(toggle_key, time.time())
|
|
||||||
remaining = (timeout_val * 60) - (time.time() - start)
|
|
||||||
if remaining <= 0:
|
|
||||||
live_checkbox.set_value(False)
|
|
||||||
state.live_toggles[toggle_key] = None
|
|
||||||
ui.label('Preview timed out.').classes('text-caption')
|
|
||||||
return
|
|
||||||
ui.label(f'Auto-off in: {int(remaining)}s').classes('text-caption')
|
|
||||||
|
|
||||||
iframe_h = ui.slider(min=600, max=2500, step=50, value=1000).classes('w-full')
|
|
||||||
ui.label().bind_text_from(iframe_h, 'value', backward=lambda v: f'Height: {v}px')
|
|
||||||
|
|
||||||
viewer_base = config.get('viewer_url', '').strip()
|
|
||||||
parsed = urllib.parse.urlparse(viewer_base)
|
|
||||||
if viewer_base and parsed.scheme in ('http', 'https'):
|
|
||||||
safe_src = html.escape(viewer_base, quote=True)
|
|
||||||
ui.label(f'Viewing: {viewer_base}').classes('text-caption')
|
|
||||||
|
|
||||||
iframe_container = ui.column().classes('w-full')
|
|
||||||
|
|
||||||
def update_iframe():
|
|
||||||
iframe_container.clear()
|
|
||||||
with iframe_container:
|
|
||||||
ui.html(
|
|
||||||
f'<iframe src="{safe_src}" width="100%" height="{int(iframe_h.value)}px"'
|
|
||||||
f' style="border: 2px solid #666; border-radius: 8px;"></iframe>'
|
|
||||||
)
|
|
||||||
|
|
||||||
iframe_h.on_value_change(lambda _: update_iframe())
|
|
||||||
update_iframe()
|
|
||||||
else:
|
|
||||||
ui.label('No valid viewer URL configured.').classes('text-warning')
|
|
||||||
|
|
||||||
state._live_refreshables[toggle_key] = render_live_view
|
|
||||||
live_checkbox.on_value_change(lambda _: render_live_view.refresh())
|
|
||||||
render_live_view()
|
|
||||||
|
|
||||||
# --- Latest Output ---
|
|
||||||
with ui.card().classes('w-full q-pa-md q-mt-md'):
|
|
||||||
ui.label('Latest Output').classes('section-header')
|
|
||||||
img_container = ui.column().classes('w-full')
|
|
||||||
|
|
||||||
async def check_image():
|
|
||||||
img_container.clear()
|
|
||||||
loop = asyncio.get_running_loop()
|
|
||||||
res, err = await loop.run_in_executor(
|
|
||||||
None, lambda: _fetch_blocking(f'{comfy_url}/history', timeout=2))
|
|
||||||
with img_container:
|
|
||||||
if err is not None:
|
|
||||||
ui.label(f'Error fetching image: {err}').classes('text-negative')
|
|
||||||
return
|
|
||||||
try:
|
|
||||||
history = res.json()
|
|
||||||
except (ValueError, Exception):
|
|
||||||
ui.label('Invalid response from server').classes('text-negative')
|
|
||||||
return
|
|
||||||
if not history:
|
|
||||||
ui.label('No history found.').classes('text-caption')
|
|
||||||
return
|
|
||||||
last_prompt_id = list(history.keys())[-1]
|
|
||||||
outputs = history[last_prompt_id].get('outputs', {})
|
|
||||||
found_img = None
|
|
||||||
for node_output in outputs.values():
|
|
||||||
if 'images' in node_output:
|
|
||||||
for img_info in node_output['images']:
|
|
||||||
if img_info['type'] == 'output':
|
|
||||||
found_img = img_info
|
|
||||||
break
|
|
||||||
if found_img:
|
|
||||||
break
|
|
||||||
if found_img:
|
|
||||||
params = urllib.parse.urlencode({
|
|
||||||
'filename': found_img['filename'],
|
|
||||||
'subfolder': found_img['subfolder'],
|
|
||||||
'type': found_img['type'],
|
|
||||||
})
|
|
||||||
img_url = f'{comfy_url}/view?{params}'
|
|
||||||
ui.image(img_url).classes('w-full').style('max-width: 600px')
|
|
||||||
ui.label(f'Last Output: {found_img["filename"]}').classes('text-caption')
|
|
||||||
else:
|
|
||||||
ui.label('Last run had no image output.').classes('text-caption')
|
|
||||||
|
|
||||||
ui.button('Check Latest Image', icon='image', on_click=check_image).props('flat')
|
|
||||||
@@ -1,294 +0,0 @@
|
|||||||
import asyncio
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import sqlite3
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from nicegui import ui
|
|
||||||
|
|
||||||
from state import AppState
|
|
||||||
from db import ProjectDB
|
|
||||||
from utils import save_config, sync_to_db
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def render_projects_tab(state: AppState):
|
|
||||||
"""Render the Projects management tab."""
|
|
||||||
|
|
||||||
# --- DB toggle ---
|
|
||||||
def on_db_toggle(e):
|
|
||||||
state.db_enabled = e.value
|
|
||||||
state.config['db_enabled'] = e.value
|
|
||||||
save_config(state.current_dir, state.config.get('favorites', []), state.config)
|
|
||||||
render_project_content.refresh()
|
|
||||||
|
|
||||||
ui.switch('Enable Project Database', value=state.db_enabled,
|
|
||||||
on_change=on_db_toggle).classes('q-mb-md')
|
|
||||||
|
|
||||||
@ui.refreshable
|
|
||||||
def render_project_content():
|
|
||||||
if not state.db_enabled:
|
|
||||||
ui.label('Project database is disabled. Enable it above to manage projects.').classes(
|
|
||||||
'text-caption q-pa-md')
|
|
||||||
return
|
|
||||||
|
|
||||||
if not state.db:
|
|
||||||
ui.label('Database not initialized.').classes('text-warning q-pa-md')
|
|
||||||
return
|
|
||||||
|
|
||||||
# --- Create project form ---
|
|
||||||
with ui.card().classes('w-full q-pa-md q-mb-md'):
|
|
||||||
ui.label('Create New Project').classes('section-header')
|
|
||||||
name_input = ui.input('Project Name', placeholder='my_project').classes('w-full')
|
|
||||||
desc_input = ui.input('Description (optional)', placeholder='A short description').classes('w-full')
|
|
||||||
|
|
||||||
async def create_project():
|
|
||||||
name = name_input.value.strip()
|
|
||||||
if not name:
|
|
||||||
ui.notify('Please enter a project name', type='warning')
|
|
||||||
return
|
|
||||||
try:
|
|
||||||
await asyncio.to_thread(state.db.create_project, name, str(state.current_dir), desc_input.value.strip())
|
|
||||||
name_input.set_value('')
|
|
||||||
desc_input.set_value('')
|
|
||||||
ui.notify(f'Created project "{name}"', type='positive')
|
|
||||||
render_project_list.refresh()
|
|
||||||
except Exception as e:
|
|
||||||
ui.notify(f'Error: {e}', type='negative')
|
|
||||||
|
|
||||||
ui.button('Create Project', icon='add', on_click=create_project).classes('w-full')
|
|
||||||
|
|
||||||
# --- Path replacements (for ComfyUI Docker path differences) ---
|
|
||||||
with ui.card().classes('w-full q-pa-md q-mb-md'):
|
|
||||||
ui.label('ComfyUI Path Replacements').classes('section-header')
|
|
||||||
ui.label('Applied to project_path output — use to fix Docker mount casing differences.'
|
|
||||||
).classes('text-caption q-mb-sm')
|
|
||||||
|
|
||||||
replacements: list[dict] = state.config.get('path_replacements', [])
|
|
||||||
|
|
||||||
@ui.refreshable
|
|
||||||
def render_replacements():
|
|
||||||
for idx, rep in enumerate(replacements):
|
|
||||||
with ui.row().classes('w-full items-center no-wrap q-gutter-xs'):
|
|
||||||
ui.input('From', value=rep.get('from', '')).classes('col').props(
|
|
||||||
'outlined dense').on('update:model-value',
|
|
||||||
lambda e, i=idx: _update_replacement(i, 'from', e.args))
|
|
||||||
ui.label('→').classes('text-caption')
|
|
||||||
ui.input('To', value=rep.get('to', '')).classes('col').props(
|
|
||||||
'outlined dense').on('update:model-value',
|
|
||||||
lambda e, i=idx: _update_replacement(i, 'to', e.args))
|
|
||||||
ui.button(icon='delete', on_click=lambda i=idx: _remove_replacement(i)
|
|
||||||
).props('flat dense color=negative')
|
|
||||||
|
|
||||||
def _update_replacement(idx, field, value):
|
|
||||||
replacements[idx][field] = value
|
|
||||||
state.config['path_replacements'] = replacements
|
|
||||||
save_config(state.current_dir, state.config.get('favorites', []), state.config)
|
|
||||||
|
|
||||||
def _remove_replacement(idx):
|
|
||||||
replacements.pop(idx)
|
|
||||||
state.config['path_replacements'] = replacements
|
|
||||||
save_config(state.current_dir, state.config.get('favorites', []), state.config)
|
|
||||||
render_replacements.refresh()
|
|
||||||
|
|
||||||
def _add_replacement():
|
|
||||||
replacements.append({'from': '', 'to': ''})
|
|
||||||
state.config['path_replacements'] = replacements
|
|
||||||
save_config(state.current_dir, state.config.get('favorites', []), state.config)
|
|
||||||
render_replacements.refresh()
|
|
||||||
|
|
||||||
render_replacements()
|
|
||||||
ui.button('Add Replacement', icon='add', on_click=_add_replacement).props('flat dense')
|
|
||||||
|
|
||||||
# --- Active project indicator ---
|
|
||||||
# Fetch once with file counts and reuse in render_project_list
|
|
||||||
_cached_projects = state.db.list_projects_with_file_counts()
|
|
||||||
|
|
||||||
if state.current_project:
|
|
||||||
# Check if active project actually exists in the database
|
|
||||||
project_exists = any(p['name'] == state.current_project for p in _cached_projects)
|
|
||||||
if project_exists:
|
|
||||||
ui.label(f'Active Project: {state.current_project}').classes(
|
|
||||||
'text-bold text-primary q-pa-sm')
|
|
||||||
else:
|
|
||||||
with ui.card().classes('w-full q-pa-sm q-mb-sm').style(
|
|
||||||
'border-left: 3px solid orange;'):
|
|
||||||
ui.label(f'Stale project reference: "{state.current_project}" '
|
|
||||||
'(not found in database)').classes('text-warning')
|
|
||||||
with ui.row().classes('q-gutter-sm'):
|
|
||||||
def clear_stale():
|
|
||||||
state.current_project = ''
|
|
||||||
state.config['current_project'] = ''
|
|
||||||
save_config(state.current_dir,
|
|
||||||
state.config.get('favorites', []),
|
|
||||||
state.config)
|
|
||||||
ui.notify('Cleared stale project reference', type='info')
|
|
||||||
render_project_content.refresh()
|
|
||||||
|
|
||||||
def recreate_project():
|
|
||||||
name = state.current_project
|
|
||||||
try:
|
|
||||||
state.db.create_project(name, str(state.current_dir))
|
|
||||||
ui.notify(f'Recreated project "{name}"', type='positive')
|
|
||||||
render_project_content.refresh()
|
|
||||||
except Exception as e:
|
|
||||||
ui.notify(f'Error: {e}', type='negative')
|
|
||||||
|
|
||||||
ui.button('Clear Reference', icon='clear',
|
|
||||||
on_click=clear_stale).props('flat dense')
|
|
||||||
ui.button('Recreate Project', icon='add_circle',
|
|
||||||
on_click=recreate_project).props('flat dense color=primary')
|
|
||||||
|
|
||||||
# --- Project list ---
|
|
||||||
@ui.refreshable
|
|
||||||
def render_project_list():
|
|
||||||
nonlocal _cached_projects
|
|
||||||
projects = state.db.list_projects_with_file_counts()
|
|
||||||
_cached_projects = projects
|
|
||||||
if not projects:
|
|
||||||
ui.label('No projects yet. Create one above.').classes('text-caption q-pa-md')
|
|
||||||
return
|
|
||||||
|
|
||||||
for proj in projects:
|
|
||||||
is_active = proj['name'] == state.current_project
|
|
||||||
card_style = 'border-left: 3px solid var(--accent);' if is_active else ''
|
|
||||||
|
|
||||||
with ui.card().classes('w-full q-pa-sm q-mb-sm').style(card_style):
|
|
||||||
with ui.row().classes('w-full items-center'):
|
|
||||||
with ui.column().classes('col'):
|
|
||||||
ui.label(proj['name']).classes('text-bold')
|
|
||||||
if proj['description']:
|
|
||||||
ui.label(proj['description']).classes('text-caption')
|
|
||||||
ui.label(f'Path: {proj["folder_path"]}').classes('text-caption')
|
|
||||||
ui.label(f'{proj["file_count"]} data file(s)').classes('text-caption')
|
|
||||||
|
|
||||||
with ui.row().classes('q-gutter-xs'):
|
|
||||||
if not is_active:
|
|
||||||
def activate(name=proj['name']):
|
|
||||||
state.current_project = name
|
|
||||||
state.config['current_project'] = name
|
|
||||||
save_config(state.current_dir,
|
|
||||||
state.config.get('favorites', []),
|
|
||||||
state.config)
|
|
||||||
ui.notify(f'Activated project "{name}"', type='positive')
|
|
||||||
render_project_list.refresh()
|
|
||||||
|
|
||||||
ui.button('Activate', icon='check_circle',
|
|
||||||
on_click=activate).props('flat dense color=primary')
|
|
||||||
else:
|
|
||||||
def deactivate():
|
|
||||||
state.current_project = ''
|
|
||||||
state.config['current_project'] = ''
|
|
||||||
save_config(state.current_dir,
|
|
||||||
state.config.get('favorites', []),
|
|
||||||
state.config)
|
|
||||||
ui.notify('Deactivated project', type='info')
|
|
||||||
render_project_list.refresh()
|
|
||||||
|
|
||||||
ui.button('Deactivate', icon='cancel',
|
|
||||||
on_click=deactivate).props('flat dense')
|
|
||||||
|
|
||||||
async def rename_proj(name=proj['name']):
|
|
||||||
new_name = await ui.run_javascript(
|
|
||||||
f'prompt("Rename project:", {json.dumps(name)})',
|
|
||||||
timeout=30.0,
|
|
||||||
)
|
|
||||||
if new_name and new_name.strip() and new_name.strip() != name:
|
|
||||||
new_name = new_name.strip()
|
|
||||||
try:
|
|
||||||
await asyncio.to_thread(state.db.rename_project, name, new_name)
|
|
||||||
if state.current_project == name:
|
|
||||||
state.current_project = new_name
|
|
||||||
state.config['current_project'] = new_name
|
|
||||||
save_config(state.current_dir,
|
|
||||||
state.config.get('favorites', []),
|
|
||||||
state.config)
|
|
||||||
ui.notify(f'Renamed to "{new_name}"', type='positive')
|
|
||||||
render_project_list.refresh()
|
|
||||||
except sqlite3.IntegrityError:
|
|
||||||
ui.notify(f'A project named "{new_name}" already exists',
|
|
||||||
type='warning')
|
|
||||||
except Exception as e:
|
|
||||||
ui.notify(f'Error: {e}', type='negative')
|
|
||||||
|
|
||||||
ui.button('Rename', icon='edit',
|
|
||||||
on_click=rename_proj).props('flat dense')
|
|
||||||
|
|
||||||
async def change_path(name=proj['name'], path=proj['folder_path']):
|
|
||||||
new_path = await ui.run_javascript(
|
|
||||||
f'prompt("New path for project:", {json.dumps(path)})',
|
|
||||||
timeout=30.0,
|
|
||||||
)
|
|
||||||
if new_path and new_path.strip() and new_path.strip() != path:
|
|
||||||
new_path = new_path.strip()
|
|
||||||
if not Path(new_path).is_dir():
|
|
||||||
ui.notify(f'Warning: "{new_path}" does not exist',
|
|
||||||
type='warning')
|
|
||||||
await asyncio.to_thread(state.db.update_project_path, name, new_path)
|
|
||||||
ui.notify(f'Path updated to "{new_path}"', type='positive')
|
|
||||||
render_project_list.refresh()
|
|
||||||
|
|
||||||
ui.button('Path', icon='folder',
|
|
||||||
on_click=change_path).props('flat dense')
|
|
||||||
|
|
||||||
def import_folder(pid=proj['id'], pname=proj['name']):
|
|
||||||
_import_folder(state, pid, pname, render_project_list)
|
|
||||||
|
|
||||||
ui.button('Import Folder', icon='folder_open',
|
|
||||||
on_click=import_folder).props('flat dense')
|
|
||||||
|
|
||||||
async def delete_proj(name=proj['name']):
|
|
||||||
await asyncio.to_thread(state.db.delete_project, name)
|
|
||||||
if state.current_project == name:
|
|
||||||
state.current_project = ''
|
|
||||||
state.config['current_project'] = ''
|
|
||||||
save_config(state.current_dir,
|
|
||||||
state.config.get('favorites', []),
|
|
||||||
state.config)
|
|
||||||
ui.notify(f'Deleted project "{name}"', type='positive')
|
|
||||||
render_project_list.refresh()
|
|
||||||
|
|
||||||
ui.button(icon='delete',
|
|
||||||
on_click=delete_proj).props('flat dense color=negative')
|
|
||||||
|
|
||||||
render_project_list()
|
|
||||||
|
|
||||||
render_project_content()
|
|
||||||
|
|
||||||
|
|
||||||
async def _import_folder(state: AppState, project_id: int, project_name: str, refresh_fn):
|
|
||||||
"""Bulk import all .json files from the project's folder_path into a project."""
|
|
||||||
proj = state.db.get_project(project_name)
|
|
||||||
scan_dir = Path(proj['folder_path']) if proj else state.current_dir
|
|
||||||
json_files = sorted(scan_dir.glob('*.json'))
|
|
||||||
json_files = [f for f in json_files if f.name not in (
|
|
||||||
'.editor_config.json', '.editor_snippets.json')]
|
|
||||||
|
|
||||||
if not json_files:
|
|
||||||
ui.notify('No JSON files in current directory', type='warning')
|
|
||||||
return
|
|
||||||
|
|
||||||
def _do_import():
|
|
||||||
imported = 0
|
|
||||||
skipped = 0
|
|
||||||
for jf in json_files:
|
|
||||||
file_name = jf.stem
|
|
||||||
existing = state.db.get_data_file(project_id, file_name)
|
|
||||||
if existing:
|
|
||||||
skipped += 1
|
|
||||||
continue
|
|
||||||
try:
|
|
||||||
state.db.import_json_file(project_id, jf)
|
|
||||||
imported += 1
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Failed to import {jf}: {e}")
|
|
||||||
return imported, skipped
|
|
||||||
|
|
||||||
imported, skipped = await asyncio.to_thread(_do_import)
|
|
||||||
|
|
||||||
msg = f'Imported {imported} file(s)'
|
|
||||||
if skipped:
|
|
||||||
msg += f', skipped {skipped} existing'
|
|
||||||
ui.notify(msg, type='positive')
|
|
||||||
refresh_fn.refresh()
|
|
||||||
@@ -1,74 +0,0 @@
|
|||||||
import asyncio
|
|
||||||
import json
|
|
||||||
|
|
||||||
from nicegui import ui
|
|
||||||
|
|
||||||
from state import AppState
|
|
||||||
from utils import save_json, sync_to_db, get_file_mtime, KEY_HISTORY_TREE, KEY_PROMPT_HISTORY
|
|
||||||
|
|
||||||
|
|
||||||
def render_raw_editor(state: AppState):
|
|
||||||
data = state.data_cache
|
|
||||||
file_path = state.file_path
|
|
||||||
|
|
||||||
with ui.card().classes('w-full q-pa-md'):
|
|
||||||
ui.label(f'Raw Editor: {file_path.name}').classes('text-h6 q-mb-md')
|
|
||||||
|
|
||||||
hide_history = ui.checkbox(
|
|
||||||
'Hide History (Safe Mode)',
|
|
||||||
value=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
@ui.refreshable
|
|
||||||
def render_editor():
|
|
||||||
# Prepare display data — shallow copy, just pop keys
|
|
||||||
if hide_history.value:
|
|
||||||
display_data = {k: v for k, v in data.items()
|
|
||||||
if k not in (KEY_HISTORY_TREE, KEY_PROMPT_HISTORY)}
|
|
||||||
else:
|
|
||||||
display_data = data
|
|
||||||
|
|
||||||
try:
|
|
||||||
json_str = json.dumps(display_data, indent=4, ensure_ascii=False)
|
|
||||||
except Exception as e:
|
|
||||||
ui.notify(f'Error serializing JSON: {e}', type='negative')
|
|
||||||
json_str = '{}'
|
|
||||||
|
|
||||||
text_area = ui.textarea(
|
|
||||||
'JSON Content',
|
|
||||||
value=json_str,
|
|
||||||
).classes('w-full font-mono').props('outlined rows=30')
|
|
||||||
|
|
||||||
async def do_save():
|
|
||||||
try:
|
|
||||||
input_data = json.loads(text_area.value)
|
|
||||||
|
|
||||||
# Merge hidden history back in if safe mode
|
|
||||||
if hide_history.value:
|
|
||||||
if KEY_HISTORY_TREE in data:
|
|
||||||
input_data[KEY_HISTORY_TREE] = data[KEY_HISTORY_TREE]
|
|
||||||
if KEY_PROMPT_HISTORY in data:
|
|
||||||
input_data[KEY_PROMPT_HISTORY] = data[KEY_PROMPT_HISTORY]
|
|
||||||
|
|
||||||
await asyncio.to_thread(save_json, file_path, input_data)
|
|
||||||
if state.db_enabled and state.current_project and state.db:
|
|
||||||
await asyncio.to_thread(sync_to_db, state.db, state.current_project, file_path, input_data)
|
|
||||||
|
|
||||||
data.clear()
|
|
||||||
data.update(input_data)
|
|
||||||
state.last_mtime = get_file_mtime(file_path)
|
|
||||||
|
|
||||||
ui.notify('Raw JSON Saved Successfully!', type='positive')
|
|
||||||
render_editor.refresh()
|
|
||||||
|
|
||||||
except json.JSONDecodeError as e:
|
|
||||||
ui.notify(f'Invalid JSON Syntax: {e}', type='negative')
|
|
||||||
except Exception as e:
|
|
||||||
ui.notify(f'Unexpected Error: {e}', type='negative')
|
|
||||||
|
|
||||||
ui.button('Save Raw Changes', icon='save', on_click=do_save).props(
|
|
||||||
'color=primary'
|
|
||||||
).classes('w-full q-mt-md')
|
|
||||||
|
|
||||||
hide_history.on_value_change(lambda _: render_editor.refresh())
|
|
||||||
render_editor()
|
|
||||||
+250
@@ -0,0 +1,250 @@
|
|||||||
|
import streamlit as st
|
||||||
|
import random
|
||||||
|
from utils import DEFAULTS, save_json, get_file_mtime
|
||||||
|
|
||||||
|
def render_single_editor(data, file_path):
|
||||||
|
is_batch_file = "batch_data" in data or isinstance(data, list)
|
||||||
|
|
||||||
|
if is_batch_file:
|
||||||
|
st.info("This is a batch file. Switch to the 'Batch Processor' tab.")
|
||||||
|
return
|
||||||
|
|
||||||
|
col1, col2 = st.columns([2, 1])
|
||||||
|
|
||||||
|
# Unique prefix for this file's widgets + Version Token (Fixes Restore bug)
|
||||||
|
fk = f"{file_path.name}_v{st.session_state.ui_reset_token}"
|
||||||
|
|
||||||
|
# --- FORM ---
|
||||||
|
with col1:
|
||||||
|
with st.expander("🌍 General Prompts (Global Layer)", expanded=False):
|
||||||
|
gen_prompt = st.text_area("General Prompt", value=data.get("general_prompt", ""), height=100, key=f"{fk}_gp")
|
||||||
|
gen_negative = st.text_area("General Negative", value=data.get("general_negative", DEFAULTS["general_negative"]), height=100, key=f"{fk}_gn")
|
||||||
|
|
||||||
|
st.write("📝 **Specific Prompts**")
|
||||||
|
current_prompt_val = data.get("current_prompt", "")
|
||||||
|
if 'append_prompt' in st.session_state:
|
||||||
|
current_prompt_val = (current_prompt_val.strip() + ", " + st.session_state.append_prompt).strip(', ')
|
||||||
|
del st.session_state.append_prompt
|
||||||
|
|
||||||
|
new_prompt = st.text_area("Specific Prompt", value=current_prompt_val, height=150, key=f"{fk}_sp")
|
||||||
|
new_negative = st.text_area("Specific Negative", value=data.get("negative", ""), height=100, key=f"{fk}_sn")
|
||||||
|
|
||||||
|
# Seed
|
||||||
|
col_seed_val, col_seed_btn = st.columns([4, 1])
|
||||||
|
seed_key = f"{fk}_seed"
|
||||||
|
|
||||||
|
with col_seed_btn:
|
||||||
|
st.write("")
|
||||||
|
st.write("")
|
||||||
|
if st.button("🎲 Randomize", key=f"{fk}_rand"):
|
||||||
|
st.session_state[seed_key] = random.randint(0, 999999999999)
|
||||||
|
st.rerun()
|
||||||
|
|
||||||
|
with col_seed_val:
|
||||||
|
seed_val = st.session_state.get('rand_seed', int(data.get("seed", 0)))
|
||||||
|
new_seed = st.number_input("Seed", value=seed_val, step=1, min_value=0, format="%d", key=seed_key)
|
||||||
|
data["seed"] = new_seed
|
||||||
|
|
||||||
|
# LoRAs
|
||||||
|
st.subheader("LoRAs")
|
||||||
|
l_col1, l_col2 = st.columns(2)
|
||||||
|
loras = {}
|
||||||
|
lora_keys = ["lora 1 high", "lora 1 low", "lora 2 high", "lora 2 low", "lora 3 high", "lora 3 low"]
|
||||||
|
for i, k in enumerate(lora_keys):
|
||||||
|
with (l_col1 if i % 2 == 0 else l_col2):
|
||||||
|
loras[k] = st.text_input(k.title(), value=data.get(k, ""), key=f"{fk}_{k}")
|
||||||
|
|
||||||
|
# Settings
|
||||||
|
st.subheader("Settings")
|
||||||
|
spec_fields = {}
|
||||||
|
spec_fields["camera"] = st.text_input("Camera", value=str(data.get("camera", DEFAULTS["camera"])), key=f"{fk}_cam")
|
||||||
|
spec_fields["flf"] = st.text_input("FLF", value=str(data.get("flf", DEFAULTS["flf"])), key=f"{fk}_flf")
|
||||||
|
|
||||||
|
# Explicitly track standard setting keys to exclude them from custom list
|
||||||
|
standard_keys = {
|
||||||
|
"general_prompt", "general_negative", "current_prompt", "negative", "prompt", "seed",
|
||||||
|
"camera", "flf", "batch_data", "prompt_history", "sequence_number", "ui_reset_token",
|
||||||
|
"model_name", "vae_name", "steps", "cfg", "denoise", "sampler_name", "scheduler"
|
||||||
|
}
|
||||||
|
standard_keys.update(lora_keys)
|
||||||
|
|
||||||
|
if "vace" in file_path.name:
|
||||||
|
vace_keys = ["frame_to_skip", "input_a_frames", "input_b_frames", "reference switch", "vace schedule", "reference path", "video file path", "reference image path"]
|
||||||
|
standard_keys.update(vace_keys)
|
||||||
|
|
||||||
|
spec_fields["frame_to_skip"] = st.number_input("Frame to Skip", value=int(data.get("frame_to_skip", 81)), key=f"{fk}_fts")
|
||||||
|
spec_fields["input_a_frames"] = st.number_input("Input A Frames", value=int(data.get("input_a_frames", 0)), key=f"{fk}_ia")
|
||||||
|
spec_fields["input_b_frames"] = st.number_input("Input B Frames", value=int(data.get("input_b_frames", 0)), key=f"{fk}_ib")
|
||||||
|
spec_fields["reference switch"] = st.number_input("Reference Switch", value=int(data.get("reference switch", 1)), key=f"{fk}_rsw")
|
||||||
|
spec_fields["vace schedule"] = st.number_input("VACE Schedule", value=int(data.get("vace schedule", 1)), key=f"{fk}_vsc")
|
||||||
|
for f in ["reference path", "video file path", "reference image path"]:
|
||||||
|
spec_fields[f] = st.text_input(f.title(), value=str(data.get(f, "")), key=f"{fk}_{f}")
|
||||||
|
elif "i2v" in file_path.name:
|
||||||
|
i2v_keys = ["reference image path", "flf image path", "video file path"]
|
||||||
|
standard_keys.update(i2v_keys)
|
||||||
|
|
||||||
|
for f in i2v_keys:
|
||||||
|
spec_fields[f] = st.text_input(f.title(), value=str(data.get(f, "")), key=f"{fk}_{f}")
|
||||||
|
|
||||||
|
# --- CUSTOM PARAMETERS LOGIC ---
|
||||||
|
st.markdown("---")
|
||||||
|
st.subheader("🔧 Custom Parameters")
|
||||||
|
|
||||||
|
# Filter keys: Only those NOT in the standard set
|
||||||
|
custom_keys = [k for k in data.keys() if k not in standard_keys]
|
||||||
|
|
||||||
|
keys_to_remove = []
|
||||||
|
|
||||||
|
if custom_keys:
|
||||||
|
for k in custom_keys:
|
||||||
|
c1, c2, c3 = st.columns([1, 2, 0.5])
|
||||||
|
c1.text_input("Key", value=k, disabled=True, key=f"{fk}_ck_lbl_{k}", label_visibility="collapsed")
|
||||||
|
val = c2.text_input("Value", value=str(data[k]), key=f"{fk}_cv_{k}", label_visibility="collapsed")
|
||||||
|
data[k] = val
|
||||||
|
|
||||||
|
if c3.button("🗑️", key=f"{fk}_cdel_{k}"):
|
||||||
|
keys_to_remove.append(k)
|
||||||
|
else:
|
||||||
|
st.caption("No custom keys added.")
|
||||||
|
|
||||||
|
# Add New Key Interface
|
||||||
|
with st.expander("➕ Add New Parameter"):
|
||||||
|
nk_col, nv_col = st.columns(2)
|
||||||
|
new_k = nk_col.text_input("Key Name", key=f"{fk}_new_k")
|
||||||
|
new_v = nv_col.text_input("Value", key=f"{fk}_new_v")
|
||||||
|
|
||||||
|
if st.button("Add Parameter", key=f"{fk}_add_cust"):
|
||||||
|
if new_k and new_k not in data:
|
||||||
|
data[new_k] = new_v
|
||||||
|
st.rerun()
|
||||||
|
elif new_k in data:
|
||||||
|
st.error(f"Key '{new_k}' already exists!")
|
||||||
|
|
||||||
|
# Apply Removals
|
||||||
|
if keys_to_remove:
|
||||||
|
for k in keys_to_remove:
|
||||||
|
del data[k]
|
||||||
|
st.rerun()
|
||||||
|
|
||||||
|
# --- ACTIONS & HISTORY ---
|
||||||
|
with col2:
|
||||||
|
current_state = {
|
||||||
|
"general_prompt": gen_prompt, "general_negative": gen_negative,
|
||||||
|
"current_prompt": new_prompt, "negative": new_negative,
|
||||||
|
"seed": new_seed, **loras, **spec_fields
|
||||||
|
}
|
||||||
|
|
||||||
|
# MERGE CUSTOM KEYS
|
||||||
|
for k in custom_keys:
|
||||||
|
if k not in keys_to_remove:
|
||||||
|
current_state[k] = data[k]
|
||||||
|
|
||||||
|
st.session_state.single_editor_cache = current_state
|
||||||
|
|
||||||
|
st.subheader("Actions")
|
||||||
|
current_disk_mtime = get_file_mtime(file_path)
|
||||||
|
is_conflict = current_disk_mtime > st.session_state.last_mtime
|
||||||
|
|
||||||
|
if is_conflict:
|
||||||
|
st.error("⚠️ CONFLICT: Disk changed!")
|
||||||
|
if st.button("Force Save"):
|
||||||
|
data.update(current_state)
|
||||||
|
save_json(file_path, data) # No return val in new utils
|
||||||
|
st.session_state.last_mtime = get_file_mtime(file_path) # Manual Update
|
||||||
|
st.session_state.data_cache = data
|
||||||
|
st.toast("Saved!", icon="⚠️")
|
||||||
|
st.rerun()
|
||||||
|
if st.button("Reload File"):
|
||||||
|
st.session_state.loaded_file = None
|
||||||
|
st.rerun()
|
||||||
|
else:
|
||||||
|
if st.button("💾 Update File", use_container_width=True):
|
||||||
|
data.update(current_state)
|
||||||
|
save_json(file_path, data)
|
||||||
|
st.session_state.last_mtime = get_file_mtime(file_path)
|
||||||
|
st.session_state.data_cache = data
|
||||||
|
st.toast("Updated!", icon="✅")
|
||||||
|
|
||||||
|
st.markdown("---")
|
||||||
|
archive_note = st.text_input("Archive Note")
|
||||||
|
if st.button("📦 Snapshot to History", use_container_width=True):
|
||||||
|
entry = {"note": archive_note if archive_note else "Snapshot", **current_state}
|
||||||
|
if "prompt_history" not in data: data["prompt_history"] = []
|
||||||
|
data["prompt_history"].insert(0, entry)
|
||||||
|
data.update(entry)
|
||||||
|
save_json(file_path, data)
|
||||||
|
st.session_state.last_mtime = get_file_mtime(file_path)
|
||||||
|
st.session_state.data_cache = data
|
||||||
|
st.toast("Archived!", icon="📦")
|
||||||
|
st.rerun()
|
||||||
|
|
||||||
|
# --- FULL HISTORY PANEL ---
|
||||||
|
st.markdown("---")
|
||||||
|
st.subheader("History")
|
||||||
|
history = data.get("prompt_history", [])
|
||||||
|
|
||||||
|
if not history:
|
||||||
|
st.caption("No history yet.")
|
||||||
|
|
||||||
|
for idx, h in enumerate(history):
|
||||||
|
note = h.get('note', 'No Note')
|
||||||
|
|
||||||
|
with st.container():
|
||||||
|
if st.session_state.edit_history_idx == idx:
|
||||||
|
with st.expander(f"📝 Editing: {note}", expanded=True):
|
||||||
|
edit_note = st.text_input("Note", value=note, key=f"h_en_{idx}")
|
||||||
|
edit_seed = st.number_input("Seed", value=int(h.get('seed', 0)), key=f"h_es_{idx}")
|
||||||
|
edit_gp = st.text_area("General P", value=h.get('general_prompt', ''), height=60, key=f"h_egp_{idx}")
|
||||||
|
edit_gn = st.text_area("General N", value=h.get('general_negative', ''), height=60, key=f"h_egn_{idx}")
|
||||||
|
edit_sp = st.text_area("Specific P", value=h.get('prompt', ''), height=100, key=f"h_esp_{idx}")
|
||||||
|
edit_sn = st.text_area("Specific N", value=h.get('negative', ''), height=60, key=f"h_esn_{idx}")
|
||||||
|
|
||||||
|
hc1, hc2 = st.columns([1, 4])
|
||||||
|
if hc1.button("💾 Save", key=f"h_save_{idx}"):
|
||||||
|
h.update({
|
||||||
|
'note': edit_note, 'seed': edit_seed,
|
||||||
|
'general_prompt': edit_gp, 'general_negative': edit_gn,
|
||||||
|
'prompt': edit_sp, 'negative': edit_sn
|
||||||
|
})
|
||||||
|
save_json(file_path, data)
|
||||||
|
st.session_state.last_mtime = get_file_mtime(file_path)
|
||||||
|
st.session_state.data_cache = data
|
||||||
|
st.session_state.edit_history_idx = None
|
||||||
|
st.rerun()
|
||||||
|
if hc2.button("Cancel", key=f"h_can_{idx}"):
|
||||||
|
st.session_state.edit_history_idx = None
|
||||||
|
st.rerun()
|
||||||
|
|
||||||
|
else:
|
||||||
|
with st.expander(f"#{idx+1}: {note}"):
|
||||||
|
st.caption(f"Seed: {h.get('seed', 0)}")
|
||||||
|
st.text(f"SPEC: {h.get('prompt', '')[:40]}...")
|
||||||
|
|
||||||
|
view_data = {k:v for k,v in h.items() if k not in ['prompt', 'negative', 'general_prompt', 'general_negative', 'note']}
|
||||||
|
st.json(view_data, expanded=False)
|
||||||
|
|
||||||
|
bh1, bh2, bh3 = st.columns([2, 1, 1])
|
||||||
|
|
||||||
|
if bh1.button("Restore", key=f"h_rest_{idx}", use_container_width=True):
|
||||||
|
data.update(h)
|
||||||
|
if 'prompt' in h: data['current_prompt'] = h['prompt']
|
||||||
|
save_json(file_path, data)
|
||||||
|
st.session_state.last_mtime = get_file_mtime(file_path)
|
||||||
|
st.session_state.data_cache = data
|
||||||
|
|
||||||
|
# Refresh UI
|
||||||
|
st.session_state.ui_reset_token += 1
|
||||||
|
|
||||||
|
st.toast("Restored!", icon="⏪")
|
||||||
|
st.rerun()
|
||||||
|
|
||||||
|
if bh2.button("✏️", key=f"h_edit_{idx}"):
|
||||||
|
st.session_state.edit_history_idx = idx
|
||||||
|
st.rerun()
|
||||||
|
|
||||||
|
if bh3.button("🗑️", key=f"h_del_{idx}"):
|
||||||
|
history.pop(idx)
|
||||||
|
save_json(file_path, data)
|
||||||
|
st.session_state.last_mtime = get_file_mtime(file_path)
|
||||||
|
st.session_state.data_cache = data
|
||||||
|
st.rerun()
|
||||||
+143
@@ -0,0 +1,143 @@
|
|||||||
|
import streamlit as st
|
||||||
|
import json
|
||||||
|
import graphviz
|
||||||
|
import time
|
||||||
|
from history_tree import HistoryTree
|
||||||
|
from utils import save_json
|
||||||
|
|
||||||
|
def render_timeline_tab(data, file_path):
|
||||||
|
tree_data = data.get("history_tree", {})
|
||||||
|
if not tree_data:
|
||||||
|
st.info("No history timeline exists. Make some changes in the Editor first!")
|
||||||
|
return
|
||||||
|
|
||||||
|
htree = HistoryTree(tree_data)
|
||||||
|
|
||||||
|
if 'restored_indicator' in st.session_state and st.session_state.restored_indicator:
|
||||||
|
st.info(f"📍 Editing Restored Version: **{st.session_state.restored_indicator}**")
|
||||||
|
|
||||||
|
# --- VIEW SWITCHER ---
|
||||||
|
c_title, c_view = st.columns([2, 1])
|
||||||
|
c_title.subheader("🕰️ Version History")
|
||||||
|
|
||||||
|
view_mode = c_view.radio(
|
||||||
|
"View Mode",
|
||||||
|
["🌳 Horizontal", "🌲 Vertical", "📜 Linear Log"],
|
||||||
|
horizontal=True,
|
||||||
|
label_visibility="collapsed"
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- RENDER GRAPH VIEWS ---
|
||||||
|
if view_mode in ["🌳 Horizontal", "🌲 Vertical"]:
|
||||||
|
direction = "LR" if view_mode == "🌳 Horizontal" else "TB"
|
||||||
|
try:
|
||||||
|
graph_dot = htree.generate_graph(direction=direction)
|
||||||
|
st.graphviz_chart(graph_dot, use_container_width=True)
|
||||||
|
except Exception as e:
|
||||||
|
st.error(f"Graph Error: {e}")
|
||||||
|
|
||||||
|
# --- RENDER LINEAR LOG VIEW ---
|
||||||
|
elif view_mode == "📜 Linear Log":
|
||||||
|
st.caption("A simple chronological list of all snapshots.")
|
||||||
|
all_nodes = list(htree.nodes.values())
|
||||||
|
all_nodes.sort(key=lambda x: x["timestamp"], reverse=True)
|
||||||
|
|
||||||
|
for n in all_nodes:
|
||||||
|
is_head = (n["id"] == htree.head_id)
|
||||||
|
with st.container():
|
||||||
|
c1, c2, c3 = st.columns([0.5, 4, 1])
|
||||||
|
with c1:
|
||||||
|
st.markdown("### 📍" if is_head else "### ⚫")
|
||||||
|
with c2:
|
||||||
|
note_txt = n.get('note', 'Step')
|
||||||
|
ts = time.strftime('%H:%M:%S', time.localtime(n['timestamp']))
|
||||||
|
if is_head:
|
||||||
|
st.markdown(f"**{note_txt}** (Current)")
|
||||||
|
else:
|
||||||
|
st.write(f"**{note_txt}**")
|
||||||
|
st.caption(f"ID: {n['id'][:6]} • Time: {ts}")
|
||||||
|
with c3:
|
||||||
|
if not is_head:
|
||||||
|
if st.button("⏪", key=f"log_rst_{n['id']}", help="Restore this version"):
|
||||||
|
data.update(n["data"])
|
||||||
|
htree.head_id = n['id']
|
||||||
|
data["history_tree"] = htree.to_dict()
|
||||||
|
save_json(file_path, data)
|
||||||
|
st.session_state.ui_reset_token += 1
|
||||||
|
label = f"{n.get('note')} ({n['id'][:4]})"
|
||||||
|
st.session_state.restored_indicator = label
|
||||||
|
st.toast(f"Restored!", icon="🔄")
|
||||||
|
st.rerun()
|
||||||
|
st.divider()
|
||||||
|
|
||||||
|
st.markdown("---")
|
||||||
|
|
||||||
|
# --- ACTIONS & SELECTION ---
|
||||||
|
col_sel, col_act = st.columns([3, 1])
|
||||||
|
|
||||||
|
all_nodes = list(htree.nodes.values())
|
||||||
|
all_nodes.sort(key=lambda x: x["timestamp"], reverse=True)
|
||||||
|
|
||||||
|
def fmt_node(n):
|
||||||
|
return f"{n.get('note', 'Step')} ({n['id']})"
|
||||||
|
|
||||||
|
with col_sel:
|
||||||
|
current_idx = 0
|
||||||
|
for i, n in enumerate(all_nodes):
|
||||||
|
if n["id"] == htree.head_id:
|
||||||
|
current_idx = i
|
||||||
|
break
|
||||||
|
|
||||||
|
selected_node = st.selectbox(
|
||||||
|
"Select Version to Manage:",
|
||||||
|
all_nodes,
|
||||||
|
format_func=fmt_node,
|
||||||
|
index=current_idx
|
||||||
|
)
|
||||||
|
|
||||||
|
if selected_node:
|
||||||
|
node_data = selected_node["data"]
|
||||||
|
|
||||||
|
# --- ACTIONS ---
|
||||||
|
with col_act:
|
||||||
|
st.write(""); st.write("")
|
||||||
|
if st.button("⏪ Restore Version", type="primary", use_container_width=True):
|
||||||
|
data.update(node_data)
|
||||||
|
htree.head_id = selected_node['id']
|
||||||
|
data["history_tree"] = htree.to_dict()
|
||||||
|
save_json(file_path, data)
|
||||||
|
st.session_state.ui_reset_token += 1
|
||||||
|
label = f"{selected_node.get('note')} ({selected_node['id'][:4]})"
|
||||||
|
st.session_state.restored_indicator = label
|
||||||
|
st.toast(f"Restored!", icon="🔄")
|
||||||
|
st.rerun()
|
||||||
|
|
||||||
|
# --- RENAME ---
|
||||||
|
rn_col1, rn_col2 = st.columns([3, 1])
|
||||||
|
new_label = rn_col1.text_input("Rename Label", value=selected_node.get("note", ""))
|
||||||
|
if rn_col2.button("Update Label"):
|
||||||
|
selected_node["note"] = new_label
|
||||||
|
data["history_tree"] = htree.to_dict()
|
||||||
|
save_json(file_path, data)
|
||||||
|
st.rerun()
|
||||||
|
|
||||||
|
# --- DANGER ZONE ---
|
||||||
|
st.markdown("---")
|
||||||
|
with st.expander("⚠️ Danger Zone (Delete)"):
|
||||||
|
st.warning("Deleting a node cannot be undone.")
|
||||||
|
if st.button("🗑️ Delete This Node", type="primary"):
|
||||||
|
if selected_node['id'] in htree.nodes:
|
||||||
|
del htree.nodes[selected_node['id']]
|
||||||
|
for b, tip in list(htree.branches.items()):
|
||||||
|
if tip == selected_node['id']:
|
||||||
|
del htree.branches[b]
|
||||||
|
if htree.head_id == selected_node['id']:
|
||||||
|
if htree.nodes:
|
||||||
|
fallback = sorted(htree.nodes.values(), key=lambda x: x["timestamp"])[-1]
|
||||||
|
htree.head_id = fallback["id"]
|
||||||
|
else:
|
||||||
|
htree.head_id = None
|
||||||
|
data["history_tree"] = htree.to_dict()
|
||||||
|
save_json(file_path, data)
|
||||||
|
st.toast("Node Deleted", icon="🗑️")
|
||||||
|
st.rerun()
|
||||||
@@ -1,638 +0,0 @@
|
|||||||
import asyncio
|
|
||||||
import copy
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import time
|
|
||||||
|
|
||||||
from nicegui import ui
|
|
||||||
|
|
||||||
from state import AppState
|
|
||||||
from snapshot_timeline import SnapshotTimeline, diff_snapshots
|
|
||||||
from utils import save_json, load_json, sync_to_db, KEY_BATCH_DATA, KEY_HISTORY_TREE
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
# ======================================================================
|
|
||||||
# Main entry point
|
|
||||||
# ======================================================================
|
|
||||||
|
|
||||||
def render_timeline_tab(state: AppState):
|
|
||||||
t0 = time.perf_counter()
|
|
||||||
logger.info("render_timeline_tab START")
|
|
||||||
data = state.data_cache
|
|
||||||
file_path = state.file_path
|
|
||||||
|
|
||||||
tree_data = data.get(KEY_HISTORY_TREE, {})
|
|
||||||
if not tree_data:
|
|
||||||
ui.label('No version history exists. Make some changes in the Editor first!').classes(
|
|
||||||
'text-subtitle1 q-pa-md')
|
|
||||||
return
|
|
||||||
|
|
||||||
timeline = SnapshotTimeline(tree_data)
|
|
||||||
if not timeline.snapshots:
|
|
||||||
ui.label('No snapshots found in history.').classes('text-subtitle1 q-pa-md')
|
|
||||||
return
|
|
||||||
|
|
||||||
# Local UI state
|
|
||||||
ui_state = {
|
|
||||||
'selected_id': state.timeline_selected_id or timeline.current_id,
|
|
||||||
'search': '',
|
|
||||||
'filter': 'All', # All | Pinned | Auto
|
|
||||||
}
|
|
||||||
|
|
||||||
if state.restored_indicator:
|
|
||||||
ui.label(f'Editing Restored Version: {state.restored_indicator}').classes(
|
|
||||||
'text-info q-pa-sm')
|
|
||||||
|
|
||||||
ui.label('Version History').classes('text-h6 q-mb-sm')
|
|
||||||
|
|
||||||
# Mutable container so left/right panels can cross-reference each other's refreshables
|
|
||||||
panels: dict = {}
|
|
||||||
|
|
||||||
# ======================================================================
|
|
||||||
# Splitter layout: 35% left (list) / 65% right (detail)
|
|
||||||
# ======================================================================
|
|
||||||
with ui.splitter(value=35).classes('w-full').style('height: calc(100vh - 200px); min-height: 600px') as splitter:
|
|
||||||
|
|
||||||
# ==============================================================
|
|
||||||
# LEFT PANEL — Snapshot list
|
|
||||||
# ==============================================================
|
|
||||||
with splitter.before:
|
|
||||||
with ui.column().classes('w-full q-pa-sm').style('height: 100%'):
|
|
||||||
# Search + filter
|
|
||||||
search_input = ui.input(
|
|
||||||
placeholder='Search notes...',
|
|
||||||
).classes('w-full').props('dense outlined clearable')
|
|
||||||
|
|
||||||
with ui.row().classes('w-full q-gutter-xs'):
|
|
||||||
filter_toggle = ui.toggle(
|
|
||||||
['All', 'Pinned', 'Auto'], value='All',
|
|
||||||
).props('dense no-caps')
|
|
||||||
|
|
||||||
@ui.refreshable
|
|
||||||
def render_snapshot_list():
|
|
||||||
_render_snapshot_list(
|
|
||||||
timeline, ui_state, data, file_path, state,
|
|
||||||
render_snapshot_list, panels)
|
|
||||||
|
|
||||||
panels['list'] = render_snapshot_list
|
|
||||||
|
|
||||||
def _on_search(e):
|
|
||||||
ui_state['search'] = search_input.value or ''
|
|
||||||
render_snapshot_list.refresh()
|
|
||||||
|
|
||||||
def _on_filter(e):
|
|
||||||
ui_state['filter'] = e.value
|
|
||||||
render_snapshot_list.refresh()
|
|
||||||
|
|
||||||
search_input.on('update:model-value', _on_search)
|
|
||||||
filter_toggle.on_value_change(_on_filter)
|
|
||||||
|
|
||||||
render_snapshot_list()
|
|
||||||
|
|
||||||
# ==============================================================
|
|
||||||
# RIGHT PANEL — Detail tabs
|
|
||||||
# ==============================================================
|
|
||||||
with splitter.after:
|
|
||||||
@ui.refreshable
|
|
||||||
def render_detail_panel():
|
|
||||||
_render_detail_panel(timeline, ui_state, data, file_path, state,
|
|
||||||
panels)
|
|
||||||
|
|
||||||
panels['detail'] = render_detail_panel
|
|
||||||
render_detail_panel()
|
|
||||||
|
|
||||||
logger.info("render_timeline_tab END (%.3fs)", time.perf_counter() - t0)
|
|
||||||
|
|
||||||
|
|
||||||
# ======================================================================
|
|
||||||
# Left panel: snapshot list
|
|
||||||
# ======================================================================
|
|
||||||
|
|
||||||
def _render_snapshot_list(timeline, ui_state, data, file_path, state,
|
|
||||||
refresh_list, panels):
|
|
||||||
snapshots = sorted(timeline.snapshots.values(),
|
|
||||||
key=lambda s: s['timestamp'], reverse=True)
|
|
||||||
|
|
||||||
# Apply filters
|
|
||||||
search_term = ui_state.get('search', '').lower()
|
|
||||||
filter_mode = ui_state.get('filter', 'All')
|
|
||||||
|
|
||||||
if search_term:
|
|
||||||
snapshots = [s for s in snapshots
|
|
||||||
if search_term in s.get('note', '').lower()]
|
|
||||||
if filter_mode == 'Pinned':
|
|
||||||
snapshots = [s for s in snapshots if s.get('pinned')]
|
|
||||||
elif filter_mode == 'Auto':
|
|
||||||
snapshots = [s for s in snapshots if s.get('auto')]
|
|
||||||
|
|
||||||
if not snapshots:
|
|
||||||
ui.label('No snapshots match your filter.').classes('text-caption q-pa-md')
|
|
||||||
return
|
|
||||||
|
|
||||||
with ui.scroll_area().classes('w-full').style('flex: 1; min-height: 0'):
|
|
||||||
for snap in snapshots:
|
|
||||||
sid = snap['id']
|
|
||||||
is_current = sid == timeline.current_id
|
|
||||||
is_selected = sid == ui_state.get('selected_id')
|
|
||||||
is_pinned = snap.get('pinned', False)
|
|
||||||
is_auto = snap.get('auto', False)
|
|
||||||
|
|
||||||
# Card styling
|
|
||||||
border = ''
|
|
||||||
if is_current:
|
|
||||||
border = 'border-left: 4px solid #eebb00;'
|
|
||||||
if is_selected:
|
|
||||||
border = 'border-left: 4px solid #4caf50;'
|
|
||||||
bg = 'background: rgba(76,175,80,0.08) !important;' if is_selected else ''
|
|
||||||
|
|
||||||
def select_snap(snap_id=sid):
|
|
||||||
ui_state['selected_id'] = snap_id
|
|
||||||
state.timeline_selected_id = snap_id
|
|
||||||
refresh_list.refresh()
|
|
||||||
detail = panels.get('detail')
|
|
||||||
if detail is not None:
|
|
||||||
detail.refresh()
|
|
||||||
|
|
||||||
with ui.card().classes('w-full q-mb-xs q-pa-xs cursor-pointer').style(
|
|
||||||
f'{border} {bg}').on('click', select_snap):
|
|
||||||
with ui.row().classes('w-full items-center no-wrap'):
|
|
||||||
# Icon
|
|
||||||
if is_pinned:
|
|
||||||
icon_name = 'push_pin'
|
|
||||||
icon_cls = 'text-amber'
|
|
||||||
elif is_auto:
|
|
||||||
icon_name = 'bolt'
|
|
||||||
icon_cls = 'text-grey'
|
|
||||||
else:
|
|
||||||
icon_name = 'save'
|
|
||||||
icon_cls = 'text-primary'
|
|
||||||
ui.icon(icon_name, size='sm').classes(icon_cls)
|
|
||||||
|
|
||||||
# Text
|
|
||||||
with ui.column().classes('col q-ml-xs').style('min-width: 0'):
|
|
||||||
note = snap.get('note', 'Snapshot')
|
|
||||||
lbl = ui.label(note).classes('text-body2 ellipsis')
|
|
||||||
if is_current:
|
|
||||||
lbl.classes('text-bold')
|
|
||||||
ts = time.strftime('%b %d %H:%M',
|
|
||||||
time.localtime(snap['timestamp']))
|
|
||||||
seq_count = snap.get('seq_count', '?')
|
|
||||||
ui.label(f'{ts} \u00b7 {seq_count} seqs').classes(
|
|
||||||
'text-caption text-grey')
|
|
||||||
|
|
||||||
# Badges
|
|
||||||
if is_current:
|
|
||||||
ui.badge('current', color='amber').props('dense')
|
|
||||||
|
|
||||||
# Pin toggle
|
|
||||||
async def toggle_pin(snap_id=sid):
|
|
||||||
timeline.toggle_pin(snap_id)
|
|
||||||
data[KEY_HISTORY_TREE] = timeline.to_dict()
|
|
||||||
snapshot = json.loads(json.dumps(data))
|
|
||||||
await asyncio.to_thread(save_json, file_path, snapshot)
|
|
||||||
refresh_list.refresh()
|
|
||||||
|
|
||||||
pin_icon = 'push_pin' if is_pinned else 'o_push_pin'
|
|
||||||
ui.button(icon=pin_icon, on_click=toggle_pin).props(
|
|
||||||
'flat dense round size=xs').on('click.stop', lambda: None)
|
|
||||||
|
|
||||||
|
|
||||||
# ======================================================================
|
|
||||||
# Right panel: detail tabs (Preview / Compare / Cherry-pick)
|
|
||||||
# ======================================================================
|
|
||||||
|
|
||||||
def _render_detail_panel(timeline, ui_state, data, file_path, state,
|
|
||||||
panels):
|
|
||||||
sel_id = ui_state.get('selected_id')
|
|
||||||
if not sel_id or sel_id not in timeline.snapshots:
|
|
||||||
ui.label('Select a snapshot from the list.').classes('text-caption q-pa-lg')
|
|
||||||
return
|
|
||||||
|
|
||||||
def _refresh_both():
|
|
||||||
"""Refresh both list and detail panels."""
|
|
||||||
lp = panels.get('list')
|
|
||||||
dp = panels.get('detail')
|
|
||||||
if lp:
|
|
||||||
lp.refresh()
|
|
||||||
if dp:
|
|
||||||
dp.refresh()
|
|
||||||
|
|
||||||
snap = timeline.snapshots[sel_id]
|
|
||||||
note = snap.get('note', 'Snapshot')
|
|
||||||
ts = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(snap['timestamp']))
|
|
||||||
ui.label(f'{note}').classes('text-subtitle1 text-bold')
|
|
||||||
ui.label(f'{ts} \u2022 ID: {sel_id}').classes('text-caption text-grey q-mb-sm')
|
|
||||||
|
|
||||||
# Action buttons
|
|
||||||
with ui.row().classes('q-gutter-sm q-mb-sm'):
|
|
||||||
is_current = sel_id == timeline.current_id
|
|
||||||
|
|
||||||
if not is_current:
|
|
||||||
async def restore_full():
|
|
||||||
await _restore_snapshot(data, sel_id, timeline, file_path, state)
|
|
||||||
state._render_main.refresh()
|
|
||||||
|
|
||||||
ui.button('Restore Full', icon='restore',
|
|
||||||
on_click=restore_full).props('color=primary dense')
|
|
||||||
|
|
||||||
# Rename
|
|
||||||
rename_input = ui.input(placeholder='New note...').props('dense outlined').classes('w-48')
|
|
||||||
|
|
||||||
async def rename():
|
|
||||||
if rename_input.value and sel_id in timeline.snapshots:
|
|
||||||
timeline.snapshots[sel_id]['note'] = rename_input.value
|
|
||||||
data[KEY_HISTORY_TREE] = timeline.to_dict()
|
|
||||||
snapshot = json.loads(json.dumps(data))
|
|
||||||
await asyncio.to_thread(save_json, file_path, snapshot)
|
|
||||||
if state.db_enabled and state.current_project and state.db:
|
|
||||||
await asyncio.to_thread(sync_to_db, state.db, state.current_project, file_path, snapshot)
|
|
||||||
ui.notify('Note updated', type='positive')
|
|
||||||
_refresh_both()
|
|
||||||
|
|
||||||
ui.button('Rename', on_click=rename).props('flat dense')
|
|
||||||
|
|
||||||
# Delete
|
|
||||||
async def delete_snap():
|
|
||||||
timeline.delete(sel_id)
|
|
||||||
# Clean up DB snapshots
|
|
||||||
if state.db_enabled and state.db and state.current_project:
|
|
||||||
df = state.db.get_data_file_by_names(state.current_project, file_path.stem)
|
|
||||||
if df:
|
|
||||||
state.db.delete_node_snapshots(df['id'], {sel_id})
|
|
||||||
data[KEY_HISTORY_TREE] = timeline.to_dict()
|
|
||||||
snapshot = json.loads(json.dumps(data))
|
|
||||||
await asyncio.to_thread(save_json, file_path, snapshot)
|
|
||||||
if state.db_enabled and state.current_project and state.db:
|
|
||||||
await asyncio.to_thread(sync_to_db, state.db, state.current_project, file_path, snapshot)
|
|
||||||
ui_state['selected_id'] = timeline.current_id
|
|
||||||
state.timeline_selected_id = timeline.current_id
|
|
||||||
ui.notify('Snapshot deleted', type='positive')
|
|
||||||
_refresh_both()
|
|
||||||
|
|
||||||
ui.button(icon='delete', on_click=delete_snap).props('flat dense color=negative')
|
|
||||||
|
|
||||||
# Sub-tabs
|
|
||||||
with ui.tabs().classes('w-full') as tabs:
|
|
||||||
preview_tab = ui.tab('Preview', icon='visibility')
|
|
||||||
compare_tab = ui.tab('Compare', icon='compare')
|
|
||||||
cherry_tab = ui.tab('Cherry-pick', icon='content_paste')
|
|
||||||
|
|
||||||
with ui.tab_panels(tabs, value=preview_tab).classes('w-full'):
|
|
||||||
with ui.tab_panel(preview_tab):
|
|
||||||
_render_preview_tab(sel_id, timeline, state, file_path)
|
|
||||||
|
|
||||||
with ui.tab_panel(compare_tab):
|
|
||||||
_render_compare_tab(sel_id, timeline, data, state, file_path)
|
|
||||||
|
|
||||||
with ui.tab_panel(cherry_tab):
|
|
||||||
_render_cherry_pick_tab(sel_id, timeline, data, file_path, state,
|
|
||||||
panels)
|
|
||||||
|
|
||||||
|
|
||||||
# ======================================================================
|
|
||||||
# Tab 1: Preview
|
|
||||||
# ======================================================================
|
|
||||||
|
|
||||||
def _render_preview_tab(sel_id, timeline, state, file_path):
|
|
||||||
snap_data = _load_snapshot_data(sel_id, timeline, state, file_path)
|
|
||||||
if not snap_data:
|
|
||||||
ui.label('Snapshot data not available.').classes('text-caption text-warning')
|
|
||||||
return
|
|
||||||
|
|
||||||
batch_list = snap_data.get(KEY_BATCH_DATA, [])
|
|
||||||
if batch_list and isinstance(batch_list, list):
|
|
||||||
ui.label(f'{len(batch_list)} sequences in this snapshot.').classes('text-caption')
|
|
||||||
for i, seq_data in enumerate(batch_list):
|
|
||||||
seq_num = seq_data.get('sequence_number', i + 1)
|
|
||||||
with ui.expansion(f'Sequence #{seq_num}', value=(i == 0)):
|
|
||||||
_render_preview_fields(seq_data)
|
|
||||||
else:
|
|
||||||
_render_preview_fields(snap_data)
|
|
||||||
|
|
||||||
|
|
||||||
# ======================================================================
|
|
||||||
# Tab 2: Compare
|
|
||||||
# ======================================================================
|
|
||||||
|
|
||||||
def _render_compare_tab(sel_id, timeline, data, state, file_path):
|
|
||||||
snap_data = _load_snapshot_data(sel_id, timeline, state, file_path)
|
|
||||||
if not snap_data:
|
|
||||||
ui.label('Snapshot data not available.').classes('text-caption text-warning')
|
|
||||||
return
|
|
||||||
|
|
||||||
old_batch = snap_data.get(KEY_BATCH_DATA, [])
|
|
||||||
new_batch = data.get(KEY_BATCH_DATA, [])
|
|
||||||
|
|
||||||
if not old_batch and not new_batch:
|
|
||||||
ui.label('No batch data to compare.').classes('text-caption')
|
|
||||||
return
|
|
||||||
|
|
||||||
diffs = diff_snapshots(old_batch, new_batch)
|
|
||||||
|
|
||||||
show_all = ui.switch('Show unchanged', value=False)
|
|
||||||
|
|
||||||
@ui.refreshable
|
|
||||||
def render_diff():
|
|
||||||
any_diff = False
|
|
||||||
for d in diffs:
|
|
||||||
if d['status'] == 'unchanged' and not show_all.value:
|
|
||||||
continue
|
|
||||||
any_diff = True
|
|
||||||
seq_num = d['seq_num']
|
|
||||||
status = d['status']
|
|
||||||
|
|
||||||
# Header styling
|
|
||||||
if status == 'added':
|
|
||||||
icon = 'add_circle'
|
|
||||||
color = 'text-positive'
|
|
||||||
label = f'Sequence #{seq_num} \u2014 ADDED (not in snapshot)'
|
|
||||||
elif status == 'removed':
|
|
||||||
icon = 'remove_circle'
|
|
||||||
color = 'text-negative'
|
|
||||||
label = f'Sequence #{seq_num} \u2014 REMOVED (not in current)'
|
|
||||||
elif status == 'changed':
|
|
||||||
icon = 'change_circle'
|
|
||||||
color = 'text-warning'
|
|
||||||
label = f'Sequence #{seq_num} \u2014 {len(d["changes"])} field{"s" if len(d["changes"]) != 1 else ""} changed'
|
|
||||||
else:
|
|
||||||
icon = 'check_circle'
|
|
||||||
color = 'text-grey'
|
|
||||||
label = f'Sequence #{seq_num} \u2014 No changes'
|
|
||||||
|
|
||||||
with ui.expansion(label, icon=icon).classes(f'w-full {color}'):
|
|
||||||
if status == 'changed' and d['changes']:
|
|
||||||
# Table of field changes
|
|
||||||
columns = [
|
|
||||||
{'name': 'field', 'label': 'Field', 'field': 'field', 'align': 'left'},
|
|
||||||
{'name': 'old', 'label': 'Snapshot', 'field': 'old', 'align': 'left'},
|
|
||||||
{'name': 'new', 'label': 'Current', 'field': 'new', 'align': 'left'},
|
|
||||||
]
|
|
||||||
rows = []
|
|
||||||
for c in d['changes']:
|
|
||||||
rows.append({
|
|
||||||
'field': c['field'],
|
|
||||||
'old': _truncate(c['old']),
|
|
||||||
'new': _truncate(c['new']),
|
|
||||||
})
|
|
||||||
ui.table(columns=columns, rows=rows, row_key='field').classes(
|
|
||||||
'w-full').props('dense flat bordered')
|
|
||||||
elif status in ('added', 'removed'):
|
|
||||||
ui.label('Entire sequence differs.').classes('text-caption')
|
|
||||||
|
|
||||||
if not any_diff:
|
|
||||||
ui.label('All sequences are identical.').classes('text-caption q-pa-md')
|
|
||||||
|
|
||||||
show_all.on_value_change(lambda _: render_diff.refresh())
|
|
||||||
render_diff()
|
|
||||||
|
|
||||||
|
|
||||||
# ======================================================================
|
|
||||||
# Tab 3: Cherry-pick Restore
|
|
||||||
# ======================================================================
|
|
||||||
|
|
||||||
def _render_cherry_pick_tab(sel_id, timeline, data, file_path, state,
|
|
||||||
panels):
|
|
||||||
snap_data = _load_snapshot_data(sel_id, timeline, state, file_path)
|
|
||||||
if not snap_data:
|
|
||||||
ui.label('Snapshot data not available.').classes('text-caption text-warning')
|
|
||||||
return
|
|
||||||
|
|
||||||
old_batch = snap_data.get(KEY_BATCH_DATA, [])
|
|
||||||
if not old_batch:
|
|
||||||
ui.label('No sequences in this snapshot.').classes('text-caption')
|
|
||||||
return
|
|
||||||
|
|
||||||
ui.label('Select sequences and fields to restore from this snapshot.').classes(
|
|
||||||
'text-caption q-mb-sm')
|
|
||||||
|
|
||||||
mode = ui.toggle(['Whole sequences', 'Selected fields'], value='Whole sequences').props(
|
|
||||||
'dense no-caps')
|
|
||||||
|
|
||||||
# Build checkboxes per sequence
|
|
||||||
seq_checks: dict[int, ui.checkbox] = {}
|
|
||||||
field_checks: dict[int, dict[str, ui.checkbox]] = {}
|
|
||||||
|
|
||||||
for seq_item in old_batch:
|
|
||||||
seq_num = int(seq_item.get('sequence_number', 0))
|
|
||||||
seq_cb = ui.checkbox(f'Sequence #{seq_num}')
|
|
||||||
seq_checks[seq_num] = seq_cb
|
|
||||||
|
|
||||||
with ui.expansion(f'Fields for #{seq_num}').classes('w-full q-ml-lg'):
|
|
||||||
field_checks[seq_num] = {}
|
|
||||||
for k in sorted(seq_item.keys()):
|
|
||||||
if k == 'sequence_number':
|
|
||||||
continue
|
|
||||||
val_str = _truncate(seq_item.get(k))
|
|
||||||
fcb = ui.checkbox(f'{k}: {val_str}')
|
|
||||||
field_checks[seq_num][k] = fcb
|
|
||||||
|
|
||||||
async def apply_cherry_pick():
|
|
||||||
current_batch = data.get(KEY_BATCH_DATA, [])
|
|
||||||
curr_by_seq = {int(s.get('sequence_number', 0)): s for s in current_batch}
|
|
||||||
old_by_seq = {int(s.get('sequence_number', 0)): s for s in old_batch}
|
|
||||||
|
|
||||||
applied = 0
|
|
||||||
for seq_num, cb in seq_checks.items():
|
|
||||||
if not cb.value:
|
|
||||||
continue
|
|
||||||
if seq_num not in old_by_seq:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if mode.value == 'Whole sequences':
|
|
||||||
# Replace or add entire sequence
|
|
||||||
restored = copy.deepcopy(old_by_seq[seq_num])
|
|
||||||
if seq_num in curr_by_seq:
|
|
||||||
# Find and replace in-place
|
|
||||||
for i, s in enumerate(current_batch):
|
|
||||||
if int(s.get('sequence_number', 0)) == seq_num:
|
|
||||||
current_batch[i] = restored
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
current_batch.append(restored)
|
|
||||||
applied += 1
|
|
||||||
else:
|
|
||||||
# Selected fields only
|
|
||||||
if seq_num not in curr_by_seq:
|
|
||||||
continue
|
|
||||||
target = curr_by_seq[seq_num]
|
|
||||||
fields = field_checks.get(seq_num, {})
|
|
||||||
for field_name, fcb in fields.items():
|
|
||||||
if fcb.value and field_name in old_by_seq[seq_num]:
|
|
||||||
target[field_name] = copy.deepcopy(old_by_seq[seq_num][field_name])
|
|
||||||
applied += 1
|
|
||||||
|
|
||||||
if applied == 0:
|
|
||||||
ui.notify('Nothing selected to restore.', type='warning')
|
|
||||||
return
|
|
||||||
|
|
||||||
data[KEY_BATCH_DATA] = current_batch
|
|
||||||
|
|
||||||
# Auto-snapshot noting the cherry-pick
|
|
||||||
snap_note = timeline.snapshots.get(sel_id, {}).get('note', 'unknown')
|
|
||||||
snap_json = json.dumps({k: v for k, v in data.items()
|
|
||||||
if k != KEY_HISTORY_TREE})
|
|
||||||
snap_payload = json.loads(snap_json)
|
|
||||||
timeline.record(snap_payload, note=f'Cherry-pick from "{snap_note}"')
|
|
||||||
if state.db_enabled and state.current_project and state.db:
|
|
||||||
data[KEY_HISTORY_TREE] = timeline.to_dict()
|
|
||||||
db_snap = json.loads(json.dumps(data))
|
|
||||||
await asyncio.to_thread(sync_to_db, state.db, state.current_project, file_path, db_snap)
|
|
||||||
timeline.strip_snapshots()
|
|
||||||
data[KEY_HISTORY_TREE] = timeline.to_dict()
|
|
||||||
save_snap = json.loads(json.dumps(data))
|
|
||||||
await asyncio.to_thread(save_json, file_path, save_snap)
|
|
||||||
ui.notify(f'Applied {applied} item{"s" if applied != 1 else ""}!', type='positive')
|
|
||||||
for p in ('list', 'detail'):
|
|
||||||
ref = panels.get(p)
|
|
||||||
if ref:
|
|
||||||
ref.refresh()
|
|
||||||
|
|
||||||
ui.button('Apply Selected', icon='check', on_click=apply_cherry_pick).props(
|
|
||||||
'color=primary q-mt-md')
|
|
||||||
|
|
||||||
|
|
||||||
# ======================================================================
|
|
||||||
# Shared helpers
|
|
||||||
# ======================================================================
|
|
||||||
|
|
||||||
def _load_snapshot_data(snap_id, timeline, state, file_path):
|
|
||||||
"""Load snapshot data from inline, DB, or disk fallback."""
|
|
||||||
snap_data = timeline.get_snapshot_data(snap_id)
|
|
||||||
if snap_data:
|
|
||||||
return snap_data
|
|
||||||
|
|
||||||
# Try DB
|
|
||||||
if state and state.db_enabled and state.db and state.current_project and file_path:
|
|
||||||
df = state.db.get_data_file_by_names(state.current_project, file_path.stem)
|
|
||||||
if df:
|
|
||||||
snap_data = state.db.get_node_snapshot(df['id'], snap_id)
|
|
||||||
if snap_data:
|
|
||||||
return snap_data
|
|
||||||
|
|
||||||
# Disk fallback
|
|
||||||
if file_path:
|
|
||||||
try:
|
|
||||||
raw_data, _ = load_json(file_path)
|
|
||||||
tree_on_disk = raw_data.get(KEY_HISTORY_TREE, {})
|
|
||||||
# New format
|
|
||||||
entry = tree_on_disk.get('snapshots', {}).get(snap_id)
|
|
||||||
if entry and 'data' in entry:
|
|
||||||
return entry['data']
|
|
||||||
# Old format
|
|
||||||
entry = tree_on_disk.get('nodes', {}).get(snap_id)
|
|
||||||
if entry and 'data' in entry:
|
|
||||||
return entry['data']
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning("Failed to load snapshot %s from disk: %s", snap_id, e)
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
async def _restore_snapshot(data, snap_id, timeline, file_path, state):
|
|
||||||
"""Restore a snapshot as the current version (full replace)."""
|
|
||||||
snap_data = _load_snapshot_data(snap_id, timeline, state, file_path)
|
|
||||||
if not snap_data:
|
|
||||||
ui.notify('Snapshot data not available', type='negative')
|
|
||||||
return
|
|
||||||
|
|
||||||
node_data = json.loads(json.dumps(snap_data))
|
|
||||||
|
|
||||||
# Preserve history tree
|
|
||||||
preserved_tree = data.get(KEY_HISTORY_TREE)
|
|
||||||
preserved_backup = data.get('history_tree_backup')
|
|
||||||
data.clear()
|
|
||||||
data.update(node_data)
|
|
||||||
if preserved_tree is not None:
|
|
||||||
data[KEY_HISTORY_TREE] = preserved_tree
|
|
||||||
if preserved_backup is not None:
|
|
||||||
data['history_tree_backup'] = preserved_backup
|
|
||||||
|
|
||||||
timeline.current_id = snap_id
|
|
||||||
data[KEY_HISTORY_TREE] = timeline.to_dict()
|
|
||||||
|
|
||||||
snapshot = json.loads(json.dumps(data))
|
|
||||||
await asyncio.to_thread(save_json, file_path, snapshot)
|
|
||||||
if state.db_enabled and state.current_project and state.db:
|
|
||||||
await asyncio.to_thread(sync_to_db, state.db, state.current_project, file_path, snapshot)
|
|
||||||
|
|
||||||
note = timeline.snapshots.get(snap_id, {}).get('note', 'Snapshot')
|
|
||||||
label = f"{note} ({snap_id[:4]})"
|
|
||||||
state.restored_indicator = label
|
|
||||||
ui.notify('Restored!', type='positive')
|
|
||||||
|
|
||||||
|
|
||||||
def _render_preview_fields(item_data: dict):
|
|
||||||
"""Render read-only preview of prompts, settings, LoRAs."""
|
|
||||||
with ui.grid(columns=2).classes('w-full'):
|
|
||||||
ui.textarea('General Positive',
|
|
||||||
value=item_data.get('general_prompt', '')).props('readonly outlined rows=3')
|
|
||||||
ui.textarea('General Negative',
|
|
||||||
value=item_data.get('general_negative', '')).props('readonly outlined rows=3')
|
|
||||||
val_sp = item_data.get('current_prompt', '') or item_data.get('prompt', '')
|
|
||||||
ui.textarea('Specific Positive',
|
|
||||||
value=val_sp).props('readonly outlined rows=3')
|
|
||||||
ui.textarea('Specific Negative',
|
|
||||||
value=item_data.get('negative', '')).props('readonly outlined rows=3')
|
|
||||||
|
|
||||||
with ui.row().classes('w-full q-gutter-md'):
|
|
||||||
ui.input('Camera', value=str(item_data.get('camera', 'static'))).props('readonly outlined')
|
|
||||||
ui.input('Seed', value=str(item_data.get('seed', '-1'))).props('readonly outlined')
|
|
||||||
|
|
||||||
with ui.expansion('LoRA Configuration'):
|
|
||||||
with ui.row().classes('w-full q-gutter-md'):
|
|
||||||
for lora_idx in range(1, 4):
|
|
||||||
for tier, tier_label in [('high', 'High'), ('low', 'Low')]:
|
|
||||||
lora_name = item_data.get(f'lora {lora_idx} {tier}', '')
|
|
||||||
lora_str = item_data.get(f'lora {lora_idx} {tier} strength', 1.0)
|
|
||||||
ui.input(f'L{lora_idx} {tier_label}',
|
|
||||||
value=str(lora_name)).props('readonly outlined dense')
|
|
||||||
ui.number(f'L{lora_idx} {tier_label} Str',
|
|
||||||
value=float(lora_str)).props('readonly outlined dense').style('max-width: 80px')
|
|
||||||
|
|
||||||
vace_keys = ['frame_to_skip', 'vace schedule', 'video file path']
|
|
||||||
if any(k in item_data for k in vace_keys):
|
|
||||||
with ui.expansion('VACE / I2V Settings'):
|
|
||||||
with ui.row().classes('w-full q-gutter-md'):
|
|
||||||
ui.input('Skip Frames',
|
|
||||||
value=str(item_data.get('frame_to_skip', ''))).props('readonly outlined')
|
|
||||||
ui.input('Schedule',
|
|
||||||
value=str(item_data.get('vace schedule', ''))).props('readonly outlined')
|
|
||||||
ui.input('Video Path',
|
|
||||||
value=str(item_data.get('video file path', ''))).props('readonly outlined')
|
|
||||||
|
|
||||||
resolutions = item_data.get('resolutions')
|
|
||||||
if isinstance(resolutions, list) and resolutions:
|
|
||||||
with ui.expansion('Resolutions'):
|
|
||||||
with ui.grid(columns=4).classes('w-full'):
|
|
||||||
for i, entry in enumerate(resolutions):
|
|
||||||
if isinstance(entry, (list, tuple)) and len(entry) >= 2:
|
|
||||||
w, h = entry[0], entry[1]
|
|
||||||
seed = entry[2] if len(entry) >= 3 else 0
|
|
||||||
ui.input(f'#{i} W', value=str(w)).props('readonly outlined dense')
|
|
||||||
ui.input(f'#{i} H', value=str(h)).props('readonly outlined dense')
|
|
||||||
ui.input(f'#{i} Seed', value=str(seed)).props('readonly outlined dense')
|
|
||||||
ui.label('') # grid spacer for 4th column
|
|
||||||
|
|
||||||
known_keys = {
|
|
||||||
'sequence_number', 'general_prompt', 'general_negative', 'current_prompt', 'prompt',
|
|
||||||
'negative', 'camera', 'seed', 'resolutions',
|
|
||||||
'frame_to_skip', 'vace schedule', 'video file path', 'middle frame path', 'end frame path', 'start frame path',
|
|
||||||
'logic index',
|
|
||||||
}
|
|
||||||
# also skip lora keys
|
|
||||||
custom_keys = [
|
|
||||||
k for k in item_data
|
|
||||||
if k not in known_keys and not k.startswith('lora ')
|
|
||||||
]
|
|
||||||
if custom_keys:
|
|
||||||
with ui.expansion('Custom Fields'):
|
|
||||||
with ui.grid(columns=2).classes('w-full'):
|
|
||||||
for k in custom_keys:
|
|
||||||
ui.input(k, value=str(item_data[k])).props('readonly outlined dense')
|
|
||||||
|
|
||||||
|
|
||||||
def _truncate(val, max_len=60):
|
|
||||||
"""Truncate a value for display."""
|
|
||||||
s = str(val) if val is not None else ''
|
|
||||||
return (s[:max_len] + '...') if len(s) > max_len else s
|
|
||||||
@@ -0,0 +1,175 @@
|
|||||||
|
import streamlit as st
|
||||||
|
import json
|
||||||
|
from history_tree import HistoryTree
|
||||||
|
from utils import save_json
|
||||||
|
from streamlit_agraph import agraph, Node, Edge, Config
|
||||||
|
|
||||||
|
def render_timeline_wip(data, file_path):
|
||||||
|
tree_data = data.get("history_tree", {})
|
||||||
|
if not tree_data:
|
||||||
|
st.info("No history timeline exists.")
|
||||||
|
return
|
||||||
|
|
||||||
|
htree = HistoryTree(tree_data)
|
||||||
|
|
||||||
|
# --- 1. BUILD GRAPH ---
|
||||||
|
nodes = []
|
||||||
|
edges = []
|
||||||
|
|
||||||
|
sorted_nodes = sorted(htree.nodes.values(), key=lambda x: x["timestamp"])
|
||||||
|
|
||||||
|
for n in sorted_nodes:
|
||||||
|
nid = n["id"]
|
||||||
|
note = n.get('note', 'Step')
|
||||||
|
short_note = (note[:15] + '..') if len(note) > 15 else note
|
||||||
|
|
||||||
|
color = "#ffffff"
|
||||||
|
border = "#666666"
|
||||||
|
|
||||||
|
if nid == htree.head_id:
|
||||||
|
color = "#fff6cd"
|
||||||
|
border = "#eebb00"
|
||||||
|
|
||||||
|
if nid in htree.branches.values():
|
||||||
|
if color == "#ffffff":
|
||||||
|
color = "#e6ffe6"
|
||||||
|
border = "#44aa44"
|
||||||
|
|
||||||
|
nodes.append(Node(
|
||||||
|
id=nid,
|
||||||
|
label=f"{short_note}\n({nid[:4]})",
|
||||||
|
size=25,
|
||||||
|
shape="box",
|
||||||
|
color=color,
|
||||||
|
borderWidth=1,
|
||||||
|
borderColor=border,
|
||||||
|
font={'color': 'black', 'face': 'Arial', 'size': 14}
|
||||||
|
))
|
||||||
|
|
||||||
|
if n["parent"] and n["parent"] in htree.nodes:
|
||||||
|
edges.append(Edge(
|
||||||
|
source=n["parent"],
|
||||||
|
target=nid,
|
||||||
|
color="#aaaaaa",
|
||||||
|
type="STRAIGHT"
|
||||||
|
))
|
||||||
|
|
||||||
|
config = Config(
|
||||||
|
width="100%",
|
||||||
|
height="400px",
|
||||||
|
directed=True,
|
||||||
|
physics=False,
|
||||||
|
hierarchical=True,
|
||||||
|
layout={
|
||||||
|
"hierarchical": {
|
||||||
|
"enabled": True,
|
||||||
|
"levelSeparation": 150,
|
||||||
|
"nodeSpacing": 100,
|
||||||
|
"treeSpacing": 100,
|
||||||
|
"direction": "LR",
|
||||||
|
"sortMethod": "directed"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
st.subheader("✨ Interactive Timeline")
|
||||||
|
st.caption("Click a node to view its settings below.")
|
||||||
|
|
||||||
|
# --- FIX: REMOVED 'key' ARGUMENT ---
|
||||||
|
selected_id = agraph(nodes=nodes, edges=edges, config=config)
|
||||||
|
|
||||||
|
st.markdown("---")
|
||||||
|
|
||||||
|
# --- 2. DETERMINE TARGET ---
|
||||||
|
target_node_id = selected_id if selected_id else htree.head_id
|
||||||
|
|
||||||
|
if target_node_id and target_node_id in htree.nodes:
|
||||||
|
selected_node = htree.nodes[target_node_id]
|
||||||
|
node_data = selected_node["data"]
|
||||||
|
|
||||||
|
# Header
|
||||||
|
c_h1, c_h2 = st.columns([3, 1])
|
||||||
|
c_h1.markdown(f"### 📄 Previewing: {selected_node.get('note', 'Step')}")
|
||||||
|
c_h1.caption(f"ID: {target_node_id}")
|
||||||
|
|
||||||
|
# Restore Button
|
||||||
|
with c_h2:
|
||||||
|
st.write(""); st.write("")
|
||||||
|
if st.button("⏪ Restore This Version", type="primary", use_container_width=True, key=f"rst_{target_node_id}"):
|
||||||
|
data.update(node_data)
|
||||||
|
htree.head_id = target_node_id
|
||||||
|
|
||||||
|
data["history_tree"] = htree.to_dict()
|
||||||
|
save_json(file_path, data)
|
||||||
|
|
||||||
|
st.session_state.ui_reset_token += 1
|
||||||
|
label = f"{selected_node.get('note')} ({target_node_id[:4]})"
|
||||||
|
st.session_state.restored_indicator = label
|
||||||
|
|
||||||
|
st.toast(f"Restored {target_node_id}!", icon="🔄")
|
||||||
|
st.rerun()
|
||||||
|
|
||||||
|
# --- 3. PREVIEW LOGIC (BATCH VS SINGLE) ---
|
||||||
|
|
||||||
|
# Helper to render one set of inputs
|
||||||
|
def render_preview_fields(item_data, prefix):
|
||||||
|
# A. Prompts
|
||||||
|
p_col1, p_col2 = st.columns(2)
|
||||||
|
with p_col1:
|
||||||
|
val_gp = item_data.get("general_prompt", "")
|
||||||
|
st.text_area("General Positive", value=val_gp, height=80, disabled=True, key=f"{prefix}_gp")
|
||||||
|
|
||||||
|
val_sp = item_data.get("current_prompt", "") or item_data.get("prompt", "")
|
||||||
|
st.text_area("Specific Positive", value=val_sp, height=80, disabled=True, key=f"{prefix}_sp")
|
||||||
|
with p_col2:
|
||||||
|
val_gn = item_data.get("general_negative", "")
|
||||||
|
st.text_area("General Negative", value=val_gn, height=80, disabled=True, key=f"{prefix}_gn")
|
||||||
|
|
||||||
|
val_sn = item_data.get("negative", "")
|
||||||
|
st.text_area("Specific Negative", value=val_sn, height=80, disabled=True, key=f"{prefix}_sn")
|
||||||
|
|
||||||
|
# B. Settings
|
||||||
|
s_col1, s_col2, s_col3 = st.columns(3)
|
||||||
|
s_col1.text_input("Camera", value=str(item_data.get("camera", "static")), disabled=True, key=f"{prefix}_cam")
|
||||||
|
s_col2.text_input("FLF", value=str(item_data.get("flf", "0.0")), disabled=True, key=f"{prefix}_flf")
|
||||||
|
s_col3.text_input("Seed", value=str(item_data.get("seed", "-1")), disabled=True, key=f"{prefix}_seed")
|
||||||
|
|
||||||
|
# C. LoRAs
|
||||||
|
with st.expander("💊 LoRA Configuration", expanded=False):
|
||||||
|
l1, l2, l3 = st.columns(3)
|
||||||
|
with l1:
|
||||||
|
st.text_input("L1 Name", value=item_data.get("lora 1 high", ""), disabled=True, key=f"{prefix}_l1h")
|
||||||
|
st.text_input("L1 Str", value=str(item_data.get("lora 1 low", "")), disabled=True, key=f"{prefix}_l1l")
|
||||||
|
with l2:
|
||||||
|
st.text_input("L2 Name", value=item_data.get("lora 2 high", ""), disabled=True, key=f"{prefix}_l2h")
|
||||||
|
st.text_input("L2 Str", value=str(item_data.get("lora 2 low", "")), disabled=True, key=f"{prefix}_l2l")
|
||||||
|
with l3:
|
||||||
|
st.text_input("L3 Name", value=item_data.get("lora 3 high", ""), disabled=True, key=f"{prefix}_l3h")
|
||||||
|
st.text_input("L3 Str", value=str(item_data.get("lora 3 low", "")), disabled=True, key=f"{prefix}_l3l")
|
||||||
|
|
||||||
|
# D. VACE
|
||||||
|
vace_keys = ["frame_to_skip", "vace schedule", "video file path"]
|
||||||
|
has_vace = any(k in item_data for k in vace_keys)
|
||||||
|
if has_vace:
|
||||||
|
with st.expander("🎞️ VACE / I2V Settings", expanded=False):
|
||||||
|
v1, v2, v3 = st.columns(3)
|
||||||
|
v1.text_input("Skip Frames", value=str(item_data.get("frame_to_skip", "")), disabled=True, key=f"{prefix}_fts")
|
||||||
|
v2.text_input("Schedule", value=str(item_data.get("vace schedule", "")), disabled=True, key=f"{prefix}_vsc")
|
||||||
|
v3.text_input("Video Path", value=str(item_data.get("video file path", "")), disabled=True, key=f"{prefix}_vid")
|
||||||
|
|
||||||
|
# --- DETECT BATCH VS SINGLE ---
|
||||||
|
batch_list = node_data.get("batch_data", [])
|
||||||
|
|
||||||
|
if batch_list and isinstance(batch_list, list) and len(batch_list) > 0:
|
||||||
|
st.info(f"📚 This snapshot contains {len(batch_list)} sequences.")
|
||||||
|
|
||||||
|
for i, seq_data in enumerate(batch_list):
|
||||||
|
seq_num = seq_data.get("sequence_number", i+1)
|
||||||
|
with st.expander(f"🎬 Sequence #{seq_num}", expanded=(i==0)):
|
||||||
|
# Unique prefix for every sequence in every node
|
||||||
|
prefix = f"p_{target_node_id}_s{i}"
|
||||||
|
render_preview_fields(seq_data, prefix)
|
||||||
|
else:
|
||||||
|
# Single File Preview
|
||||||
|
prefix = f"p_{target_node_id}_single"
|
||||||
|
render_preview_fields(node_data, prefix)
|
||||||
@@ -1,5 +0,0 @@
|
|||||||
import sys
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
# Add project root to sys.path so tests can import project modules
|
|
||||||
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
|
||||||
@@ -1 +0,0 @@
|
|||||||
[pytest]
|
|
||||||
@@ -1,369 +0,0 @@
|
|||||||
import json
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from db import ProjectDB
|
|
||||||
from utils import KEY_BATCH_DATA, KEY_HISTORY_TREE
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def db(tmp_path):
|
|
||||||
"""Create a fresh ProjectDB in a temp directory."""
|
|
||||||
db_path = tmp_path / "test.db"
|
|
||||||
pdb = ProjectDB(db_path)
|
|
||||||
yield pdb
|
|
||||||
pdb.close()
|
|
||||||
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# Projects CRUD
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
|
|
||||||
class TestProjects:
|
|
||||||
def test_create_and_get(self, db):
|
|
||||||
pid = db.create_project("proj1", "/some/path", "A test project")
|
|
||||||
assert pid > 0
|
|
||||||
proj = db.get_project("proj1")
|
|
||||||
assert proj is not None
|
|
||||||
assert proj["name"] == "proj1"
|
|
||||||
assert proj["folder_path"] == "/some/path"
|
|
||||||
assert proj["description"] == "A test project"
|
|
||||||
|
|
||||||
def test_list_projects(self, db):
|
|
||||||
db.create_project("beta", "/b")
|
|
||||||
db.create_project("alpha", "/a")
|
|
||||||
projects = db.list_projects()
|
|
||||||
assert len(projects) == 2
|
|
||||||
assert projects[0]["name"] == "alpha"
|
|
||||||
assert projects[1]["name"] == "beta"
|
|
||||||
|
|
||||||
def test_get_nonexistent(self, db):
|
|
||||||
assert db.get_project("nope") is None
|
|
||||||
|
|
||||||
def test_delete_project(self, db):
|
|
||||||
db.create_project("to_delete", "/x")
|
|
||||||
assert db.delete_project("to_delete") is True
|
|
||||||
assert db.get_project("to_delete") is None
|
|
||||||
|
|
||||||
def test_delete_nonexistent(self, db):
|
|
||||||
assert db.delete_project("nope") is False
|
|
||||||
|
|
||||||
def test_unique_name_constraint(self, db):
|
|
||||||
db.create_project("dup", "/a")
|
|
||||||
with pytest.raises(Exception):
|
|
||||||
db.create_project("dup", "/b")
|
|
||||||
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# Data files
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
|
|
||||||
class TestDataFiles:
|
|
||||||
def test_create_and_list(self, db):
|
|
||||||
pid = db.create_project("p1", "/p1")
|
|
||||||
df_id = db.create_data_file(pid, "batch_i2v", "i2v", {"extra": "meta"})
|
|
||||||
assert df_id > 0
|
|
||||||
files = db.list_data_files(pid)
|
|
||||||
assert len(files) == 1
|
|
||||||
assert files[0]["name"] == "batch_i2v"
|
|
||||||
assert files[0]["data_type"] == "i2v"
|
|
||||||
|
|
||||||
def test_get_data_file(self, db):
|
|
||||||
pid = db.create_project("p1", "/p1")
|
|
||||||
db.create_data_file(pid, "batch_i2v", "i2v", {"key": "value"})
|
|
||||||
df = db.get_data_file(pid, "batch_i2v")
|
|
||||||
assert df is not None
|
|
||||||
assert df["top_level"] == {"key": "value"}
|
|
||||||
|
|
||||||
def test_get_data_file_by_names(self, db):
|
|
||||||
pid = db.create_project("p1", "/p1")
|
|
||||||
db.create_data_file(pid, "batch_i2v", "i2v")
|
|
||||||
df = db.get_data_file_by_names("p1", "batch_i2v")
|
|
||||||
assert df is not None
|
|
||||||
assert df["name"] == "batch_i2v"
|
|
||||||
|
|
||||||
def test_get_nonexistent_data_file(self, db):
|
|
||||||
pid = db.create_project("p1", "/p1")
|
|
||||||
assert db.get_data_file(pid, "nope") is None
|
|
||||||
|
|
||||||
def test_unique_constraint(self, db):
|
|
||||||
pid = db.create_project("p1", "/p1")
|
|
||||||
db.create_data_file(pid, "batch_i2v", "i2v")
|
|
||||||
with pytest.raises(Exception):
|
|
||||||
db.create_data_file(pid, "batch_i2v", "vace")
|
|
||||||
|
|
||||||
def test_cascade_delete(self, db):
|
|
||||||
pid = db.create_project("p1", "/p1")
|
|
||||||
df_id = db.create_data_file(pid, "batch_i2v", "i2v")
|
|
||||||
db.upsert_sequence(df_id, 1, {"prompt": "hello"})
|
|
||||||
db.save_history_tree(df_id, {"nodes": {}})
|
|
||||||
db.delete_project("p1")
|
|
||||||
assert db.get_data_file(pid, "batch_i2v") is None
|
|
||||||
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# Sequences
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
|
|
||||||
class TestSequences:
|
|
||||||
def test_upsert_and_get(self, db):
|
|
||||||
pid = db.create_project("p1", "/p1")
|
|
||||||
df_id = db.create_data_file(pid, "batch", "generic")
|
|
||||||
db.upsert_sequence(df_id, 1, {"prompt": "hello", "seed": 42})
|
|
||||||
data = db.get_sequence(df_id, 1)
|
|
||||||
assert data == {"prompt": "hello", "seed": 42}
|
|
||||||
|
|
||||||
def test_upsert_updates_existing(self, db):
|
|
||||||
pid = db.create_project("p1", "/p1")
|
|
||||||
df_id = db.create_data_file(pid, "batch", "generic")
|
|
||||||
db.upsert_sequence(df_id, 1, {"prompt": "v1"})
|
|
||||||
db.upsert_sequence(df_id, 1, {"prompt": "v2"})
|
|
||||||
data = db.get_sequence(df_id, 1)
|
|
||||||
assert data["prompt"] == "v2"
|
|
||||||
|
|
||||||
def test_list_sequences(self, db):
|
|
||||||
pid = db.create_project("p1", "/p1")
|
|
||||||
df_id = db.create_data_file(pid, "batch", "generic")
|
|
||||||
db.upsert_sequence(df_id, 3, {"a": 1})
|
|
||||||
db.upsert_sequence(df_id, 1, {"b": 2})
|
|
||||||
db.upsert_sequence(df_id, 2, {"c": 3})
|
|
||||||
seqs = db.list_sequences(df_id)
|
|
||||||
assert seqs == [1, 2, 3]
|
|
||||||
|
|
||||||
def test_get_nonexistent_sequence(self, db):
|
|
||||||
pid = db.create_project("p1", "/p1")
|
|
||||||
df_id = db.create_data_file(pid, "batch", "generic")
|
|
||||||
assert db.get_sequence(df_id, 99) is None
|
|
||||||
|
|
||||||
def test_get_sequence_keys(self, db):
|
|
||||||
pid = db.create_project("p1", "/p1")
|
|
||||||
df_id = db.create_data_file(pid, "batch", "generic")
|
|
||||||
db.upsert_sequence(df_id, 1, {
|
|
||||||
"prompt": "hello",
|
|
||||||
"seed": 42,
|
|
||||||
"cfg": 1.5,
|
|
||||||
"flag": True,
|
|
||||||
})
|
|
||||||
keys, types = db.get_sequence_keys(df_id, 1)
|
|
||||||
assert "prompt" in keys
|
|
||||||
assert "seed" in keys
|
|
||||||
idx_prompt = keys.index("prompt")
|
|
||||||
idx_seed = keys.index("seed")
|
|
||||||
idx_cfg = keys.index("cfg")
|
|
||||||
idx_flag = keys.index("flag")
|
|
||||||
assert types[idx_prompt] == "STRING"
|
|
||||||
assert types[idx_seed] == "INT"
|
|
||||||
assert types[idx_cfg] == "FLOAT"
|
|
||||||
assert types[idx_flag] == "STRING" # bools -> STRING
|
|
||||||
|
|
||||||
def test_get_sequence_keys_nonexistent(self, db):
|
|
||||||
pid = db.create_project("p1", "/p1")
|
|
||||||
df_id = db.create_data_file(pid, "batch", "generic")
|
|
||||||
keys, types = db.get_sequence_keys(df_id, 99)
|
|
||||||
assert keys == []
|
|
||||||
assert types == []
|
|
||||||
|
|
||||||
def test_delete_sequences_for_file(self, db):
|
|
||||||
pid = db.create_project("p1", "/p1")
|
|
||||||
df_id = db.create_data_file(pid, "batch", "generic")
|
|
||||||
db.upsert_sequence(df_id, 1, {"a": 1})
|
|
||||||
db.upsert_sequence(df_id, 2, {"b": 2})
|
|
||||||
db.delete_sequences_for_file(df_id)
|
|
||||||
assert db.list_sequences(df_id) == []
|
|
||||||
|
|
||||||
def test_count_sequences(self, db):
|
|
||||||
pid = db.create_project("p1", "/p1")
|
|
||||||
df_id = db.create_data_file(pid, "batch", "generic")
|
|
||||||
assert db.count_sequences(df_id) == 0
|
|
||||||
db.upsert_sequence(df_id, 1, {"a": 1})
|
|
||||||
db.upsert_sequence(df_id, 2, {"b": 2})
|
|
||||||
db.upsert_sequence(df_id, 3, {"c": 3})
|
|
||||||
assert db.count_sequences(df_id) == 3
|
|
||||||
|
|
||||||
def test_query_total_sequences(self, db):
|
|
||||||
pid = db.create_project("p1", "/p1")
|
|
||||||
df_id = db.create_data_file(pid, "batch", "generic")
|
|
||||||
db.upsert_sequence(df_id, 1, {"a": 1})
|
|
||||||
db.upsert_sequence(df_id, 2, {"b": 2})
|
|
||||||
assert db.query_total_sequences("p1", "batch") == 2
|
|
||||||
|
|
||||||
def test_query_total_sequences_nonexistent(self, db):
|
|
||||||
assert db.query_total_sequences("nope", "nope") == 0
|
|
||||||
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# History trees
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
|
|
||||||
class TestHistoryTrees:
|
|
||||||
def test_save_and_get(self, db):
|
|
||||||
pid = db.create_project("p1", "/p1")
|
|
||||||
df_id = db.create_data_file(pid, "batch", "generic")
|
|
||||||
tree = {"nodes": {"abc": {"id": "abc"}}, "head_id": "abc"}
|
|
||||||
db.save_history_tree(df_id, tree)
|
|
||||||
result = db.get_history_tree(df_id)
|
|
||||||
assert result == tree
|
|
||||||
|
|
||||||
def test_upsert_updates(self, db):
|
|
||||||
pid = db.create_project("p1", "/p1")
|
|
||||||
df_id = db.create_data_file(pid, "batch", "generic")
|
|
||||||
db.save_history_tree(df_id, {"snapshots": {}, "v": 1})
|
|
||||||
db.save_history_tree(df_id, {"snapshots": {}, "v": 2})
|
|
||||||
result = db.get_history_tree(df_id)
|
|
||||||
assert result == {"snapshots": {}, "v": 2}
|
|
||||||
|
|
||||||
def test_get_nonexistent(self, db):
|
|
||||||
pid = db.create_project("p1", "/p1")
|
|
||||||
df_id = db.create_data_file(pid, "batch", "generic")
|
|
||||||
assert db.get_history_tree(df_id) is None
|
|
||||||
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# Import
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
|
|
||||||
class TestImport:
|
|
||||||
def test_import_json_file(self, db, tmp_path):
|
|
||||||
pid = db.create_project("p1", "/p1")
|
|
||||||
json_path = tmp_path / "batch_prompt_i2v.json"
|
|
||||||
data = {
|
|
||||||
KEY_BATCH_DATA: [
|
|
||||||
{"sequence_number": 1, "prompt": "hello", "seed": 42},
|
|
||||||
{"sequence_number": 2, "prompt": "world", "seed": 99},
|
|
||||||
],
|
|
||||||
KEY_HISTORY_TREE: {"nodes": {}, "head_id": None},
|
|
||||||
}
|
|
||||||
json_path.write_text(json.dumps(data))
|
|
||||||
|
|
||||||
df_id = db.import_json_file(pid, json_path, "i2v")
|
|
||||||
assert df_id > 0
|
|
||||||
|
|
||||||
seqs = db.list_sequences(df_id)
|
|
||||||
assert seqs == [1, 2]
|
|
||||||
|
|
||||||
s1 = db.get_sequence(df_id, 1)
|
|
||||||
assert s1["prompt"] == "hello"
|
|
||||||
assert s1["seed"] == 42
|
|
||||||
|
|
||||||
tree = db.get_history_tree(df_id)
|
|
||||||
assert tree == {"nodes": {}, "head_id": None}
|
|
||||||
|
|
||||||
def test_import_file_name_from_stem(self, db, tmp_path):
|
|
||||||
pid = db.create_project("p1", "/p1")
|
|
||||||
json_path = tmp_path / "my_batch.json"
|
|
||||||
json_path.write_text(json.dumps({KEY_BATCH_DATA: [{"sequence_number": 1}]}))
|
|
||||||
db.import_json_file(pid, json_path)
|
|
||||||
df = db.get_data_file(pid, "my_batch")
|
|
||||||
assert df is not None
|
|
||||||
|
|
||||||
def test_import_no_batch_data(self, db, tmp_path):
|
|
||||||
pid = db.create_project("p1", "/p1")
|
|
||||||
json_path = tmp_path / "simple.json"
|
|
||||||
json_path.write_text(json.dumps({"prompt": "flat file"}))
|
|
||||||
df_id = db.import_json_file(pid, json_path)
|
|
||||||
seqs = db.list_sequences(df_id)
|
|
||||||
assert seqs == []
|
|
||||||
|
|
||||||
def test_reimport_updates_existing(self, db, tmp_path):
|
|
||||||
"""Re-importing the same file should update data, not crash."""
|
|
||||||
pid = db.create_project("p1", "/p1")
|
|
||||||
json_path = tmp_path / "batch.json"
|
|
||||||
|
|
||||||
# First import
|
|
||||||
data_v1 = {KEY_BATCH_DATA: [{"sequence_number": 1, "prompt": "v1"}]}
|
|
||||||
json_path.write_text(json.dumps(data_v1))
|
|
||||||
df_id_1 = db.import_json_file(pid, json_path, "i2v")
|
|
||||||
|
|
||||||
# Second import (same file, updated data)
|
|
||||||
data_v2 = {KEY_BATCH_DATA: [{"sequence_number": 1, "prompt": "v2"}, {"sequence_number": 2, "prompt": "new"}]}
|
|
||||||
json_path.write_text(json.dumps(data_v2))
|
|
||||||
df_id_2 = db.import_json_file(pid, json_path, "vace")
|
|
||||||
|
|
||||||
# Should reuse the same data_file row
|
|
||||||
assert df_id_1 == df_id_2
|
|
||||||
# Data type should be updated
|
|
||||||
df = db.get_data_file(pid, "batch")
|
|
||||||
assert df["data_type"] == "vace"
|
|
||||||
# Sequences should reflect v2
|
|
||||||
seqs = db.list_sequences(df_id_2)
|
|
||||||
assert seqs == [1, 2]
|
|
||||||
s1 = db.get_sequence(df_id_2, 1)
|
|
||||||
assert s1["prompt"] == "v2"
|
|
||||||
|
|
||||||
def test_import_skips_non_dict_batch_items(self, db, tmp_path):
|
|
||||||
"""Non-dict elements in batch_data should be silently skipped, not crash."""
|
|
||||||
pid = db.create_project("p1", "/p1")
|
|
||||||
json_path = tmp_path / "mixed.json"
|
|
||||||
data = {KEY_BATCH_DATA: [
|
|
||||||
{"sequence_number": 1, "prompt": "valid"},
|
|
||||||
"not a dict",
|
|
||||||
42,
|
|
||||||
None,
|
|
||||||
{"sequence_number": 3, "prompt": "also valid"},
|
|
||||||
]}
|
|
||||||
json_path.write_text(json.dumps(data))
|
|
||||||
df_id = db.import_json_file(pid, json_path)
|
|
||||||
|
|
||||||
seqs = db.list_sequences(df_id)
|
|
||||||
assert seqs == [1, 3]
|
|
||||||
|
|
||||||
def test_import_atomic_on_error(self, db, tmp_path):
|
|
||||||
"""If import fails partway, no partial data should be committed."""
|
|
||||||
pid = db.create_project("p1", "/p1")
|
|
||||||
json_path = tmp_path / "batch.json"
|
|
||||||
data = {KEY_BATCH_DATA: [{"sequence_number": 1, "prompt": "hello"}]}
|
|
||||||
json_path.write_text(json.dumps(data))
|
|
||||||
db.import_json_file(pid, json_path)
|
|
||||||
|
|
||||||
# Now try to import with bad data that will cause an error
|
|
||||||
# (overwrite the file with invalid sequence_number that causes int() to fail)
|
|
||||||
bad_data = {KEY_BATCH_DATA: [{"sequence_number": "not_a_number", "prompt": "bad"}]}
|
|
||||||
json_path.write_text(json.dumps(bad_data))
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
db.import_json_file(pid, json_path)
|
|
||||||
|
|
||||||
# Original data should still be intact (rollback worked)
|
|
||||||
df = db.get_data_file(pid, "batch")
|
|
||||||
assert df is not None
|
|
||||||
s1 = db.get_sequence(df["id"], 1)
|
|
||||||
assert s1["prompt"] == "hello"
|
|
||||||
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# Query helpers
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
|
|
||||||
class TestQueryHelpers:
|
|
||||||
def test_query_sequence_data(self, db):
|
|
||||||
pid = db.create_project("myproject", "/mp")
|
|
||||||
df_id = db.create_data_file(pid, "batch_i2v", "i2v")
|
|
||||||
db.upsert_sequence(df_id, 1, {"prompt": "test", "seed": 7})
|
|
||||||
result = db.query_sequence_data("myproject", "batch_i2v", 1)
|
|
||||||
assert result == {"prompt": "test", "seed": 7}
|
|
||||||
|
|
||||||
def test_query_sequence_data_not_found(self, db):
|
|
||||||
assert db.query_sequence_data("nope", "nope", 1) is None
|
|
||||||
|
|
||||||
def test_query_sequence_keys(self, db):
|
|
||||||
pid = db.create_project("myproject", "/mp")
|
|
||||||
df_id = db.create_data_file(pid, "batch_i2v", "i2v")
|
|
||||||
db.upsert_sequence(df_id, 1, {"prompt": "test", "seed": 7})
|
|
||||||
keys, types = db.query_sequence_keys("myproject", "batch_i2v", 1)
|
|
||||||
assert "prompt" in keys
|
|
||||||
assert "seed" in keys
|
|
||||||
|
|
||||||
def test_list_project_files(self, db):
|
|
||||||
pid = db.create_project("p1", "/p1")
|
|
||||||
db.create_data_file(pid, "file_a", "i2v")
|
|
||||||
db.create_data_file(pid, "file_b", "vace")
|
|
||||||
files = db.list_project_files("p1")
|
|
||||||
assert len(files) == 2
|
|
||||||
|
|
||||||
def test_list_project_sequences(self, db):
|
|
||||||
pid = db.create_project("p1", "/p1")
|
|
||||||
df_id = db.create_data_file(pid, "batch", "generic")
|
|
||||||
db.upsert_sequence(df_id, 1, {})
|
|
||||||
db.upsert_sequence(df_id, 2, {})
|
|
||||||
seqs = db.list_project_sequences("p1", "batch")
|
|
||||||
assert seqs == [1, 2]
|
|
||||||
@@ -1,67 +0,0 @@
|
|||||||
import pytest
|
|
||||||
from history_tree import HistoryTree
|
|
||||||
|
|
||||||
|
|
||||||
def test_commit_creates_node_with_correct_parent():
|
|
||||||
tree = HistoryTree({})
|
|
||||||
id1 = tree.commit({"a": 1}, note="first")
|
|
||||||
id2 = tree.commit({"b": 2}, note="second")
|
|
||||||
|
|
||||||
assert tree.nodes[id1]["parent"] is None
|
|
||||||
assert tree.nodes[id2]["parent"] == id1
|
|
||||||
|
|
||||||
|
|
||||||
def test_checkout_returns_correct_data():
|
|
||||||
tree = HistoryTree({})
|
|
||||||
id1 = tree.commit({"val": 42}, note="snap")
|
|
||||||
result = tree.checkout(id1)
|
|
||||||
assert result == {"val": 42}
|
|
||||||
|
|
||||||
|
|
||||||
def test_checkout_nonexistent_returns_none():
|
|
||||||
tree = HistoryTree({})
|
|
||||||
assert tree.checkout("nonexistent") is None
|
|
||||||
|
|
||||||
|
|
||||||
def test_cycle_detection_raises():
|
|
||||||
tree = HistoryTree({})
|
|
||||||
id1 = tree.commit({"a": 1})
|
|
||||||
# Manually introduce a cycle
|
|
||||||
tree.nodes[id1]["parent"] = id1
|
|
||||||
with pytest.raises(ValueError, match="Cycle detected"):
|
|
||||||
tree.commit({"b": 2})
|
|
||||||
|
|
||||||
|
|
||||||
def test_branch_creation_on_detached_head():
|
|
||||||
tree = HistoryTree({})
|
|
||||||
id1 = tree.commit({"a": 1})
|
|
||||||
id2 = tree.commit({"b": 2})
|
|
||||||
# Detach head by checking out a non-tip node
|
|
||||||
tree.checkout(id1)
|
|
||||||
# head_id is now id1, which is no longer a branch tip (main points to id2)
|
|
||||||
id3 = tree.commit({"c": 3})
|
|
||||||
# A new branch should have been created
|
|
||||||
assert len(tree.branches) == 2
|
|
||||||
assert tree.nodes[id3]["parent"] == id1
|
|
||||||
|
|
||||||
|
|
||||||
def test_legacy_migration():
|
|
||||||
legacy = {
|
|
||||||
"prompt_history": [
|
|
||||||
{"note": "Entry A", "seed": 1},
|
|
||||||
{"note": "Entry B", "seed": 2},
|
|
||||||
]
|
|
||||||
}
|
|
||||||
tree = HistoryTree(legacy)
|
|
||||||
assert len(tree.nodes) == 2
|
|
||||||
assert tree.head_id is not None
|
|
||||||
assert tree.branches["main"] == tree.head_id
|
|
||||||
|
|
||||||
|
|
||||||
def test_to_dict_roundtrip():
|
|
||||||
tree = HistoryTree({})
|
|
||||||
tree.commit({"x": 1}, note="test")
|
|
||||||
d = tree.to_dict()
|
|
||||||
tree2 = HistoryTree(d)
|
|
||||||
assert tree2.head_id == tree.head_id
|
|
||||||
assert tree2.nodes == tree.nodes
|
|
||||||
@@ -1,508 +0,0 @@
|
|||||||
import json
|
|
||||||
from unittest.mock import patch, MagicMock
|
|
||||||
from io import BytesIO
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from project_loader import (
|
|
||||||
ProjectLoaderDynamic,
|
|
||||||
_fetch_json,
|
|
||||||
_fetch_data,
|
|
||||||
_fetch_keys,
|
|
||||||
MAX_DYNAMIC_OUTPUTS,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _mock_urlopen(data: dict):
|
|
||||||
"""Create a mock context manager for urllib.request.urlopen."""
|
|
||||||
response = MagicMock()
|
|
||||||
response.read.return_value = json.dumps(data).encode()
|
|
||||||
response.__enter__ = lambda s: s
|
|
||||||
response.__exit__ = MagicMock(return_value=False)
|
|
||||||
return response
|
|
||||||
|
|
||||||
|
|
||||||
class TestFetchHelpers:
|
|
||||||
def test_fetch_json_success(self):
|
|
||||||
data = {"key": "value"}
|
|
||||||
with patch("project_loader.urllib.request.urlopen", return_value=_mock_urlopen(data)):
|
|
||||||
result = _fetch_json("http://example.com/api")
|
|
||||||
assert result == data
|
|
||||||
|
|
||||||
def test_fetch_json_network_error(self):
|
|
||||||
with patch("project_loader.urllib.request.urlopen", side_effect=OSError("connection refused")):
|
|
||||||
result = _fetch_json("http://example.com/api")
|
|
||||||
assert result["error"] == "network_error"
|
|
||||||
assert "connection refused" in result["message"]
|
|
||||||
|
|
||||||
def test_fetch_json_http_error(self):
|
|
||||||
import urllib.error
|
|
||||||
err = urllib.error.HTTPError(
|
|
||||||
"http://example.com/api", 404, "Not Found", {},
|
|
||||||
BytesIO(json.dumps({"detail": "Project 'x' not found"}).encode())
|
|
||||||
)
|
|
||||||
with patch("project_loader.urllib.request.urlopen", side_effect=err):
|
|
||||||
result = _fetch_json("http://example.com/api")
|
|
||||||
assert result["error"] == "http_error"
|
|
||||||
assert result["status"] == 404
|
|
||||||
assert "not found" in result["message"].lower()
|
|
||||||
|
|
||||||
def test_fetch_data_builds_url(self):
|
|
||||||
data = {"prompt": "hello"}
|
|
||||||
with patch("project_loader.urllib.request.urlopen", return_value=_mock_urlopen(data)) as mock:
|
|
||||||
result = _fetch_data("http://localhost:8080", "proj1", "batch_i2v", 1)
|
|
||||||
assert result == data
|
|
||||||
called_url = mock.call_args[0][0]
|
|
||||||
assert "/api/projects/proj1/files/batch_i2v/data?seq=1" in called_url
|
|
||||||
|
|
||||||
def test_fetch_keys_builds_url(self):
|
|
||||||
data = {"keys": ["prompt"], "types": ["STRING"]}
|
|
||||||
with patch("project_loader.urllib.request.urlopen", return_value=_mock_urlopen(data)) as mock:
|
|
||||||
result = _fetch_keys("http://localhost:8080", "proj1", "batch_i2v", 1)
|
|
||||||
assert result == data
|
|
||||||
called_url = mock.call_args[0][0]
|
|
||||||
assert "/api/projects/proj1/files/batch_i2v/keys?seq=1" in called_url
|
|
||||||
|
|
||||||
def test_fetch_data_strips_trailing_slash(self):
|
|
||||||
data = {"prompt": "hello"}
|
|
||||||
with patch("project_loader.urllib.request.urlopen", return_value=_mock_urlopen(data)) as mock:
|
|
||||||
_fetch_data("http://localhost:8080/", "proj1", "file1", 1)
|
|
||||||
called_url = mock.call_args[0][0]
|
|
||||||
assert "//api" not in called_url
|
|
||||||
|
|
||||||
def test_fetch_data_encodes_special_chars(self):
|
|
||||||
"""Project/file names with spaces or special chars should be percent-encoded."""
|
|
||||||
data = {"prompt": "hello"}
|
|
||||||
with patch("project_loader.urllib.request.urlopen", return_value=_mock_urlopen(data)) as mock:
|
|
||||||
_fetch_data("http://localhost:8080", "my project", "batch file", 1)
|
|
||||||
called_url = mock.call_args[0][0]
|
|
||||||
assert "my%20project" in called_url
|
|
||||||
assert "batch%20file" in called_url
|
|
||||||
assert " " not in called_url.split("?")[0] # no raw spaces in path
|
|
||||||
|
|
||||||
|
|
||||||
class TestProjectLoaderDynamic:
|
|
||||||
def _keys_meta(self, total=5):
|
|
||||||
return {"keys": [], "types": [], "total_sequences": total}
|
|
||||||
|
|
||||||
def test_load_dynamic_with_keys(self):
|
|
||||||
data = {"prompt": "hello", "seed": 42, "cfg": 1.5}
|
|
||||||
node = ProjectLoaderDynamic()
|
|
||||||
with patch("project_loader._fetch_keys", return_value=self._keys_meta()):
|
|
||||||
with patch("project_loader._fetch_data", return_value=data):
|
|
||||||
result = node.load_dynamic(
|
|
||||||
"http://localhost:8080", "proj1", "batch_i2v", 1,
|
|
||||||
output_keys="prompt,seed,cfg"
|
|
||||||
)
|
|
||||||
assert result[0] == 5 # total_sequences
|
|
||||||
assert result[1] == "hello"
|
|
||||||
assert result[2] == 42
|
|
||||||
assert result[3] == 1.5
|
|
||||||
assert len(result) == MAX_DYNAMIC_OUTPUTS + 1
|
|
||||||
|
|
||||||
def test_load_dynamic_with_json_encoded_keys(self):
|
|
||||||
"""JSON-encoded output_keys should be parsed correctly."""
|
|
||||||
import json as _json
|
|
||||||
data = {"my,key": "comma_val", "normal": "ok"}
|
|
||||||
node = ProjectLoaderDynamic()
|
|
||||||
keys_json = _json.dumps(["my,key", "normal"])
|
|
||||||
with patch("project_loader._fetch_keys", return_value=self._keys_meta()):
|
|
||||||
with patch("project_loader._fetch_data", return_value=data):
|
|
||||||
result = node.load_dynamic(
|
|
||||||
"http://localhost:8080", "proj1", "batch_i2v", 1,
|
|
||||||
output_keys=keys_json
|
|
||||||
)
|
|
||||||
assert result[1] == "comma_val"
|
|
||||||
assert result[2] == "ok"
|
|
||||||
|
|
||||||
def test_load_dynamic_type_coercion(self):
|
|
||||||
"""output_types should coerce values to declared types."""
|
|
||||||
import json as _json
|
|
||||||
data = {"seed": "42", "cfg": "1.5", "prompt": "hello"}
|
|
||||||
node = ProjectLoaderDynamic()
|
|
||||||
keys_json = _json.dumps(["seed", "cfg", "prompt"])
|
|
||||||
types_json = _json.dumps(["INT", "FLOAT", "STRING"])
|
|
||||||
with patch("project_loader._fetch_keys", return_value=self._keys_meta()):
|
|
||||||
with patch("project_loader._fetch_data", return_value=data):
|
|
||||||
result = node.load_dynamic(
|
|
||||||
"http://localhost:8080", "proj1", "batch_i2v", 1,
|
|
||||||
output_keys=keys_json, output_types=types_json
|
|
||||||
)
|
|
||||||
assert result[1] == 42 # string "42" coerced to int
|
|
||||||
assert result[2] == 1.5 # string "1.5" coerced to float
|
|
||||||
assert result[3] == "hello" # string stays string
|
|
||||||
|
|
||||||
def test_load_dynamic_empty_keys(self):
|
|
||||||
node = ProjectLoaderDynamic()
|
|
||||||
with patch("project_loader._fetch_keys", return_value=self._keys_meta()):
|
|
||||||
with patch("project_loader._fetch_data", return_value={"prompt": "hello"}):
|
|
||||||
result = node.load_dynamic(
|
|
||||||
"http://localhost:8080", "proj1", "batch_i2v", 1,
|
|
||||||
output_keys=""
|
|
||||||
)
|
|
||||||
# Slot 0 is total_sequences (INT), rest are empty strings
|
|
||||||
assert result[0] == 5
|
|
||||||
assert all(v == "" for v in result[1:])
|
|
||||||
|
|
||||||
def test_load_dynamic_missing_key(self):
|
|
||||||
node = ProjectLoaderDynamic()
|
|
||||||
with patch("project_loader._fetch_keys", return_value=self._keys_meta()):
|
|
||||||
with patch("project_loader._fetch_data", return_value={"prompt": "hello"}):
|
|
||||||
result = node.load_dynamic(
|
|
||||||
"http://localhost:8080", "proj1", "batch_i2v", 1,
|
|
||||||
output_keys="nonexistent"
|
|
||||||
)
|
|
||||||
assert result[1] == ""
|
|
||||||
|
|
||||||
def test_load_dynamic_bool_becomes_string(self):
|
|
||||||
node = ProjectLoaderDynamic()
|
|
||||||
with patch("project_loader._fetch_keys", return_value=self._keys_meta()):
|
|
||||||
with patch("project_loader._fetch_data", return_value={"flag": True}):
|
|
||||||
result = node.load_dynamic(
|
|
||||||
"http://localhost:8080", "proj1", "batch_i2v", 1,
|
|
||||||
output_keys="flag"
|
|
||||||
)
|
|
||||||
assert result[1] == "true"
|
|
||||||
|
|
||||||
def test_load_dynamic_returns_total_sequences(self):
|
|
||||||
"""total_sequences should be the first output from keys metadata."""
|
|
||||||
node = ProjectLoaderDynamic()
|
|
||||||
with patch("project_loader._fetch_keys", return_value={"keys": [], "types": [], "total_sequences": 42}):
|
|
||||||
with patch("project_loader._fetch_data", return_value={}):
|
|
||||||
result = node.load_dynamic(
|
|
||||||
"http://localhost:8080", "proj1", "batch_i2v", 1,
|
|
||||||
output_keys=""
|
|
||||||
)
|
|
||||||
assert result[0] == 42
|
|
||||||
|
|
||||||
def test_load_dynamic_raises_on_network_error(self):
|
|
||||||
"""Network errors from _fetch_keys should raise RuntimeError."""
|
|
||||||
node = ProjectLoaderDynamic()
|
|
||||||
error_resp = {"error": "network_error", "message": "Connection refused"}
|
|
||||||
with patch("project_loader._fetch_keys", return_value=error_resp):
|
|
||||||
with pytest.raises(RuntimeError, match="Failed to fetch project keys"):
|
|
||||||
node.load_dynamic("http://localhost:8080", "proj1", "batch", 1)
|
|
||||||
|
|
||||||
def test_load_dynamic_raises_on_data_fetch_error(self):
|
|
||||||
"""Network errors from _fetch_data should raise RuntimeError."""
|
|
||||||
node = ProjectLoaderDynamic()
|
|
||||||
error_resp = {"error": "http_error", "status": 404, "message": "Sequence not found"}
|
|
||||||
with patch("project_loader._fetch_keys", return_value=self._keys_meta()):
|
|
||||||
with patch("project_loader._fetch_data", return_value=error_resp):
|
|
||||||
with pytest.raises(RuntimeError, match="Failed to fetch sequence data"):
|
|
||||||
node.load_dynamic("http://localhost:8080", "proj1", "batch", 1)
|
|
||||||
|
|
||||||
def test_input_types_has_manager_url(self):
|
|
||||||
inputs = ProjectLoaderDynamic.INPUT_TYPES()
|
|
||||||
assert "manager_url" in inputs["required"]
|
|
||||||
assert "project_name" in inputs["required"]
|
|
||||||
assert "file_name" in inputs["required"]
|
|
||||||
assert "sequence_number" in inputs["required"]
|
|
||||||
|
|
||||||
def test_category(self):
|
|
||||||
assert ProjectLoaderDynamic.CATEGORY == "JSON Manager/project"
|
|
||||||
|
|
||||||
|
|
||||||
class TestProjectSource:
|
|
||||||
def test_input_types(self):
|
|
||||||
from project_loader import ProjectSource
|
|
||||||
inputs = ProjectSource.INPUT_TYPES()
|
|
||||||
assert "manager_url" in inputs["required"]
|
|
||||||
assert "project_name" in inputs["required"]
|
|
||||||
assert "file_name" in inputs["required"]
|
|
||||||
assert "sequence_number" in inputs["required"]
|
|
||||||
assert "label" in inputs["required"]
|
|
||||||
|
|
||||||
def test_outputs_sequence_number(self):
|
|
||||||
from project_loader import ProjectSource
|
|
||||||
assert ProjectSource.RETURN_TYPES == ("INT", "STRING",)
|
|
||||||
assert ProjectSource.RETURN_NAMES == ("sequence_number", "file_name",)
|
|
||||||
|
|
||||||
def test_hold_config_returns_sequence_number(self):
|
|
||||||
from project_loader import ProjectSource
|
|
||||||
node = ProjectSource()
|
|
||||||
result = node.hold_config(
|
|
||||||
manager_url="http://localhost:8080",
|
|
||||||
project_name="proj1",
|
|
||||||
file_name="batch_i2v",
|
|
||||||
sequence_number=42,
|
|
||||||
label="my_source"
|
|
||||||
)
|
|
||||||
assert result == (42, "batch_i2v")
|
|
||||||
|
|
||||||
def test_category(self):
|
|
||||||
from project_loader import ProjectSource
|
|
||||||
assert ProjectSource.CATEGORY == "JSON Manager/project"
|
|
||||||
|
|
||||||
|
|
||||||
class TestProjectKey:
|
|
||||||
def test_input_types(self):
|
|
||||||
from project_loader import ProjectKey
|
|
||||||
inputs = ProjectKey.INPUT_TYPES()
|
|
||||||
assert "source_label" in inputs["required"]
|
|
||||||
assert "key_name" in inputs["required"]
|
|
||||||
assert "key_type" in inputs["required"]
|
|
||||||
|
|
||||||
def test_single_output(self):
|
|
||||||
from project_loader import ProjectKey
|
|
||||||
assert len(ProjectKey.RETURN_TYPES) == 1
|
|
||||||
assert len(ProjectKey.RETURN_NAMES) == 1
|
|
||||||
|
|
||||||
def test_fetch_key_string(self):
|
|
||||||
from project_loader import ProjectKey
|
|
||||||
node = ProjectKey()
|
|
||||||
data = {"prompt": "hello", "seed": 42}
|
|
||||||
with patch("project_loader._fetch_data", return_value=data):
|
|
||||||
result = node.fetch_key(
|
|
||||||
source_label="my_source",
|
|
||||||
key_name="prompt",
|
|
||||||
key_type="STRING",
|
|
||||||
manager_url="http://localhost:8080",
|
|
||||||
project_name="proj1",
|
|
||||||
file_name="batch_i2v",
|
|
||||||
sequence_number=1,
|
|
||||||
)
|
|
||||||
assert result == ("hello",)
|
|
||||||
|
|
||||||
def test_fetch_key_int_coercion(self):
|
|
||||||
from project_loader import ProjectKey
|
|
||||||
node = ProjectKey()
|
|
||||||
data = {"seed": "42"}
|
|
||||||
with patch("project_loader._fetch_data", return_value=data):
|
|
||||||
result = node.fetch_key(
|
|
||||||
source_label="my_source",
|
|
||||||
key_name="seed",
|
|
||||||
key_type="INT",
|
|
||||||
manager_url="http://localhost:8080",
|
|
||||||
project_name="proj1",
|
|
||||||
file_name="batch_i2v",
|
|
||||||
sequence_number=1,
|
|
||||||
)
|
|
||||||
assert result == (42,)
|
|
||||||
|
|
||||||
def test_fetch_key_float_coercion(self):
|
|
||||||
from project_loader import ProjectKey
|
|
||||||
node = ProjectKey()
|
|
||||||
data = {"cfg": "1.5"}
|
|
||||||
with patch("project_loader._fetch_data", return_value=data):
|
|
||||||
result = node.fetch_key(
|
|
||||||
source_label="my_source",
|
|
||||||
key_name="cfg",
|
|
||||||
key_type="FLOAT",
|
|
||||||
manager_url="http://localhost:8080",
|
|
||||||
project_name="proj1",
|
|
||||||
file_name="batch_i2v",
|
|
||||||
sequence_number=1,
|
|
||||||
)
|
|
||||||
assert result == (1.5,)
|
|
||||||
|
|
||||||
def test_fetch_key_missing_key(self):
|
|
||||||
from project_loader import ProjectKey
|
|
||||||
node = ProjectKey()
|
|
||||||
with patch("project_loader._fetch_data", return_value={}):
|
|
||||||
result = node.fetch_key(
|
|
||||||
source_label="my_source",
|
|
||||||
key_name="nonexistent",
|
|
||||||
key_type="STRING",
|
|
||||||
manager_url="http://localhost:8080",
|
|
||||||
project_name="proj1",
|
|
||||||
file_name="batch_i2v",
|
|
||||||
sequence_number=1,
|
|
||||||
)
|
|
||||||
assert result == ("",)
|
|
||||||
|
|
||||||
def test_fetch_key_network_error_returns_default(self):
|
|
||||||
from project_loader import ProjectKey
|
|
||||||
node = ProjectKey()
|
|
||||||
error_resp = {"error": "network_error", "message": "Connection refused"}
|
|
||||||
with patch("project_loader._fetch_data", return_value=error_resp):
|
|
||||||
result = node.fetch_key(
|
|
||||||
source_label="my_source",
|
|
||||||
key_name="prompt",
|
|
||||||
key_type="STRING",
|
|
||||||
manager_url="http://localhost:8080",
|
|
||||||
project_name="proj1",
|
|
||||||
file_name="batch_i2v",
|
|
||||||
sequence_number=1,
|
|
||||||
)
|
|
||||||
assert result == ("",)
|
|
||||||
|
|
||||||
def test_fetch_key_error_returns_int_default(self):
|
|
||||||
from project_loader import ProjectKey
|
|
||||||
node = ProjectKey()
|
|
||||||
error_resp = {"error": "http_error", "status": 404, "message": "Not found"}
|
|
||||||
with patch("project_loader._fetch_data", return_value=error_resp):
|
|
||||||
result = node.fetch_key(
|
|
||||||
source_label="s", key_name="seed", key_type="INT",
|
|
||||||
manager_url="http://localhost:8080", project_name="p",
|
|
||||||
file_name="f", sequence_number=1,
|
|
||||||
)
|
|
||||||
assert result == (0,)
|
|
||||||
|
|
||||||
def test_category(self):
|
|
||||||
from project_loader import ProjectKey
|
|
||||||
assert ProjectKey.CATEGORY == "JSON Manager/project"
|
|
||||||
|
|
||||||
|
|
||||||
class TestProjectResolution:
|
|
||||||
def test_input_types(self):
|
|
||||||
from project_loader import ProjectResolution
|
|
||||||
inputs = ProjectResolution.INPUT_TYPES()
|
|
||||||
assert "source_label" in inputs["required"]
|
|
||||||
assert "key_name" in inputs["required"]
|
|
||||||
assert "index" in inputs["required"]
|
|
||||||
assert inputs["required"]["index"][0] == "INT"
|
|
||||||
|
|
||||||
def test_three_outputs(self):
|
|
||||||
from project_loader import ProjectResolution
|
|
||||||
assert ProjectResolution.RETURN_TYPES == ("INT", "INT", "INT")
|
|
||||||
assert ProjectResolution.RETURN_NAMES == ("width", "height", "seed")
|
|
||||||
|
|
||||||
def test_fetch_resolution_basic(self):
|
|
||||||
from project_loader import ProjectResolution
|
|
||||||
node = ProjectResolution()
|
|
||||||
data = {"resolutions": [[512, 512, 0], [768, 1344, 12345], [1344, 768, 99]]}
|
|
||||||
with patch("project_loader._fetch_data", return_value=data):
|
|
||||||
result = node.fetch_resolution(
|
|
||||||
source_label="src", key_name="resolutions", index=1,
|
|
||||||
manager_url="http://localhost:8080", project_name="p",
|
|
||||||
file_name="f", sequence_number=1,
|
|
||||||
)
|
|
||||||
assert result == (768, 1344, 12345)
|
|
||||||
|
|
||||||
def test_fetch_resolution_index_zero(self):
|
|
||||||
from project_loader import ProjectResolution
|
|
||||||
node = ProjectResolution()
|
|
||||||
data = {"resolutions": [[512, 512, 42], [1024, 1024, 0]]}
|
|
||||||
with patch("project_loader._fetch_data", return_value=data):
|
|
||||||
result = node.fetch_resolution(
|
|
||||||
source_label="src", key_name="resolutions", index=0,
|
|
||||||
manager_url="http://localhost:8080", project_name="p",
|
|
||||||
file_name="f", sequence_number=1,
|
|
||||||
)
|
|
||||||
assert result == (512, 512, 42)
|
|
||||||
|
|
||||||
def test_fetch_resolution_clamps_on_out_of_bounds(self):
|
|
||||||
from project_loader import ProjectResolution
|
|
||||||
node = ProjectResolution()
|
|
||||||
data = {"resolutions": [[512, 512, 0], [1024, 1024, 7]]}
|
|
||||||
with patch("project_loader._fetch_data", return_value=data):
|
|
||||||
result = node.fetch_resolution(
|
|
||||||
source_label="src", key_name="resolutions", index=99,
|
|
||||||
manager_url="http://localhost:8080", project_name="p",
|
|
||||||
file_name="f", sequence_number=1,
|
|
||||||
)
|
|
||||||
assert result == (1024, 1024, 7) # last entry
|
|
||||||
|
|
||||||
def test_fetch_resolution_old_format_no_seed(self):
|
|
||||||
"""Old [w, h] entries without seed should return seed=0."""
|
|
||||||
from project_loader import ProjectResolution
|
|
||||||
node = ProjectResolution()
|
|
||||||
data = {"resolutions": [[576, 384], [960, 640]]}
|
|
||||||
with patch("project_loader._fetch_data", return_value=data):
|
|
||||||
result = node.fetch_resolution(
|
|
||||||
source_label="src", key_name="resolutions", index=0,
|
|
||||||
manager_url="http://localhost:8080", project_name="p",
|
|
||||||
file_name="f", sequence_number=1,
|
|
||||||
)
|
|
||||||
assert result == (576, 384, 0)
|
|
||||||
|
|
||||||
def test_fetch_resolution_missing_key_returns_defaults(self):
|
|
||||||
from project_loader import ProjectResolution
|
|
||||||
node = ProjectResolution()
|
|
||||||
with patch("project_loader._fetch_data", return_value={}):
|
|
||||||
result = node.fetch_resolution(
|
|
||||||
source_label="src", key_name="nonexistent", index=0,
|
|
||||||
manager_url="http://localhost:8080", project_name="p",
|
|
||||||
file_name="f", sequence_number=1,
|
|
||||||
)
|
|
||||||
assert result == (512, 512, 0)
|
|
||||||
|
|
||||||
def test_fetch_resolution_network_error_returns_defaults(self):
|
|
||||||
from project_loader import ProjectResolution
|
|
||||||
node = ProjectResolution()
|
|
||||||
error_resp = {"error": "network_error", "message": "Connection refused"}
|
|
||||||
with patch("project_loader._fetch_data", return_value=error_resp):
|
|
||||||
result = node.fetch_resolution(
|
|
||||||
source_label="src", key_name="resolutions", index=0,
|
|
||||||
manager_url="http://localhost:8080", project_name="p",
|
|
||||||
file_name="f", sequence_number=1,
|
|
||||||
)
|
|
||||||
assert result == (512, 512, 0)
|
|
||||||
|
|
||||||
def test_fetch_resolution_malformed_entry_returns_defaults(self):
|
|
||||||
from project_loader import ProjectResolution
|
|
||||||
node = ProjectResolution()
|
|
||||||
data = {"resolutions": [[512]]} # single-element, not a valid pair
|
|
||||||
with patch("project_loader._fetch_data", return_value=data):
|
|
||||||
result = node.fetch_resolution(
|
|
||||||
source_label="src", key_name="resolutions", index=0,
|
|
||||||
manager_url="http://localhost:8080", project_name="p",
|
|
||||||
file_name="f", sequence_number=1,
|
|
||||||
)
|
|
||||||
assert result == (512, 512, 0)
|
|
||||||
|
|
||||||
def test_category(self):
|
|
||||||
from project_loader import ProjectResolution
|
|
||||||
assert ProjectResolution.CATEGORY == "JSON Manager/project"
|
|
||||||
|
|
||||||
|
|
||||||
class TestBinaryIndexDecoder:
|
|
||||||
def test_input_types(self):
|
|
||||||
from project_loader import BinaryIndexDecoder
|
|
||||||
inputs = BinaryIndexDecoder.INPUT_TYPES()
|
|
||||||
assert "index" in inputs["required"]
|
|
||||||
assert inputs["required"]["index"][0] == "INT"
|
|
||||||
|
|
||||||
def test_three_boolean_outputs(self):
|
|
||||||
from project_loader import BinaryIndexDecoder
|
|
||||||
assert BinaryIndexDecoder.RETURN_TYPES == ("BOOLEAN", "BOOLEAN", "BOOLEAN")
|
|
||||||
assert BinaryIndexDecoder.RETURN_NAMES == ("flag_0", "flag_1", "flag_2")
|
|
||||||
|
|
||||||
def test_category(self):
|
|
||||||
from project_loader import BinaryIndexDecoder
|
|
||||||
assert BinaryIndexDecoder.CATEGORY == "JSON Manager/utils"
|
|
||||||
|
|
||||||
def test_index_0(self):
|
|
||||||
from project_loader import BinaryIndexDecoder
|
|
||||||
assert BinaryIndexDecoder().decode(0) == (False, False, False)
|
|
||||||
|
|
||||||
def test_index_1(self):
|
|
||||||
from project_loader import BinaryIndexDecoder
|
|
||||||
assert BinaryIndexDecoder().decode(1) == (True, False, False)
|
|
||||||
|
|
||||||
def test_index_2(self):
|
|
||||||
from project_loader import BinaryIndexDecoder
|
|
||||||
assert BinaryIndexDecoder().decode(2) == (False, True, False)
|
|
||||||
|
|
||||||
def test_index_3(self):
|
|
||||||
from project_loader import BinaryIndexDecoder
|
|
||||||
assert BinaryIndexDecoder().decode(3) == (True, True, False)
|
|
||||||
|
|
||||||
def test_index_4(self):
|
|
||||||
from project_loader import BinaryIndexDecoder
|
|
||||||
assert BinaryIndexDecoder().decode(4) == (False, False, True)
|
|
||||||
|
|
||||||
def test_index_5(self):
|
|
||||||
from project_loader import BinaryIndexDecoder
|
|
||||||
assert BinaryIndexDecoder().decode(5) == (True, False, True)
|
|
||||||
|
|
||||||
def test_index_6(self):
|
|
||||||
from project_loader import BinaryIndexDecoder
|
|
||||||
assert BinaryIndexDecoder().decode(6) == (False, True, True)
|
|
||||||
|
|
||||||
def test_index_7(self):
|
|
||||||
from project_loader import BinaryIndexDecoder
|
|
||||||
assert BinaryIndexDecoder().decode(7) == (True, True, True)
|
|
||||||
|
|
||||||
|
|
||||||
class TestNodeMappings:
|
|
||||||
def test_mappings_exist(self):
|
|
||||||
from project_loader import PROJECT_NODE_CLASS_MAPPINGS, PROJECT_NODE_DISPLAY_NAME_MAPPINGS
|
|
||||||
assert "ProjectLoaderDynamic" in PROJECT_NODE_CLASS_MAPPINGS
|
|
||||||
assert "ProjectSource" in PROJECT_NODE_CLASS_MAPPINGS
|
|
||||||
assert "ProjectKey" in PROJECT_NODE_CLASS_MAPPINGS
|
|
||||||
assert "ProjectResolution" in PROJECT_NODE_CLASS_MAPPINGS
|
|
||||||
assert "BinaryIndexDecoder" in PROJECT_NODE_CLASS_MAPPINGS
|
|
||||||
assert len(PROJECT_NODE_CLASS_MAPPINGS) == 5
|
|
||||||
assert len(PROJECT_NODE_DISPLAY_NAME_MAPPINGS) == 5
|
|
||||||
@@ -1,159 +0,0 @@
|
|||||||
import pytest
|
|
||||||
from snapshot_timeline import SnapshotTimeline, diff_snapshots
|
|
||||||
|
|
||||||
|
|
||||||
def test_record_creates_snapshot():
|
|
||||||
tl = SnapshotTimeline({})
|
|
||||||
sid = tl.record({"batch_data": [{"seed": 42}]}, note="first")
|
|
||||||
assert sid in tl.snapshots
|
|
||||||
assert tl.current_id == sid
|
|
||||||
assert tl.snapshots[sid]["note"] == "first"
|
|
||||||
assert tl.snapshots[sid]["auto"] is False
|
|
||||||
assert tl.snapshots[sid]["seq_count"] == 1
|
|
||||||
|
|
||||||
|
|
||||||
def test_record_auto_flag():
|
|
||||||
tl = SnapshotTimeline({})
|
|
||||||
sid = tl.record({"batch_data": []}, note="auto save", auto=True)
|
|
||||||
assert tl.snapshots[sid]["auto"] is True
|
|
||||||
|
|
||||||
|
|
||||||
def test_multiple_records():
|
|
||||||
tl = SnapshotTimeline({})
|
|
||||||
id1 = tl.record({"batch_data": [{"a": 1}]}, note="one")
|
|
||||||
id2 = tl.record({"batch_data": [{"b": 2}]}, note="two")
|
|
||||||
assert len(tl.snapshots) == 2
|
|
||||||
assert tl.current_id == id2
|
|
||||||
|
|
||||||
|
|
||||||
def test_to_dict_roundtrip():
|
|
||||||
tl = SnapshotTimeline({})
|
|
||||||
tl.record({"batch_data": [{"x": 1}]}, note="test")
|
|
||||||
d = tl.to_dict()
|
|
||||||
tl2 = SnapshotTimeline(d)
|
|
||||||
assert tl2.current_id == tl.current_id
|
|
||||||
assert set(tl2.snapshots.keys()) == set(tl.snapshots.keys())
|
|
||||||
|
|
||||||
|
|
||||||
def test_migrate_from_history_tree():
|
|
||||||
"""Old HistoryTree format should be flattened into snapshots."""
|
|
||||||
old_data = {
|
|
||||||
"nodes": {
|
|
||||||
"aaa": {"id": "aaa", "parent": None, "timestamp": 1000, "note": "First", "data": {"batch_data": [{"seed": 1}]}},
|
|
||||||
"bbb": {"id": "bbb", "parent": "aaa", "timestamp": 2000, "note": "Second", "data": {"batch_data": [{"seed": 2}]}},
|
|
||||||
},
|
|
||||||
"branches": {"main": "bbb"},
|
|
||||||
"head_id": "bbb",
|
|
||||||
}
|
|
||||||
tl = SnapshotTimeline(old_data)
|
|
||||||
assert len(tl.snapshots) == 2
|
|
||||||
assert tl.current_id == "bbb"
|
|
||||||
assert tl.snapshots["aaa"]["note"] == "First"
|
|
||||||
assert tl.snapshots["bbb"]["note"] == "Second"
|
|
||||||
# Data should be preserved
|
|
||||||
assert tl.snapshots["aaa"]["data"]["batch_data"] == [{"seed": 1}]
|
|
||||||
|
|
||||||
|
|
||||||
def test_migrate_from_history_tree_no_data():
|
|
||||||
"""Slim tree nodes (no inline data) should still migrate."""
|
|
||||||
old_data = {
|
|
||||||
"nodes": {
|
|
||||||
"aaa": {"id": "aaa", "parent": None, "timestamp": 1000, "note": "First"},
|
|
||||||
},
|
|
||||||
"branches": {"main": "aaa"},
|
|
||||||
"head_id": "aaa",
|
|
||||||
}
|
|
||||||
tl = SnapshotTimeline(old_data)
|
|
||||||
assert len(tl.snapshots) == 1
|
|
||||||
assert tl.snapshots["aaa"]["seq_count"] == 0
|
|
||||||
|
|
||||||
|
|
||||||
def test_migrate_legacy_prompt_history():
|
|
||||||
legacy = {
|
|
||||||
"prompt_history": [
|
|
||||||
{"note": "A", "seed": 1},
|
|
||||||
{"note": "B", "seed": 2},
|
|
||||||
]
|
|
||||||
}
|
|
||||||
tl = SnapshotTimeline(legacy)
|
|
||||||
assert len(tl.snapshots) == 2
|
|
||||||
assert tl.current_id is not None
|
|
||||||
|
|
||||||
|
|
||||||
def test_toggle_pin():
|
|
||||||
tl = SnapshotTimeline({})
|
|
||||||
sid = tl.record({"batch_data": []}, note="test")
|
|
||||||
assert tl.snapshots[sid]["pinned"] is False
|
|
||||||
result = tl.toggle_pin(sid)
|
|
||||||
assert result is True
|
|
||||||
assert tl.snapshots[sid]["pinned"] is True
|
|
||||||
result = tl.toggle_pin(sid)
|
|
||||||
assert result is False
|
|
||||||
|
|
||||||
|
|
||||||
def test_delete_snapshot():
|
|
||||||
tl = SnapshotTimeline({})
|
|
||||||
id1 = tl.record({"batch_data": []}, note="one")
|
|
||||||
id2 = tl.record({"batch_data": []}, note="two")
|
|
||||||
tl.delete(id2)
|
|
||||||
assert id2 not in tl.snapshots
|
|
||||||
assert tl.current_id == id1
|
|
||||||
|
|
||||||
|
|
||||||
def test_delete_all_snapshots():
|
|
||||||
tl = SnapshotTimeline({})
|
|
||||||
sid = tl.record({"batch_data": []}, note="only")
|
|
||||||
tl.delete(sid)
|
|
||||||
assert len(tl.snapshots) == 0
|
|
||||||
assert tl.current_id is None
|
|
||||||
|
|
||||||
|
|
||||||
def test_strip_snapshots():
|
|
||||||
tl = SnapshotTimeline({})
|
|
||||||
tl.record({"batch_data": [{"a": 1}]}, note="test")
|
|
||||||
tl.strip_snapshots()
|
|
||||||
for snap in tl.snapshots.values():
|
|
||||||
assert "data" not in snap
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_snapshot_data():
|
|
||||||
tl = SnapshotTimeline({})
|
|
||||||
sid = tl.record({"batch_data": [{"x": 1}]}, note="test")
|
|
||||||
data = tl.get_snapshot_data(sid)
|
|
||||||
assert data == {"batch_data": [{"x": 1}]}
|
|
||||||
assert tl.get_snapshot_data("nonexistent") is None
|
|
||||||
|
|
||||||
|
|
||||||
# --- diff_snapshots tests ---
|
|
||||||
|
|
||||||
def test_diff_unchanged():
|
|
||||||
batch = [{"sequence_number": 1, "seed": 42}]
|
|
||||||
result = diff_snapshots(batch, batch)
|
|
||||||
assert len(result) == 1
|
|
||||||
assert result[0]["status"] == "unchanged"
|
|
||||||
assert result[0]["changes"] == []
|
|
||||||
|
|
||||||
|
|
||||||
def test_diff_changed():
|
|
||||||
old = [{"sequence_number": 1, "seed": 42, "cfg": 1.5}]
|
|
||||||
new = [{"sequence_number": 1, "seed": 99, "cfg": 1.5}]
|
|
||||||
result = diff_snapshots(old, new)
|
|
||||||
assert result[0]["status"] == "changed"
|
|
||||||
assert len(result[0]["changes"]) == 1
|
|
||||||
assert result[0]["changes"][0]["field"] == "seed"
|
|
||||||
assert result[0]["changes"][0]["old"] == 42
|
|
||||||
assert result[0]["changes"][0]["new"] == 99
|
|
||||||
|
|
||||||
|
|
||||||
def test_diff_added_and_removed():
|
|
||||||
old = [{"sequence_number": 1, "seed": 1}]
|
|
||||||
new = [{"sequence_number": 2, "seed": 2}]
|
|
||||||
result = diff_snapshots(old, new)
|
|
||||||
assert len(result) == 2
|
|
||||||
statuses = {r["seq_num"]: r["status"] for r in result}
|
|
||||||
assert statuses[1] == "removed"
|
|
||||||
assert statuses[2] == "added"
|
|
||||||
|
|
||||||
|
|
||||||
def test_diff_empty():
|
|
||||||
assert diff_snapshots([], []) == []
|
|
||||||
@@ -1,94 +0,0 @@
|
|||||||
import json
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from utils import load_json, save_json, get_file_mtime, ALLOWED_BASE_DIR, DEFAULTS, resolve_path_case_insensitive
|
|
||||||
|
|
||||||
|
|
||||||
def test_load_json_valid(tmp_path):
|
|
||||||
p = tmp_path / "test.json"
|
|
||||||
data = {"key": "value"}
|
|
||||||
p.write_text(json.dumps(data))
|
|
||||||
result, mtime = load_json(p)
|
|
||||||
assert result == data
|
|
||||||
assert mtime > 0
|
|
||||||
|
|
||||||
|
|
||||||
def test_load_json_missing(tmp_path):
|
|
||||||
p = tmp_path / "nope.json"
|
|
||||||
result, mtime = load_json(p)
|
|
||||||
assert result == DEFAULTS.copy()
|
|
||||||
assert mtime == 0
|
|
||||||
|
|
||||||
|
|
||||||
def test_load_json_invalid(tmp_path):
|
|
||||||
p = tmp_path / "bad.json"
|
|
||||||
p.write_text("{not valid json")
|
|
||||||
result, mtime = load_json(p)
|
|
||||||
assert result == DEFAULTS.copy()
|
|
||||||
assert mtime == 0
|
|
||||||
|
|
||||||
|
|
||||||
def test_save_json_atomic(tmp_path):
|
|
||||||
p = tmp_path / "out.json"
|
|
||||||
data = {"hello": "world"}
|
|
||||||
save_json(p, data)
|
|
||||||
assert p.exists()
|
|
||||||
assert not p.with_suffix(".json.tmp").exists()
|
|
||||||
assert json.loads(p.read_text()) == data
|
|
||||||
|
|
||||||
|
|
||||||
def test_save_json_overwrites(tmp_path):
|
|
||||||
p = tmp_path / "out.json"
|
|
||||||
save_json(p, {"a": 1})
|
|
||||||
save_json(p, {"b": 2})
|
|
||||||
assert json.loads(p.read_text()) == {"b": 2}
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_file_mtime_existing(tmp_path):
|
|
||||||
p = tmp_path / "f.txt"
|
|
||||||
p.write_text("x")
|
|
||||||
assert get_file_mtime(p) > 0
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_file_mtime_missing(tmp_path):
|
|
||||||
assert get_file_mtime(tmp_path / "missing.txt") == 0
|
|
||||||
|
|
||||||
|
|
||||||
def test_allowed_base_dir_is_set():
|
|
||||||
assert ALLOWED_BASE_DIR is not None
|
|
||||||
assert isinstance(ALLOWED_BASE_DIR, Path)
|
|
||||||
|
|
||||||
|
|
||||||
class TestResolvePathCaseInsensitive:
|
|
||||||
def test_exact_match(self, tmp_path):
|
|
||||||
d = tmp_path / "MyFolder"
|
|
||||||
d.mkdir()
|
|
||||||
result = resolve_path_case_insensitive(str(d))
|
|
||||||
assert result == d.resolve()
|
|
||||||
|
|
||||||
def test_wrong_case_single_component(self, tmp_path):
|
|
||||||
d = tmp_path / "MyFolder"
|
|
||||||
d.mkdir()
|
|
||||||
wrong = tmp_path / "myfolder"
|
|
||||||
result = resolve_path_case_insensitive(str(wrong))
|
|
||||||
assert result == d.resolve()
|
|
||||||
|
|
||||||
def test_wrong_case_nested(self, tmp_path):
|
|
||||||
d = tmp_path / "Parent" / "Child"
|
|
||||||
d.mkdir(parents=True)
|
|
||||||
wrong = tmp_path / "parent" / "CHILD"
|
|
||||||
result = resolve_path_case_insensitive(str(wrong))
|
|
||||||
assert result == d.resolve()
|
|
||||||
|
|
||||||
def test_no_match_returns_none(self, tmp_path):
|
|
||||||
result = resolve_path_case_insensitive(str(tmp_path / "nonexistent"))
|
|
||||||
assert result is None
|
|
||||||
|
|
||||||
def test_file_path(self, tmp_path):
|
|
||||||
f = tmp_path / "Data.json"
|
|
||||||
f.write_text("{}")
|
|
||||||
wrong = tmp_path / "data.JSON"
|
|
||||||
result = resolve_path_case_insensitive(str(wrong))
|
|
||||||
assert result == f.resolve()
|
|
||||||
@@ -1,121 +1,56 @@
|
|||||||
import copy
|
|
||||||
import json
|
import json
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
import streamlit as st
|
||||||
|
|
||||||
# --- Magic String Keys ---
|
|
||||||
KEY_BATCH_DATA = "batch_data"
|
|
||||||
KEY_HISTORY_TREE = "history_tree"
|
|
||||||
KEY_PROMPT_HISTORY = "prompt_history"
|
|
||||||
KEY_SEQUENCE_NUMBER = "sequence_number"
|
|
||||||
|
|
||||||
# Configure logging for the application
|
|
||||||
logging.basicConfig(
|
|
||||||
level=logging.INFO,
|
|
||||||
format="%(asctime)s [%(name)s] %(levelname)s: %(message)s",
|
|
||||||
datefmt="%H:%M:%S",
|
|
||||||
)
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# Default structure for new files
|
# Default structure for new files
|
||||||
DEFAULTS = {
|
DEFAULTS = {
|
||||||
# --- Prompts ---
|
# --- Standard Keys for your Restored Single Tab ---
|
||||||
"general_prompt": "",
|
"general_prompt": "", # Global positive
|
||||||
"general_negative": "Vivid tones, overexposed, static, blurry details, subtitles, style, artwork, painting, picture, still image, overall gray, worst quality, low quality, JPEG compression artifacts, ugly, deformed, extra fingers, poorly drawn hands, poorly drawn face, distorted, disfigured, malformed limbs, fused fingers, unmoving frame, cluttered background, three legs",
|
"general_negative": "", # Global negative
|
||||||
"current_prompt": "",
|
"current_prompt": "", # Specific positive
|
||||||
"negative": "",
|
"negative": "", # Specific negative
|
||||||
"seed": -1,
|
"seed": -1,
|
||||||
|
|
||||||
# --- Settings ---
|
# --- Settings ---
|
||||||
"mode": 0,
|
|
||||||
"camera": "static",
|
"camera": "static",
|
||||||
|
"flf": 0.0,
|
||||||
|
"steps": 20,
|
||||||
|
"cfg": 7.0,
|
||||||
|
"sampler_name": "euler",
|
||||||
|
"scheduler": "normal",
|
||||||
|
"denoise": 1.0,
|
||||||
|
"model_name": "v1-5-pruned-emaonly.ckpt",
|
||||||
|
"vae_name": "vae-ft-mse-840000-ema-pruned.ckpt",
|
||||||
|
|
||||||
# --- I2V / VACE Specifics ---
|
# --- I2V / VACE Specifics ---
|
||||||
"frame_to_skip": 81,
|
"frame_to_skip": 81,
|
||||||
"logic index": 0,
|
|
||||||
"transition": "1-2",
|
|
||||||
"vace_length": 49,
|
|
||||||
"vace schedule": 1,
|
"vace schedule": 1,
|
||||||
"input_a_frames": 16,
|
"input_a_frames": 0,
|
||||||
"input_b_frames": 16,
|
"input_b_frames": 0,
|
||||||
"reference switch": 1,
|
"reference switch": 1,
|
||||||
"video file path": "",
|
"video file path": "",
|
||||||
"start frame path": "",
|
"reference image path": "",
|
||||||
"start frame high strength": 1.0,
|
"reference path": "",
|
||||||
"start frame low strength": 1.0,
|
"flf image path": "",
|
||||||
"middle frame path": "",
|
|
||||||
"middle frame high strength": 1.0,
|
|
||||||
"middle frame low strength": 1.0,
|
|
||||||
"end frame path": "",
|
|
||||||
"end frame high strength": 1.0,
|
|
||||||
"end frame low strength": 1.0,
|
|
||||||
|
|
||||||
# --- LoRAs (name as STRING, strength as FLOAT) ---
|
# --- LoRAs ---
|
||||||
"lora 1 high": "",
|
"lora 1 high": "", "lora 1 low": "",
|
||||||
"lora 1 high strength": 1.0,
|
"lora 2 high": "", "lora 2 low": "",
|
||||||
"lora 1 low": "",
|
"lora 3 high": "", "lora 3 low": ""
|
||||||
"lora 1 low strength": 1.0,
|
|
||||||
"lora 2 high": "",
|
|
||||||
"lora 2 high strength": 1.0,
|
|
||||||
"lora 2 low": "",
|
|
||||||
"lora 2 low strength": 1.0,
|
|
||||||
"lora 3 high": "",
|
|
||||||
"lora 3 high strength": 1.0,
|
|
||||||
"lora 3 low": "",
|
|
||||||
"lora 3 low strength": 1.0
|
|
||||||
}
|
}
|
||||||
|
|
||||||
CONFIG_FILE = Path(".editor_config.json")
|
CONFIG_FILE = Path(".editor_config.json")
|
||||||
SNIPPETS_FILE = Path(".editor_snippets.json")
|
SNIPPETS_FILE = Path(".editor_snippets.json")
|
||||||
|
|
||||||
# No restriction on directory navigation
|
|
||||||
ALLOWED_BASE_DIR = Path("/").resolve()
|
|
||||||
|
|
||||||
def resolve_path_case_insensitive(path: str | Path) -> Path | None:
|
|
||||||
"""Resolve a path with case-insensitive component matching on Linux.
|
|
||||||
|
|
||||||
Walks each component of the path and matches against actual directory
|
|
||||||
entries when an exact match fails. Returns the corrected Path, or None
|
|
||||||
if no match is found.
|
|
||||||
"""
|
|
||||||
p = Path(path)
|
|
||||||
if p.exists():
|
|
||||||
return p.resolve()
|
|
||||||
|
|
||||||
# Start from the root / anchor
|
|
||||||
parts = p.resolve().parts # resolve to get absolute parts
|
|
||||||
built = Path(parts[0]) # root "/"
|
|
||||||
for component in parts[1:]:
|
|
||||||
candidate = built / component
|
|
||||||
if candidate.exists():
|
|
||||||
built = candidate
|
|
||||||
continue
|
|
||||||
# Case-insensitive scan of the parent directory
|
|
||||||
try:
|
|
||||||
lower = component.lower()
|
|
||||||
match = next(
|
|
||||||
(entry for entry in built.iterdir() if entry.name.lower() == lower),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
except PermissionError:
|
|
||||||
return None
|
|
||||||
if match is None:
|
|
||||||
return None
|
|
||||||
built = match
|
|
||||||
return built.resolve()
|
|
||||||
|
|
||||||
|
|
||||||
def load_config():
|
def load_config():
|
||||||
"""Loads the main editor configuration (Favorites, Last Dir, Servers)."""
|
"""Loads the main editor configuration (Favorites, Last Dir, Servers)."""
|
||||||
if CONFIG_FILE.exists():
|
if CONFIG_FILE.exists():
|
||||||
try:
|
try:
|
||||||
with open(CONFIG_FILE, 'r') as f:
|
with open(CONFIG_FILE, 'r') as f:
|
||||||
return json.load(f)
|
return json.load(f)
|
||||||
except (json.JSONDecodeError, IOError) as e:
|
except:
|
||||||
logger.warning(f"Failed to load config: {e}")
|
pass
|
||||||
return {"favorites": [], "last_dir": str(Path.cwd()), "comfy_instances": []}
|
return {"favorites": [], "last_dir": str(Path.cwd()), "comfy_instances": []}
|
||||||
|
|
||||||
def save_config(current_dir, favorites, extra_data=None):
|
def save_config(current_dir, favorites, extra_data=None):
|
||||||
@@ -127,290 +62,54 @@ def save_config(current_dir, favorites, extra_data=None):
|
|||||||
existing = load_config()
|
existing = load_config()
|
||||||
data.update(existing)
|
data.update(existing)
|
||||||
|
|
||||||
if extra_data:
|
|
||||||
data.update(extra_data)
|
|
||||||
|
|
||||||
# Force-set explicit params last so extra_data can't override them
|
|
||||||
data["last_dir"] = str(current_dir)
|
data["last_dir"] = str(current_dir)
|
||||||
data["favorites"] = favorites
|
data["favorites"] = favorites
|
||||||
|
|
||||||
tmp = CONFIG_FILE.with_suffix('.json.tmp')
|
if extra_data:
|
||||||
with open(tmp, 'w') as f:
|
data.update(extra_data)
|
||||||
|
|
||||||
|
with open(CONFIG_FILE, 'w') as f:
|
||||||
json.dump(data, f, indent=4)
|
json.dump(data, f, indent=4)
|
||||||
os.replace(tmp, CONFIG_FILE)
|
|
||||||
|
|
||||||
def load_snippets():
|
def load_snippets():
|
||||||
if SNIPPETS_FILE.exists():
|
if SNIPPETS_FILE.exists():
|
||||||
try:
|
try:
|
||||||
with open(SNIPPETS_FILE, 'r') as f:
|
with open(SNIPPETS_FILE, 'r') as f:
|
||||||
return json.load(f)
|
return json.load(f)
|
||||||
except (json.JSONDecodeError, IOError) as e:
|
except:
|
||||||
logger.warning(f"Failed to load snippets: {e}")
|
pass
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
def save_snippets(snippets):
|
def save_snippets(snippets):
|
||||||
tmp = SNIPPETS_FILE.with_suffix('.json.tmp')
|
with open(SNIPPETS_FILE, 'w') as f:
|
||||||
with open(tmp, 'w') as f:
|
|
||||||
json.dump(snippets, f, indent=4)
|
json.dump(snippets, f, indent=4)
|
||||||
os.replace(tmp, SNIPPETS_FILE)
|
|
||||||
|
|
||||||
_REMOVED_KEYS = {"cfg", "flf", "end_frame"}
|
def load_json(path):
|
||||||
|
|
||||||
def _migrate_remove_keys(data: dict) -> None:
|
|
||||||
"""Drop keys that have been removed from the schema."""
|
|
||||||
for item in data.get(KEY_BATCH_DATA, []):
|
|
||||||
if not isinstance(item, dict):
|
|
||||||
continue
|
|
||||||
for k in _REMOVED_KEYS:
|
|
||||||
item.pop(k, None)
|
|
||||||
|
|
||||||
|
|
||||||
def _migrate_key_renames(data: dict) -> None:
|
|
||||||
"""Rename legacy keys to their current names."""
|
|
||||||
for item in data.get(KEY_BATCH_DATA, []):
|
|
||||||
if not isinstance(item, dict):
|
|
||||||
continue
|
|
||||||
if 'reference path' in item and 'middle frame path' not in item:
|
|
||||||
item['middle frame path'] = item.pop('reference path')
|
|
||||||
if 'flf image path' in item and 'end frame path' not in item:
|
|
||||||
item['end frame path'] = item.pop('flf image path')
|
|
||||||
if 'reference image path' in item and 'start frame path' not in item:
|
|
||||||
item['start frame path'] = item.pop('reference image path')
|
|
||||||
# Split old single strength into high+low
|
|
||||||
for prefix in ('start frame', 'middle frame', 'end frame'):
|
|
||||||
old_key = f'{prefix} strength'
|
|
||||||
if old_key in item:
|
|
||||||
val = item.pop(old_key)
|
|
||||||
item.setdefault(f'{prefix} high strength', val)
|
|
||||||
item.setdefault(f'{prefix} low strength', val)
|
|
||||||
|
|
||||||
|
|
||||||
def _migrate_lora_keys(data: dict) -> None:
|
|
||||||
"""Split combined lora 'name:strength' into separate name and strength keys.
|
|
||||||
|
|
||||||
Handles legacy formats:
|
|
||||||
1. <lora:Name:0.5> → name_key='Name', str_key=0.5
|
|
||||||
2. 'Name:0.5' (merged) → name_key='Name', str_key=0.5
|
|
||||||
3. Already split (name_key + str_key exist) → no change
|
|
||||||
"""
|
|
||||||
for item in data.get(KEY_BATCH_DATA, []):
|
|
||||||
if not isinstance(item, dict):
|
|
||||||
continue
|
|
||||||
for idx in range(1, 4):
|
|
||||||
for tier in ('high', 'low'):
|
|
||||||
name_key = f'lora {idx} {tier}'
|
|
||||||
str_key = f'lora {idx} {tier} strength'
|
|
||||||
raw = str(item.get(name_key, ''))
|
|
||||||
|
|
||||||
if raw.startswith('<lora:'):
|
|
||||||
# Legacy <lora:Name:0.5> format
|
|
||||||
inner = raw.replace('<lora:', '').replace('>', '')
|
|
||||||
if ':' in inner:
|
|
||||||
parts = inner.rsplit(':', 1)
|
|
||||||
item[name_key] = parts[0]
|
|
||||||
try:
|
|
||||||
item[str_key] = float(parts[1])
|
|
||||||
except ValueError:
|
|
||||||
item[str_key] = 1.0
|
|
||||||
else:
|
|
||||||
item[name_key] = inner
|
|
||||||
if str_key not in item:
|
|
||||||
item[str_key] = 1.0
|
|
||||||
elif ':' in raw and raw:
|
|
||||||
# Combined 'name:strength' format → split
|
|
||||||
parts = raw.rsplit(':', 1)
|
|
||||||
try:
|
|
||||||
strength = float(parts[1])
|
|
||||||
item[name_key] = parts[0]
|
|
||||||
item[str_key] = strength
|
|
||||||
except ValueError:
|
|
||||||
# Not a valid strength, leave as-is
|
|
||||||
if str_key not in item:
|
|
||||||
item[str_key] = 1.0
|
|
||||||
elif raw:
|
|
||||||
# Name exists without colon, ensure strength key exists
|
|
||||||
if str_key not in item:
|
|
||||||
item[str_key] = 1.0
|
|
||||||
# If name is empty, don't add a strength key
|
|
||||||
|
|
||||||
|
|
||||||
def load_json(path: str | Path) -> tuple[dict[str, Any], float]:
|
|
||||||
t0 = time.time()
|
|
||||||
path = Path(path)
|
path = Path(path)
|
||||||
if not path.exists():
|
if not path.exists():
|
||||||
return DEFAULTS.copy(), 0
|
return DEFAULTS.copy(), 0
|
||||||
try:
|
try:
|
||||||
with open(path, 'r') as f:
|
with open(path, 'r') as f:
|
||||||
data = json.load(f)
|
data = json.load(f)
|
||||||
t1 = time.time()
|
return data, path.stat().st_mtime
|
||||||
_migrate_remove_keys(data)
|
|
||||||
_migrate_key_renames(data)
|
|
||||||
_migrate_lora_keys(data)
|
|
||||||
t2 = time.time()
|
|
||||||
mtime = path.stat().st_mtime
|
|
||||||
logger.info("load_json %s: read=%.3fs migrate=%.3fs total=%.3fs",
|
|
||||||
path.name, t1 - t0, t2 - t1, t2 - t0)
|
|
||||||
return data, mtime
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error loading JSON: {e}")
|
st.error(f"Error loading JSON: {e}")
|
||||||
return DEFAULTS.copy(), 0
|
return DEFAULTS.copy(), 0
|
||||||
|
|
||||||
def save_json(path: str | Path, data: dict[str, Any]) -> None:
|
def save_json(path, data):
|
||||||
t0 = time.time()
|
with open(path, 'w') as f:
|
||||||
path = Path(path)
|
|
||||||
tmp = path.with_suffix('.json.tmp')
|
|
||||||
with open(tmp, 'w') as f:
|
|
||||||
json.dump(data, f, indent=4)
|
json.dump(data, f, indent=4)
|
||||||
os.replace(tmp, path)
|
|
||||||
logger.info("save_json %s: %.3fs", path.name, time.time() - t0)
|
|
||||||
|
|
||||||
|
def get_file_mtime(path):
|
||||||
def snapshot_data(data: dict[str, Any]) -> dict[str, Any]:
|
|
||||||
"""Create a thread-safe deep copy via JSON roundtrip.
|
|
||||||
|
|
||||||
Must be called on the main thread before passing data to asyncio.to_thread,
|
|
||||||
to avoid 'dict changed size during iteration' when the UI mutates data.
|
|
||||||
"""
|
|
||||||
return json.loads(json.dumps(data))
|
|
||||||
|
|
||||||
def get_file_mtime(path: str | Path) -> float:
|
|
||||||
"""Returns the modification time of a file, or 0 if it doesn't exist."""
|
"""Returns the modification time of a file, or 0 if it doesn't exist."""
|
||||||
path = Path(path)
|
path = Path(path)
|
||||||
if path.exists():
|
if path.exists():
|
||||||
return path.stat().st_mtime
|
return path.stat().st_mtime
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
def sync_to_db(db, project_name: str, file_path: Path, data: dict) -> None:
|
def generate_templates(current_dir):
|
||||||
"""Dual-write helper: sync JSON data to the project database.
|
"""Creates dummy template files if folder is empty."""
|
||||||
|
save_json(current_dir / "template_i2v.json", DEFAULTS)
|
||||||
|
|
||||||
Resolves (or creates) the data_file, upserts all sequences from batch_data,
|
batch_data = {"batch_data": [DEFAULTS.copy(), DEFAULTS.copy()]}
|
||||||
and saves the history_tree. All writes happen in a single transaction.
|
save_json(current_dir / "template_batch.json", batch_data)
|
||||||
"""
|
|
||||||
t0 = time.time()
|
|
||||||
if not db or not project_name:
|
|
||||||
return
|
|
||||||
try:
|
|
||||||
proj = db.get_project(project_name)
|
|
||||||
if not proj:
|
|
||||||
return
|
|
||||||
file_name = Path(file_path).stem
|
|
||||||
|
|
||||||
# Use a single transaction for atomicity
|
|
||||||
db.conn.execute("BEGIN IMMEDIATE")
|
|
||||||
try:
|
|
||||||
now = time.time()
|
|
||||||
df = db.get_data_file(proj["id"], file_name)
|
|
||||||
top_level = {k: v for k, v in data.items()
|
|
||||||
if k not in (KEY_BATCH_DATA, KEY_HISTORY_TREE)}
|
|
||||||
if not df:
|
|
||||||
cur = db.conn.execute(
|
|
||||||
"INSERT INTO data_files (project_id, name, data_type, top_level, created_at, updated_at) "
|
|
||||||
"VALUES (?, ?, ?, ?, ?, ?)",
|
|
||||||
(proj["id"], file_name, "generic", json.dumps(top_level), now, now),
|
|
||||||
)
|
|
||||||
df_id = cur.lastrowid
|
|
||||||
else:
|
|
||||||
df_id = df["id"]
|
|
||||||
# Update top_level metadata
|
|
||||||
db.conn.execute(
|
|
||||||
"UPDATE data_files SET top_level = ?, updated_at = ? WHERE id = ?",
|
|
||||||
(json.dumps(top_level), now, df_id),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Sync sequences
|
|
||||||
batch_data = data.get(KEY_BATCH_DATA, [])
|
|
||||||
if isinstance(batch_data, list):
|
|
||||||
new_seq_nums = set()
|
|
||||||
for item in batch_data:
|
|
||||||
if not isinstance(item, dict):
|
|
||||||
continue
|
|
||||||
seq_num = int(item.get(KEY_SEQUENCE_NUMBER, 0))
|
|
||||||
new_seq_nums.add(seq_num)
|
|
||||||
db.conn.execute(
|
|
||||||
"INSERT INTO sequences (data_file_id, sequence_number, data, updated_at) "
|
|
||||||
"VALUES (?, ?, ?, ?) "
|
|
||||||
"ON CONFLICT(data_file_id, sequence_number) DO UPDATE SET data=excluded.data, updated_at=excluded.updated_at",
|
|
||||||
(df_id, seq_num, json.dumps(item), now),
|
|
||||||
)
|
|
||||||
# Remove sequences that no longer exist
|
|
||||||
if new_seq_nums:
|
|
||||||
placeholders = ','.join('?' * len(new_seq_nums))
|
|
||||||
db.conn.execute(
|
|
||||||
f"DELETE FROM sequences WHERE data_file_id = ? AND sequence_number NOT IN ({placeholders})",
|
|
||||||
(df_id, *new_seq_nums),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
db.conn.execute("DELETE FROM sequences WHERE data_file_id = ?", (df_id,))
|
|
||||||
|
|
||||||
# Sync history tree (extract snapshot data into separate table)
|
|
||||||
# Supports both new format (snapshots dict) and old format (nodes dict)
|
|
||||||
history_tree = data.get(KEY_HISTORY_TREE)
|
|
||||||
if history_tree and isinstance(history_tree, dict):
|
|
||||||
# Detect format: new has "snapshots", old has "nodes"
|
|
||||||
if "snapshots" in history_tree:
|
|
||||||
entries = history_tree.get("snapshots", {})
|
|
||||||
else:
|
|
||||||
entries = history_tree.get("nodes", {})
|
|
||||||
slim_tree = dict(history_tree)
|
|
||||||
slim_entries = {}
|
|
||||||
for eid, entry in entries.items():
|
|
||||||
snap = entry.get("data")
|
|
||||||
if snap:
|
|
||||||
db.conn.execute(
|
|
||||||
"INSERT INTO history_snapshots (data_file_id, node_id, snapshot_data, updated_at) "
|
|
||||||
"VALUES (?, ?, ?, ?) "
|
|
||||||
"ON CONFLICT(data_file_id, node_id) DO UPDATE SET "
|
|
||||||
"snapshot_data=excluded.snapshot_data, updated_at=excluded.updated_at",
|
|
||||||
(df_id, eid, json.dumps(snap), now),
|
|
||||||
)
|
|
||||||
slim_entries[eid] = {k: v for k, v in entry.items() if k != "data"}
|
|
||||||
# Write back slim version using the correct key
|
|
||||||
if "snapshots" in history_tree:
|
|
||||||
slim_tree["snapshots"] = slim_entries
|
|
||||||
else:
|
|
||||||
slim_tree["nodes"] = slim_entries
|
|
||||||
db.conn.execute(
|
|
||||||
"INSERT INTO history_trees (data_file_id, tree_data, updated_at) "
|
|
||||||
"VALUES (?, ?, ?) "
|
|
||||||
"ON CONFLICT(data_file_id) DO UPDATE SET tree_data=excluded.tree_data, updated_at=excluded.updated_at",
|
|
||||||
(df_id, json.dumps(slim_tree), now),
|
|
||||||
)
|
|
||||||
# Clean up orphaned snapshots
|
|
||||||
current_ids = set(entries.keys())
|
|
||||||
if current_ids:
|
|
||||||
placeholders = ",".join("?" for _ in current_ids)
|
|
||||||
db.conn.execute(
|
|
||||||
f"DELETE FROM history_snapshots WHERE data_file_id = ? "
|
|
||||||
f"AND node_id NOT IN ({placeholders})",
|
|
||||||
(df_id, *current_ids),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
db.conn.execute(
|
|
||||||
"DELETE FROM history_snapshots WHERE data_file_id = ?",
|
|
||||||
(df_id,),
|
|
||||||
)
|
|
||||||
|
|
||||||
db.conn.execute("COMMIT")
|
|
||||||
except Exception:
|
|
||||||
try:
|
|
||||||
db.conn.execute("ROLLBACK")
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"sync_to_db failed: {e}")
|
|
||||||
return
|
|
||||||
batch_count = len(data.get(KEY_BATCH_DATA, []))
|
|
||||||
logger.info("sync_to_db %s (%d seqs): %.3fs",
|
|
||||||
Path(file_path).name, batch_count, time.time() - t0)
|
|
||||||
|
|
||||||
|
|
||||||
def generate_templates(current_dir: Path) -> None:
|
|
||||||
"""Creates batch template files if folder is empty."""
|
|
||||||
first = copy.deepcopy(DEFAULTS)
|
|
||||||
first[KEY_SEQUENCE_NUMBER] = 1
|
|
||||||
save_json(current_dir / "batch_prompt_i2v.json", {KEY_BATCH_DATA: [first]})
|
|
||||||
|
|
||||||
first2 = copy.deepcopy(DEFAULTS)
|
|
||||||
first2[KEY_SEQUENCE_NUMBER] = 1
|
|
||||||
save_json(current_dir / "batch_prompt_vace_extend.json", {KEY_BATCH_DATA: [first2]})
|
|
||||||
|
|||||||
@@ -1,20 +0,0 @@
|
|||||||
import { app } from "../../scripts/app.js";
|
|
||||||
|
|
||||||
app.registerExtension({
|
|
||||||
name: "json.manager.binary_index_decoder",
|
|
||||||
|
|
||||||
async beforeRegisterNodeDef(nodeType, nodeData, app) {
|
|
||||||
if (nodeData.name !== "BinaryIndexDecoder") return;
|
|
||||||
|
|
||||||
nodeType.prototype.onExecuted = function (output) {
|
|
||||||
if (!output?.values) return;
|
|
||||||
for (let i = 0; i < Math.min(output.values.length, this.outputs.length); i++) {
|
|
||||||
const val = output.values[i];
|
|
||||||
this.outputs[i].label = `${val} ${this.outputs[i].name}`;
|
|
||||||
this.outputs[i].color_on = (val === "true") ? "#4caf50" : "#888888";
|
|
||||||
this.outputs[i].color_off = (val === "true") ? "#4caf50" : "#888888";
|
|
||||||
}
|
|
||||||
app.graph?.setDirtyCanvas(true, true);
|
|
||||||
};
|
|
||||||
},
|
|
||||||
});
|
|
||||||
@@ -1,261 +0,0 @@
|
|||||||
import { app } from "../../scripts/app.js";
|
|
||||||
import { api } from "../../scripts/api.js";
|
|
||||||
|
|
||||||
app.registerExtension({
|
|
||||||
name: "json.manager.project.dynamic",
|
|
||||||
|
|
||||||
async beforeRegisterNodeDef(nodeType, nodeData, app) {
|
|
||||||
if (nodeData.name !== "ProjectLoaderDynamic") return;
|
|
||||||
|
|
||||||
const origOnNodeCreated = nodeType.prototype.onNodeCreated;
|
|
||||||
nodeType.prototype.onNodeCreated = function () {
|
|
||||||
origOnNodeCreated?.apply(this, arguments);
|
|
||||||
|
|
||||||
// Hide internal widgets (managed by JS)
|
|
||||||
for (const name of ["output_keys", "output_types"]) {
|
|
||||||
const w = this.widgets?.find(w => w.name === name);
|
|
||||||
if (w) { w.type = "hidden"; w.computeSize = () => [0, -4]; }
|
|
||||||
}
|
|
||||||
|
|
||||||
// Do NOT remove default outputs synchronously here.
|
|
||||||
// During graph loading, ComfyUI creates all nodes (firing onNodeCreated)
|
|
||||||
// before configuring them. Other nodes (e.g. Kijai Set/Get) may resolve
|
|
||||||
// links to our outputs during their configure step. If we remove outputs
|
|
||||||
// here, those nodes find no output slot and error out.
|
|
||||||
//
|
|
||||||
// Instead, defer cleanup: for loaded workflows onConfigure sets _configured
|
|
||||||
// before this runs; for new nodes the defaults are cleaned up.
|
|
||||||
this._configured = false;
|
|
||||||
|
|
||||||
// Add Refresh button
|
|
||||||
this.addWidget("button", "Refresh Outputs", null, () => {
|
|
||||||
this.refreshDynamicOutputs();
|
|
||||||
});
|
|
||||||
|
|
||||||
// Auto-refresh with 500ms debounce on widget changes
|
|
||||||
this._refreshTimer = null;
|
|
||||||
const autoRefreshWidgets = ["project_name", "file_name", "sequence_number", "refresh"];
|
|
||||||
for (const widgetName of autoRefreshWidgets) {
|
|
||||||
const w = this.widgets?.find(w => w.name === widgetName);
|
|
||||||
if (w) {
|
|
||||||
const origCallback = w.callback;
|
|
||||||
const node = this;
|
|
||||||
w.callback = function (...args) {
|
|
||||||
origCallback?.apply(this, args);
|
|
||||||
clearTimeout(node._refreshTimer);
|
|
||||||
node._refreshTimer = setTimeout(() => {
|
|
||||||
node.refreshDynamicOutputs();
|
|
||||||
}, 500);
|
|
||||||
};
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
queueMicrotask(() => {
|
|
||||||
if (!this._configured) {
|
|
||||||
// New node (not loading) — remove the Python default outputs
|
|
||||||
// and add only the fixed total_sequences slot
|
|
||||||
while (this.outputs.length > 0) {
|
|
||||||
this.removeOutput(0);
|
|
||||||
}
|
|
||||||
this.addOutput("total_sequences", "INT");
|
|
||||||
this.setSize(this.computeSize());
|
|
||||||
app.graph?.setDirtyCanvas(true, true);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
};
|
|
||||||
|
|
||||||
nodeType.prototype._setStatus = function (status, message) {
|
|
||||||
const baseTitle = "Project Loader (Dynamic)";
|
|
||||||
if (status === "ok") {
|
|
||||||
this.title = baseTitle;
|
|
||||||
this.color = undefined;
|
|
||||||
this.bgcolor = undefined;
|
|
||||||
} else if (status === "error") {
|
|
||||||
this.title = baseTitle + " - ERROR";
|
|
||||||
this.color = "#ff4444";
|
|
||||||
this.bgcolor = "#331111";
|
|
||||||
if (message) this.title = baseTitle + ": " + message;
|
|
||||||
} else if (status === "loading") {
|
|
||||||
this.title = baseTitle + " - Loading...";
|
|
||||||
}
|
|
||||||
app.graph?.setDirtyCanvas(true, true);
|
|
||||||
};
|
|
||||||
|
|
||||||
nodeType.prototype.refreshDynamicOutputs = async function () {
|
|
||||||
const urlWidget = this.widgets?.find(w => w.name === "manager_url");
|
|
||||||
const projectWidget = this.widgets?.find(w => w.name === "project_name");
|
|
||||||
const fileWidget = this.widgets?.find(w => w.name === "file_name");
|
|
||||||
const seqWidget = this.widgets?.find(w => w.name === "sequence_number");
|
|
||||||
|
|
||||||
if (!urlWidget?.value || !projectWidget?.value || !fileWidget?.value) return;
|
|
||||||
|
|
||||||
this._setStatus("loading");
|
|
||||||
|
|
||||||
try {
|
|
||||||
const resp = await api.fetchApi(
|
|
||||||
`/json_manager/get_project_keys?url=${encodeURIComponent(urlWidget.value)}&project=${encodeURIComponent(projectWidget.value)}&file=${encodeURIComponent(fileWidget.value)}&seq=${seqWidget?.value || 1}`
|
|
||||||
);
|
|
||||||
|
|
||||||
if (!resp.ok) {
|
|
||||||
let errorMsg = `HTTP ${resp.status}`;
|
|
||||||
try {
|
|
||||||
const errData = await resp.json();
|
|
||||||
if (errData.message) errorMsg = errData.message;
|
|
||||||
} catch (_) {}
|
|
||||||
this._setStatus("error", errorMsg);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const data = await resp.json();
|
|
||||||
const keys = data.keys;
|
|
||||||
const types = data.types;
|
|
||||||
|
|
||||||
// If the API returned an error or missing data, keep existing outputs and links intact
|
|
||||||
if (data.error || !Array.isArray(keys) || !Array.isArray(types)) {
|
|
||||||
const errMsg = data.error ? data.message || data.error : "Missing keys/types";
|
|
||||||
this._setStatus("error", errMsg);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Store keys and types in hidden widgets for persistence (JSON)
|
|
||||||
const okWidget = this.widgets?.find(w => w.name === "output_keys");
|
|
||||||
if (okWidget) okWidget.value = JSON.stringify(keys);
|
|
||||||
const otWidget = this.widgets?.find(w => w.name === "output_types");
|
|
||||||
if (otWidget) otWidget.value = JSON.stringify(types);
|
|
||||||
|
|
||||||
// Slot 0 is always total_sequences (INT) — ensure it exists
|
|
||||||
if (this.outputs.length === 0 || this.outputs[0].name !== "total_sequences") {
|
|
||||||
this.outputs.unshift({ name: "total_sequences", type: "INT", links: null });
|
|
||||||
}
|
|
||||||
this.outputs[0].type = "INT";
|
|
||||||
|
|
||||||
// Build a map of current dynamic output names to slot indices (skip slot 0)
|
|
||||||
const oldSlots = {};
|
|
||||||
for (let i = 1; i < this.outputs.length; i++) {
|
|
||||||
oldSlots[this.outputs[i].name] = i;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Build new dynamic outputs, reusing existing slots to preserve links
|
|
||||||
const newOutputs = [this.outputs[0]]; // Keep total_sequences at slot 0
|
|
||||||
for (let k = 0; k < keys.length; k++) {
|
|
||||||
const key = keys[k];
|
|
||||||
const type = types[k] || "*";
|
|
||||||
if (key in oldSlots) {
|
|
||||||
const slot = this.outputs[oldSlots[key]];
|
|
||||||
slot.type = type;
|
|
||||||
slot.label = key;
|
|
||||||
newOutputs.push(slot);
|
|
||||||
delete oldSlots[key];
|
|
||||||
} else {
|
|
||||||
newOutputs.push({ name: key, label: key, type: type, links: null });
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Disconnect links on slots that are being removed
|
|
||||||
for (const name in oldSlots) {
|
|
||||||
const idx = oldSlots[name];
|
|
||||||
if (this.outputs[idx]?.links?.length) {
|
|
||||||
for (const linkId of [...this.outputs[idx].links]) {
|
|
||||||
this.graph?.removeLink(linkId);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Reassign the outputs array and fix link slot indices
|
|
||||||
this.outputs = newOutputs;
|
|
||||||
if (this.graph) {
|
|
||||||
for (let i = 0; i < this.outputs.length; i++) {
|
|
||||||
const links = this.outputs[i].links;
|
|
||||||
if (!links) continue;
|
|
||||||
for (const linkId of links) {
|
|
||||||
const link = this.graph.links[linkId];
|
|
||||||
if (link) link.origin_slot = i;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
this._setStatus("ok");
|
|
||||||
this.setSize(this.computeSize());
|
|
||||||
app.graph?.setDirtyCanvas(true, true);
|
|
||||||
} catch (e) {
|
|
||||||
console.error("[ProjectLoaderDynamic] Refresh failed:", e);
|
|
||||||
this._setStatus("error", "Server unreachable");
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// Restore state on workflow load
|
|
||||||
const origOnConfigure = nodeType.prototype.onConfigure;
|
|
||||||
nodeType.prototype.onConfigure = function (info) {
|
|
||||||
origOnConfigure?.apply(this, arguments);
|
|
||||||
this._configured = true;
|
|
||||||
|
|
||||||
// Hide internal widgets
|
|
||||||
for (const name of ["output_keys", "output_types"]) {
|
|
||||||
const w = this.widgets?.find(w => w.name === name);
|
|
||||||
if (w) { w.type = "hidden"; w.computeSize = () => [0, -4]; }
|
|
||||||
}
|
|
||||||
|
|
||||||
const okWidget = this.widgets?.find(w => w.name === "output_keys");
|
|
||||||
const otWidget = this.widgets?.find(w => w.name === "output_types");
|
|
||||||
|
|
||||||
let keys = [];
|
|
||||||
let types = [];
|
|
||||||
if (okWidget?.value) {
|
|
||||||
try { keys = JSON.parse(okWidget.value); } catch (_) {
|
|
||||||
keys = okWidget.value.split(",").filter(k => k.trim());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (otWidget?.value) {
|
|
||||||
try { types = JSON.parse(otWidget.value); } catch (_) {
|
|
||||||
types = otWidget.value.split(",");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Ensure slot 0 is total_sequences (INT)
|
|
||||||
if (this.outputs.length === 0 || this.outputs[0].name !== "total_sequences") {
|
|
||||||
this.outputs.unshift({ name: "total_sequences", type: "INT", links: null });
|
|
||||||
const node = this;
|
|
||||||
queueMicrotask(() => {
|
|
||||||
if (!node.graph) return;
|
|
||||||
for (const output of node.outputs) {
|
|
||||||
output.links = null;
|
|
||||||
}
|
|
||||||
for (const linkId in node.graph.links) {
|
|
||||||
const link = node.graph.links[linkId];
|
|
||||||
if (!link || link.origin_id !== node.id) continue;
|
|
||||||
link.origin_slot += 1;
|
|
||||||
const output = node.outputs[link.origin_slot];
|
|
||||||
if (output) {
|
|
||||||
if (!output.links) output.links = [];
|
|
||||||
output.links.push(link.id);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
app.graph?.setDirtyCanvas(true, true);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
this.outputs[0].type = "INT";
|
|
||||||
this.outputs[0].name = "total_sequences";
|
|
||||||
|
|
||||||
if (keys.length > 0) {
|
|
||||||
for (let i = 0; i < keys.length; i++) {
|
|
||||||
const slotIdx = i + 1;
|
|
||||||
if (slotIdx < this.outputs.length) {
|
|
||||||
this.outputs[slotIdx].name = keys[i].trim();
|
|
||||||
this.outputs[slotIdx].label = keys[i].trim();
|
|
||||||
if (types[i]) this.outputs[slotIdx].type = types[i];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
while (this.outputs.length > keys.length + 1) {
|
|
||||||
this.removeOutput(this.outputs.length - 1);
|
|
||||||
}
|
|
||||||
} else if (this.outputs.length > 1) {
|
|
||||||
// Widget values empty but serialized dynamic outputs exist — sync widgets
|
|
||||||
const dynamicOutputs = this.outputs.slice(1);
|
|
||||||
if (okWidget) okWidget.value = JSON.stringify(dynamicOutputs.map(o => o.name));
|
|
||||||
if (otWidget) otWidget.value = JSON.stringify(dynamicOutputs.map(o => o.type));
|
|
||||||
}
|
|
||||||
|
|
||||||
this.setSize(this.computeSize());
|
|
||||||
};
|
|
||||||
},
|
|
||||||
});
|
|
||||||
@@ -1,329 +0,0 @@
|
|||||||
import { app } from "../../scripts/app.js";
|
|
||||||
import { api } from "../../scripts/api.js";
|
|
||||||
|
|
||||||
app.registerExtension({
|
|
||||||
name: "json.manager.project.key",
|
|
||||||
|
|
||||||
// Re-sync all ProjectKey nodes from their sources before queueing
|
|
||||||
// This fixes stale config when the user edits a ProjectSource after
|
|
||||||
// a ProjectKey already selected it.
|
|
||||||
async beforeQueuePrompt() {
|
|
||||||
if (!app.graph?._nodes) return;
|
|
||||||
for (const node of app.graph._nodes) {
|
|
||||||
if (node.type === "ProjectKey" && node._syncFromSource) {
|
|
||||||
node._syncFromSource();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
|
|
||||||
async beforeRegisterNodeDef(nodeType, nodeData, app) {
|
|
||||||
if (nodeData.name !== "ProjectKey") return;
|
|
||||||
|
|
||||||
// Helper: properly hide a widget (works for all types including INT)
|
|
||||||
function hideWidget(widget) {
|
|
||||||
if (widget.origType === undefined) widget.origType = widget.type;
|
|
||||||
widget.type = "hidden";
|
|
||||||
widget.hidden = true;
|
|
||||||
widget.computeSize = () => [0, -4];
|
|
||||||
}
|
|
||||||
|
|
||||||
// Helper: replace a STRING widget with a proper combo widget
|
|
||||||
function replaceWithCombo(node, name, values, callback) {
|
|
||||||
const idx = node.widgets?.findIndex(w => w.name === name);
|
|
||||||
if (idx === -1 || idx === undefined) return null;
|
|
||||||
const oldWidget = node.widgets[idx];
|
|
||||||
const savedValue = oldWidget.value || "";
|
|
||||||
// Ensure values list is never empty (combo shows undefined otherwise)
|
|
||||||
const comboValues = values.length > 0 ? values : [""];
|
|
||||||
// Always preserve saved value — it may not be in the list yet (load-order race)
|
|
||||||
if (savedValue && !comboValues.includes(savedValue)) {
|
|
||||||
comboValues.unshift(savedValue);
|
|
||||||
}
|
|
||||||
const defaultValue = savedValue || comboValues[0];
|
|
||||||
// Remove old STRING widget
|
|
||||||
node.widgets.splice(idx, 1);
|
|
||||||
// Insert a real combo widget at the same position
|
|
||||||
const combo = node.addWidget("combo", name, defaultValue, callback, { values: comboValues });
|
|
||||||
// Move it from the end to the original position
|
|
||||||
if (node.widgets.length > 1) {
|
|
||||||
node.widgets.splice(node.widgets.length - 1, 1);
|
|
||||||
node.widgets.splice(idx, 0, combo);
|
|
||||||
}
|
|
||||||
return combo;
|
|
||||||
}
|
|
||||||
|
|
||||||
const origOnNodeCreated = nodeType.prototype.onNodeCreated;
|
|
||||||
nodeType.prototype.onNodeCreated = function () {
|
|
||||||
origOnNodeCreated?.apply(this, arguments);
|
|
||||||
this._configured = false;
|
|
||||||
|
|
||||||
// Hide the connection-config widgets (synced from source by JS)
|
|
||||||
for (const name of ["manager_url", "project_name", "file_name", "sequence_number", "key_type"]) {
|
|
||||||
const w = this.widgets?.find(w => w.name === name);
|
|
||||||
if (w) hideWidget(w);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Replace source_label STRING with a proper combo widget
|
|
||||||
const node = this;
|
|
||||||
const sourceLabels = this._getSourceLabels?.() || [];
|
|
||||||
const srcCombo = replaceWithCombo(this, "source_label", sourceLabels, function (value) {
|
|
||||||
node._syncFromSource();
|
|
||||||
node._refreshKeys();
|
|
||||||
});
|
|
||||||
// Set first available source or "none" placeholder
|
|
||||||
if (srcCombo) srcCombo.value = sourceLabels[0] || "";
|
|
||||||
|
|
||||||
// Replace key_name STRING with a proper combo widget
|
|
||||||
const keyCombo = replaceWithCombo(this, "key_name", [], function (value) {
|
|
||||||
node._applyKeySelection();
|
|
||||||
});
|
|
||||||
if (keyCombo) keyCombo.value = "";
|
|
||||||
|
|
||||||
queueMicrotask(() => {
|
|
||||||
if (!this._configured) {
|
|
||||||
// New node — set output to a generic slot
|
|
||||||
if (this.outputs.length === 0) {
|
|
||||||
this.addOutput("value", "*");
|
|
||||||
}
|
|
||||||
this.setSize(this.computeSize());
|
|
||||||
}
|
|
||||||
});
|
|
||||||
};
|
|
||||||
|
|
||||||
// --- Find all ProjectSource nodes and their labels (deduplicated) ---
|
|
||||||
nodeType.prototype._getSourceLabels = function () {
|
|
||||||
const seen = new Set();
|
|
||||||
const labels = [];
|
|
||||||
if (!this.graph) return labels;
|
|
||||||
for (const node of this.graph._nodes) {
|
|
||||||
if (node.type === "ProjectSource") {
|
|
||||||
const lw = node.widgets?.find(w => w.name === "label");
|
|
||||||
if (lw?.value && !seen.has(lw.value)) {
|
|
||||||
seen.add(lw.value);
|
|
||||||
labels.push(lw.value);
|
|
||||||
} else if (lw?.value && seen.has(lw.value)) {
|
|
||||||
console.warn(`[ProjectKey] Duplicate source label "${lw.value}" (node ${node.id}) — only first will be used`);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return labels;
|
|
||||||
};
|
|
||||||
|
|
||||||
// --- Find the ProjectSource node matching a label ---
|
|
||||||
nodeType.prototype._findSource = function (label) {
|
|
||||||
if (!this.graph || !label) return null;
|
|
||||||
for (const node of this.graph._nodes) {
|
|
||||||
if (node.type === "ProjectSource") {
|
|
||||||
const lw = node.widgets?.find(w => w.name === "label");
|
|
||||||
if (lw?.value === label) return node;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return null;
|
|
||||||
};
|
|
||||||
|
|
||||||
// --- Copy config from source node into hidden widgets ---
|
|
||||||
nodeType.prototype._syncFromSource = function () {
|
|
||||||
const srcWidget = this.widgets?.find(w => w.name === "source_label");
|
|
||||||
const source = this._findSource(srcWidget?.value);
|
|
||||||
if (!source) {
|
|
||||||
console.log(`[ProjectKey] _syncFromSource id=${this.id}: no source found for label="${srcWidget?.value}"`);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
for (const name of ["manager_url", "project_name", "file_name", "sequence_number"]) {
|
|
||||||
const dst = this.widgets?.find(w => w.name === name);
|
|
||||||
const src = source.widgets?.find(w => w.name === name);
|
|
||||||
if (dst && src) {
|
|
||||||
dst.value = src.value;
|
|
||||||
console.log(`[ProjectKey] _syncFromSource id=${this.id}: ${name}="${src.value}"`);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// --- Fetch keys from API and populate key_name dropdown ---
|
|
||||||
nodeType.prototype._refreshKeys = async function () {
|
|
||||||
const urlW = this.widgets?.find(w => w.name === "manager_url");
|
|
||||||
const projW = this.widgets?.find(w => w.name === "project_name");
|
|
||||||
const fileW = this.widgets?.find(w => w.name === "file_name");
|
|
||||||
const seqW = this.widgets?.find(w => w.name === "sequence_number");
|
|
||||||
|
|
||||||
console.log(`[ProjectKey] _refreshKeys id=${this.id}: url="${urlW?.value}" project="${projW?.value}" file="${fileW?.value}" seq=${seqW?.value}`);
|
|
||||||
if (!urlW?.value || !projW?.value || !fileW?.value) {
|
|
||||||
console.log(`[ProjectKey] _refreshKeys: skipped (missing config)`);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
try {
|
|
||||||
const resp = await api.fetchApi(
|
|
||||||
`/json_manager/get_project_keys?url=${encodeURIComponent(urlW.value)}&project=${encodeURIComponent(projW.value)}&file=${encodeURIComponent(fileW.value)}&seq=${seqW?.value || 1}`
|
|
||||||
);
|
|
||||||
if (!resp.ok) return;
|
|
||||||
|
|
||||||
const data = await resp.json();
|
|
||||||
if (data.error || !Array.isArray(data.keys)) return;
|
|
||||||
|
|
||||||
// Store keys/types for lookup
|
|
||||||
this._availableKeys = data.keys;
|
|
||||||
this._availableTypes = data.types;
|
|
||||||
|
|
||||||
// Update key_name combo options only — never change the selection
|
|
||||||
const keyWidget = this.widgets?.find(w => w.name === "key_name");
|
|
||||||
if (keyWidget) {
|
|
||||||
keyWidget.options.values = data.keys;
|
|
||||||
// Selection is sticky: user must change it manually
|
|
||||||
this._applyKeySelection();
|
|
||||||
}
|
|
||||||
} catch (e) {
|
|
||||||
console.error("[ProjectKey] Failed to refresh keys:", e);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// --- Update output slot based on selected key ---
|
|
||||||
nodeType.prototype._applyKeySelection = function () {
|
|
||||||
const keyWidget = this.widgets?.find(w => w.name === "key_name");
|
|
||||||
if (!keyWidget?.value) return;
|
|
||||||
|
|
||||||
const keyIdx = (this._availableKeys || []).indexOf(keyWidget.value);
|
|
||||||
const keyType = keyIdx >= 0 ? (this._availableTypes[keyIdx] || "*") : "*";
|
|
||||||
|
|
||||||
// Update hidden key_type widget
|
|
||||||
const ktWidget = this.widgets?.find(w => w.name === "key_type");
|
|
||||||
if (ktWidget) ktWidget.value = keyType;
|
|
||||||
|
|
||||||
// Update output slot
|
|
||||||
if (this.outputs.length > 0) {
|
|
||||||
this.outputs[0].name = keyWidget.value;
|
|
||||||
this.outputs[0].label = keyWidget.value;
|
|
||||||
this.outputs[0].type = keyType;
|
|
||||||
}
|
|
||||||
|
|
||||||
this.title = keyWidget.value ? `Key: ${keyWidget.value}` : "Project Key";
|
|
||||||
this.setSize(this.computeSize());
|
|
||||||
app.graph?.setDirtyCanvas(true, true);
|
|
||||||
};
|
|
||||||
|
|
||||||
// --- Show live value on output slot after execution (INT/FLOAT/BOOL only) ---
|
|
||||||
nodeType.prototype.onExecuted = function (output) {
|
|
||||||
if (!this.outputs.length) return;
|
|
||||||
const val = output?.value?.[0];
|
|
||||||
if (val === undefined) return;
|
|
||||||
const keyWidget = this.widgets?.find(w => w.name === "key_name");
|
|
||||||
const name = keyWidget?.value || this.outputs[0].name;
|
|
||||||
this.outputs[0].label = `${val} ${name}`;
|
|
||||||
const slotType = this.outputs[0].type;
|
|
||||||
const TYPE_COLORS = { "INT": "#3d7eb5", "FLOAT": "#68a468", "BOOLEAN": null };
|
|
||||||
let color;
|
|
||||||
if (slotType === "BOOLEAN") {
|
|
||||||
color = (val === "true") ? "#4caf50" : "#888888";
|
|
||||||
} else {
|
|
||||||
color = TYPE_COLORS[slotType]
|
|
||||||
?? LGraphCanvas?.link_type_colors?.[slotType]
|
|
||||||
?? app.canvas?.default_connection_color_byType?.[slotType];
|
|
||||||
}
|
|
||||||
if (color) {
|
|
||||||
this.outputs[0].color_on = color;
|
|
||||||
this.outputs[0].color_off = color;
|
|
||||||
}
|
|
||||||
app.graph?.setDirtyCanvas(true, true);
|
|
||||||
};
|
|
||||||
|
|
||||||
// --- Highlight all ProjectKey nodes sharing the same key_name on select ---
|
|
||||||
nodeType.prototype.onSelected = function () {
|
|
||||||
const keyWidget = this.widgets?.find(w => w.name === "key_name");
|
|
||||||
const myKey = keyWidget?.value;
|
|
||||||
if (!myKey || !this.graph) return;
|
|
||||||
for (const node of this.graph._nodes) {
|
|
||||||
if (node === this || node.type !== "ProjectKey") continue;
|
|
||||||
const kw = node.widgets?.find(w => w.name === "key_name");
|
|
||||||
if (kw?.value !== myKey) continue;
|
|
||||||
node._savedColor = node.color;
|
|
||||||
node._savedBgColor = node.bgcolor;
|
|
||||||
node.color = "#c8a000";
|
|
||||||
node.bgcolor = "#4a3800";
|
|
||||||
}
|
|
||||||
app.graph?.setDirtyCanvas(true, true);
|
|
||||||
};
|
|
||||||
|
|
||||||
nodeType.prototype.onDeselected = function () {
|
|
||||||
if (!this.graph) return;
|
|
||||||
for (const node of this.graph._nodes) {
|
|
||||||
if (node.type !== "ProjectKey" || !("_savedColor" in node)) continue;
|
|
||||||
node.color = node._savedColor;
|
|
||||||
node.bgcolor = node._savedBgColor;
|
|
||||||
delete node._savedColor;
|
|
||||||
delete node._savedBgColor;
|
|
||||||
}
|
|
||||||
app.graph?.setDirtyCanvas(true, true);
|
|
||||||
};
|
|
||||||
|
|
||||||
// --- Sync config on click (lazy, no key refresh to avoid race) ---
|
|
||||||
const origOnMouseDown = nodeType.prototype.onMouseDown;
|
|
||||||
nodeType.prototype.onMouseDown = function (e, localPos, graphCanvas) {
|
|
||||||
origOnMouseDown?.apply(this, arguments);
|
|
||||||
const srcWidget = this.widgets?.find(w => w.name === "source_label");
|
|
||||||
if (srcWidget) {
|
|
||||||
srcWidget.options.values = this._getSourceLabels();
|
|
||||||
}
|
|
||||||
// Sync config values from source (synchronous, safe)
|
|
||||||
this._syncFromSource();
|
|
||||||
};
|
|
||||||
|
|
||||||
// --- Restore state on workflow load ---
|
|
||||||
const origOnConfigure = nodeType.prototype.onConfigure;
|
|
||||||
nodeType.prototype.onConfigure = function (info) {
|
|
||||||
origOnConfigure?.apply(this, arguments);
|
|
||||||
this._configured = true;
|
|
||||||
|
|
||||||
// Hide config widgets
|
|
||||||
for (const name of ["manager_url", "project_name", "file_name", "sequence_number", "key_type"]) {
|
|
||||||
const w = this.widgets?.find(w => w.name === name);
|
|
||||||
if (w) hideWidget(w);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Ensure source_label is a proper combo (may still be STRING from serialization)
|
|
||||||
const srcWidget = this.widgets?.find(w => w.name === "source_label");
|
|
||||||
if (srcWidget && srcWidget.type !== "combo") {
|
|
||||||
const node = this;
|
|
||||||
replaceWithCombo(this, "source_label", this._getSourceLabels(), function (value) {
|
|
||||||
node._syncFromSource();
|
|
||||||
node._refreshKeys();
|
|
||||||
});
|
|
||||||
} else if (srcWidget) {
|
|
||||||
srcWidget.options.values = this._getSourceLabels();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Ensure key_name is a proper combo
|
|
||||||
const keyWidget = this.widgets?.find(w => w.name === "key_name");
|
|
||||||
if (keyWidget && keyWidget.type !== "combo") {
|
|
||||||
const node = this;
|
|
||||||
replaceWithCombo(this, "key_name", [], function (value) {
|
|
||||||
node._applyKeySelection();
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
// Re-find widgets after possible replacement
|
|
||||||
const finalKeyWidget = this.widgets?.find(w => w.name === "key_name");
|
|
||||||
|
|
||||||
// Update title from saved key
|
|
||||||
if (finalKeyWidget?.value) {
|
|
||||||
this.title = `Key: ${finalKeyWidget.value}`;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Restore output slot name from saved key_name
|
|
||||||
if (finalKeyWidget?.value && this.outputs.length > 0) {
|
|
||||||
this.outputs[0].name = finalKeyWidget.value;
|
|
||||||
this.outputs[0].label = finalKeyWidget.value;
|
|
||||||
const ktWidget = this.widgets?.find(w => w.name === "key_type");
|
|
||||||
if (ktWidget?.value) this.outputs[0].type = ktWidget.value;
|
|
||||||
}
|
|
||||||
|
|
||||||
this.setSize(this.computeSize());
|
|
||||||
|
|
||||||
// Deferred: sync from source and refresh key dropdown once graph is ready
|
|
||||||
const node = this;
|
|
||||||
queueMicrotask(() => {
|
|
||||||
node._syncFromSource();
|
|
||||||
node._refreshKeys();
|
|
||||||
});
|
|
||||||
};
|
|
||||||
},
|
|
||||||
});
|
|
||||||
@@ -1,191 +0,0 @@
|
|||||||
import { app } from "../../scripts/app.js";
|
|
||||||
import { api } from "../../scripts/api.js";
|
|
||||||
|
|
||||||
app.registerExtension({
|
|
||||||
name: "json.manager.project.resolution",
|
|
||||||
|
|
||||||
async beforeQueuePrompt() {
|
|
||||||
if (!app.graph?._nodes) return;
|
|
||||||
for (const node of app.graph._nodes) {
|
|
||||||
if (node.type === "ProjectResolution" && node._syncFromSource) {
|
|
||||||
node._syncFromSource();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
|
|
||||||
async beforeRegisterNodeDef(nodeType, nodeData, app) {
|
|
||||||
if (nodeData.name !== "ProjectResolution") return;
|
|
||||||
|
|
||||||
function hideWidget(widget) {
|
|
||||||
if (widget.origType === undefined) widget.origType = widget.type;
|
|
||||||
widget.type = "hidden";
|
|
||||||
widget.hidden = true;
|
|
||||||
widget.computeSize = () => [0, -4];
|
|
||||||
}
|
|
||||||
|
|
||||||
function replaceWithCombo(node, name, values, callback) {
|
|
||||||
const idx = node.widgets?.findIndex(w => w.name === name);
|
|
||||||
if (idx === -1 || idx === undefined) return null;
|
|
||||||
const oldWidget = node.widgets[idx];
|
|
||||||
const savedValue = oldWidget.value || "";
|
|
||||||
const comboValues = values.length > 0 ? values : [""];
|
|
||||||
if (savedValue && !comboValues.includes(savedValue)) {
|
|
||||||
comboValues.unshift(savedValue);
|
|
||||||
}
|
|
||||||
const defaultValue = savedValue || comboValues[0];
|
|
||||||
node.widgets.splice(idx, 1);
|
|
||||||
const combo = node.addWidget("combo", name, defaultValue, callback, { values: comboValues });
|
|
||||||
if (node.widgets.length > 1) {
|
|
||||||
node.widgets.splice(node.widgets.length - 1, 1);
|
|
||||||
node.widgets.splice(idx, 0, combo);
|
|
||||||
}
|
|
||||||
return combo;
|
|
||||||
}
|
|
||||||
|
|
||||||
const origOnNodeCreated = nodeType.prototype.onNodeCreated;
|
|
||||||
nodeType.prototype.onNodeCreated = function () {
|
|
||||||
origOnNodeCreated?.apply(this, arguments);
|
|
||||||
this._configured = false;
|
|
||||||
|
|
||||||
// Hide synced config widgets — index stays visible, user wires it from loop node
|
|
||||||
for (const name of ["manager_url", "project_name", "file_name", "sequence_number"]) {
|
|
||||||
const w = this.widgets?.find(w => w.name === name);
|
|
||||||
if (w) hideWidget(w);
|
|
||||||
}
|
|
||||||
|
|
||||||
const node = this;
|
|
||||||
const sourceLabels = this._getSourceLabels?.() || [];
|
|
||||||
const srcCombo = replaceWithCombo(this, "source_label", sourceLabels, function (value) {
|
|
||||||
node._syncFromSource();
|
|
||||||
node._refreshKeys();
|
|
||||||
});
|
|
||||||
if (srcCombo) srcCombo.value = sourceLabels[0] || "";
|
|
||||||
|
|
||||||
const keyCombo = replaceWithCombo(this, "key_name", [], function (value) {
|
|
||||||
node.title = value ? `Resolution: ${value}` : "Project Resolution";
|
|
||||||
app.graph?.setDirtyCanvas(true, true);
|
|
||||||
});
|
|
||||||
if (keyCombo && !keyCombo.value) keyCombo.value = "resolutions";
|
|
||||||
|
|
||||||
queueMicrotask(() => {
|
|
||||||
if (!this._configured) {
|
|
||||||
this.setSize(this.computeSize());
|
|
||||||
}
|
|
||||||
});
|
|
||||||
};
|
|
||||||
|
|
||||||
nodeType.prototype._getSourceLabels = function () {
|
|
||||||
const seen = new Set();
|
|
||||||
const labels = [];
|
|
||||||
if (!this.graph) return labels;
|
|
||||||
for (const node of this.graph._nodes) {
|
|
||||||
if (node.type === "ProjectSource") {
|
|
||||||
const lw = node.widgets?.find(w => w.name === "label");
|
|
||||||
if (lw?.value && !seen.has(lw.value)) {
|
|
||||||
seen.add(lw.value);
|
|
||||||
labels.push(lw.value);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return labels;
|
|
||||||
};
|
|
||||||
|
|
||||||
nodeType.prototype._findSource = function (label) {
|
|
||||||
if (!this.graph || !label) return null;
|
|
||||||
for (const node of this.graph._nodes) {
|
|
||||||
if (node.type === "ProjectSource") {
|
|
||||||
const lw = node.widgets?.find(w => w.name === "label");
|
|
||||||
if (lw?.value === label) return node;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return null;
|
|
||||||
};
|
|
||||||
|
|
||||||
nodeType.prototype._syncFromSource = function () {
|
|
||||||
const srcWidget = this.widgets?.find(w => w.name === "source_label");
|
|
||||||
const source = this._findSource(srcWidget?.value);
|
|
||||||
if (!source) return;
|
|
||||||
for (const name of ["manager_url", "project_name", "file_name", "sequence_number"]) {
|
|
||||||
const dst = this.widgets?.find(w => w.name === name);
|
|
||||||
const src = source.widgets?.find(w => w.name === name);
|
|
||||||
if (dst && src) dst.value = src.value;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
nodeType.prototype._refreshKeys = async function () {
|
|
||||||
const urlW = this.widgets?.find(w => w.name === "manager_url");
|
|
||||||
const projW = this.widgets?.find(w => w.name === "project_name");
|
|
||||||
const fileW = this.widgets?.find(w => w.name === "file_name");
|
|
||||||
const seqW = this.widgets?.find(w => w.name === "sequence_number");
|
|
||||||
if (!urlW?.value || !projW?.value || !fileW?.value) return;
|
|
||||||
|
|
||||||
try {
|
|
||||||
const resp = await api.fetchApi(
|
|
||||||
`/json_manager/get_project_keys?url=${encodeURIComponent(urlW.value)}&project=${encodeURIComponent(projW.value)}&file=${encodeURIComponent(fileW.value)}&seq=${seqW?.value || 1}`
|
|
||||||
);
|
|
||||||
if (!resp.ok) return;
|
|
||||||
const data = await resp.json();
|
|
||||||
if (data.error || !Array.isArray(data.keys)) return;
|
|
||||||
|
|
||||||
const keyWidget = this.widgets?.find(w => w.name === "key_name");
|
|
||||||
if (keyWidget) {
|
|
||||||
keyWidget.options.values = data.keys.length > 0 ? data.keys : [""];
|
|
||||||
}
|
|
||||||
} catch (e) {
|
|
||||||
console.error("[ProjectResolution] Failed to refresh keys:", e);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
const origOnMouseDown = nodeType.prototype.onMouseDown;
|
|
||||||
nodeType.prototype.onMouseDown = function (e, localPos, graphCanvas) {
|
|
||||||
origOnMouseDown?.apply(this, arguments);
|
|
||||||
const srcWidget = this.widgets?.find(w => w.name === "source_label");
|
|
||||||
if (srcWidget) srcWidget.options.values = this._getSourceLabels();
|
|
||||||
this._syncFromSource();
|
|
||||||
};
|
|
||||||
|
|
||||||
const origOnConfigure = nodeType.prototype.onConfigure;
|
|
||||||
nodeType.prototype.onConfigure = function (info) {
|
|
||||||
origOnConfigure?.apply(this, arguments);
|
|
||||||
this._configured = true;
|
|
||||||
|
|
||||||
for (const name of ["manager_url", "project_name", "file_name", "sequence_number"]) {
|
|
||||||
const w = this.widgets?.find(w => w.name === name);
|
|
||||||
if (w) hideWidget(w);
|
|
||||||
}
|
|
||||||
|
|
||||||
const srcWidget = this.widgets?.find(w => w.name === "source_label");
|
|
||||||
if (srcWidget && srcWidget.type !== "combo") {
|
|
||||||
const node = this;
|
|
||||||
replaceWithCombo(this, "source_label", this._getSourceLabels(), function (value) {
|
|
||||||
node._syncFromSource();
|
|
||||||
node._refreshKeys();
|
|
||||||
});
|
|
||||||
} else if (srcWidget) {
|
|
||||||
srcWidget.options.values = this._getSourceLabels();
|
|
||||||
}
|
|
||||||
|
|
||||||
const keyWidget = this.widgets?.find(w => w.name === "key_name");
|
|
||||||
if (keyWidget && keyWidget.type !== "combo") {
|
|
||||||
const node = this;
|
|
||||||
replaceWithCombo(this, "key_name", [], function (value) {
|
|
||||||
node.title = value ? `Resolution: ${value}` : "Project Resolution";
|
|
||||||
app.graph?.setDirtyCanvas(true, true);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
const finalKeyWidget = this.widgets?.find(w => w.name === "key_name");
|
|
||||||
if (finalKeyWidget?.value) {
|
|
||||||
this.title = `Resolution: ${finalKeyWidget.value}`;
|
|
||||||
}
|
|
||||||
|
|
||||||
this.setSize(this.computeSize());
|
|
||||||
|
|
||||||
const node = this;
|
|
||||||
queueMicrotask(() => {
|
|
||||||
node._syncFromSource();
|
|
||||||
node._refreshKeys();
|
|
||||||
});
|
|
||||||
};
|
|
||||||
},
|
|
||||||
});
|
|
||||||
@@ -1,200 +0,0 @@
|
|||||||
import { app } from "../../scripts/app.js";
|
|
||||||
import { api } from "../../scripts/api.js";
|
|
||||||
|
|
||||||
app.registerExtension({
|
|
||||||
name: "json.manager.project.source",
|
|
||||||
|
|
||||||
async beforeRegisterNodeDef(nodeType, nodeData, app) {
|
|
||||||
if (nodeData.name !== "ProjectSource") return;
|
|
||||||
|
|
||||||
// Helper: replace a STRING widget with a proper combo widget
|
|
||||||
function replaceWithCombo(node, name, values, callback) {
|
|
||||||
const idx = node.widgets?.findIndex(w => w.name === name);
|
|
||||||
if (idx === -1 || idx === undefined) return null;
|
|
||||||
const oldWidget = node.widgets[idx];
|
|
||||||
const savedValue = oldWidget.value || "";
|
|
||||||
const comboValues = values.length > 0 ? values : [""];
|
|
||||||
// Always preserve saved value (may not be in list yet)
|
|
||||||
if (savedValue && !comboValues.includes(savedValue)) {
|
|
||||||
comboValues.unshift(savedValue);
|
|
||||||
}
|
|
||||||
const defaultValue = savedValue || comboValues[0];
|
|
||||||
node.widgets.splice(idx, 1);
|
|
||||||
const combo = node.addWidget("combo", name, defaultValue, callback, { values: comboValues });
|
|
||||||
if (node.widgets.length > 1) {
|
|
||||||
node.widgets.splice(node.widgets.length - 1, 1);
|
|
||||||
node.widgets.splice(idx, 0, combo);
|
|
||||||
}
|
|
||||||
return combo;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Fetch active project from Manager and update project_name + title
|
|
||||||
async function refreshActiveProject(node) {
|
|
||||||
const urlW = node.widgets?.find(w => w.name === "manager_url");
|
|
||||||
if (!urlW?.value) return;
|
|
||||||
try {
|
|
||||||
const resp = await fetch(`${urlW.value}/api/active-project`);
|
|
||||||
if (!resp.ok) return;
|
|
||||||
const data = await resp.json();
|
|
||||||
const project = data.project || "";
|
|
||||||
const projW = node.widgets?.find(w => w.name === "project_name");
|
|
||||||
if (projW && projW.value !== project) {
|
|
||||||
projW.value = project;
|
|
||||||
await refreshFiles(node);
|
|
||||||
}
|
|
||||||
_updateTitle(node);
|
|
||||||
} catch (e) {
|
|
||||||
console.warn("[ProjectSource] Failed to fetch active project:", e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
function _updateTitle(node) {
|
|
||||||
const labelW = node.widgets?.find(w => w.name === "label");
|
|
||||||
const projW = node.widgets?.find(w => w.name === "project_name");
|
|
||||||
const label = labelW?.value || "";
|
|
||||||
const project = projW?.value || "?";
|
|
||||||
node.title = label ? `Source: ${label} [${project}]` : `Project Source [${project}]`;
|
|
||||||
app.graph?.setDirtyCanvas(true, true);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Fetch file list from API and update file_name combo
|
|
||||||
async function refreshFiles(node) {
|
|
||||||
const urlW = node.widgets?.find(w => w.name === "manager_url");
|
|
||||||
const projW = node.widgets?.find(w => w.name === "project_name");
|
|
||||||
if (!urlW?.value || !projW?.value) return;
|
|
||||||
|
|
||||||
try {
|
|
||||||
const resp = await api.fetchApi(
|
|
||||||
`/json_manager/list_project_files?url=${encodeURIComponent(urlW.value)}&project=${encodeURIComponent(projW.value)}`
|
|
||||||
);
|
|
||||||
if (!resp.ok) return;
|
|
||||||
const data = await resp.json();
|
|
||||||
const fileList = (data.files || []).map(f => f.name || f);
|
|
||||||
console.log(`[ProjectSource] refreshFiles: got ${fileList.length} files:`, fileList);
|
|
||||||
|
|
||||||
const fileW = node.widgets?.find(w => w.name === "file_name");
|
|
||||||
if (fileW) {
|
|
||||||
const currentValue = fileW.value;
|
|
||||||
fileW.options.values = fileList.length > 0 ? fileList : [""];
|
|
||||||
// Keep current selection if still valid
|
|
||||||
if (currentValue && fileList.includes(currentValue)) {
|
|
||||||
fileW.value = currentValue;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} catch (e) {
|
|
||||||
console.error("[ProjectSource] Failed to refresh files:", e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Notify all ProjectKey nodes referencing this source to re-sync
|
|
||||||
function notifyRelays(sourceNode) {
|
|
||||||
if (!sourceNode.graph?._nodes) return;
|
|
||||||
const labelW = sourceNode.widgets?.find(w => w.name === "label");
|
|
||||||
if (!labelW?.value) return;
|
|
||||||
console.log(`[ProjectSource] notifyRelays: label="${labelW.value}", scanning ${sourceNode.graph._nodes.length} nodes`);
|
|
||||||
let matched = 0;
|
|
||||||
for (const node of sourceNode.graph._nodes) {
|
|
||||||
if (node.type === "ProjectKey" && node._syncFromSource && node._refreshKeys) {
|
|
||||||
const srcW = node.widgets?.find(w => w.name === "source_label");
|
|
||||||
console.log(`[ProjectSource] ProjectKey id=${node.id} source_label="${srcW?.value}"`);
|
|
||||||
if (srcW?.value === labelW.value) {
|
|
||||||
matched++;
|
|
||||||
node._syncFromSource();
|
|
||||||
node._refreshKeys();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
console.log(`[ProjectSource] notifyRelays: matched ${matched} relays`);
|
|
||||||
}
|
|
||||||
|
|
||||||
const origOnNodeCreated = nodeType.prototype.onNodeCreated;
|
|
||||||
nodeType.prototype.onNodeCreated = function () {
|
|
||||||
origOnNodeCreated?.apply(this, arguments);
|
|
||||||
|
|
||||||
const node = this;
|
|
||||||
|
|
||||||
// Hide project_name — it is auto-filled from the Manager's active project
|
|
||||||
const projW = this.widgets?.find(w => w.name === "project_name");
|
|
||||||
if (projW) {
|
|
||||||
if (projW.origType === undefined) projW.origType = projW.type;
|
|
||||||
projW.type = "hidden";
|
|
||||||
projW.hidden = true;
|
|
||||||
projW.computeSize = () => [0, -4];
|
|
||||||
}
|
|
||||||
|
|
||||||
// Replace file_name STRING with a combo
|
|
||||||
replaceWithCombo(this, "file_name", [], function (value) {
|
|
||||||
notifyRelays(node);
|
|
||||||
});
|
|
||||||
|
|
||||||
// Hook manager_url to refresh active project + files + notify relays
|
|
||||||
const urlW = this.widgets?.find(w => w.name === "manager_url");
|
|
||||||
if (urlW) {
|
|
||||||
const origCb = urlW.callback;
|
|
||||||
urlW.callback = function (...args) {
|
|
||||||
origCb?.apply(this, args);
|
|
||||||
refreshActiveProject(node).then(() => notifyRelays(node));
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
// Hook sequence_number to notify relays
|
|
||||||
const seqW = this.widgets?.find(w => w.name === "sequence_number");
|
|
||||||
if (seqW) {
|
|
||||||
const origCb = seqW.callback;
|
|
||||||
seqW.callback = function (...args) {
|
|
||||||
origCb?.apply(this, args);
|
|
||||||
notifyRelays(node);
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
// Update title when label changes
|
|
||||||
const labelWidget = this.widgets?.find(w => w.name === "label");
|
|
||||||
if (labelWidget) {
|
|
||||||
const origCallback = labelWidget.callback;
|
|
||||||
labelWidget.callback = function (...args) {
|
|
||||||
origCallback?.apply(this, args);
|
|
||||||
_updateTitle(node);
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
// Auto-fetch active project on creation
|
|
||||||
queueMicrotask(() => refreshActiveProject(node));
|
|
||||||
};
|
|
||||||
|
|
||||||
const origOnConfigure = nodeType.prototype.onConfigure;
|
|
||||||
nodeType.prototype.onConfigure = function (info) {
|
|
||||||
origOnConfigure?.apply(this, arguments);
|
|
||||||
|
|
||||||
// Hide project_name (may have been serialized as visible)
|
|
||||||
const projW = this.widgets?.find(w => w.name === "project_name");
|
|
||||||
if (projW) {
|
|
||||||
if (projW.origType === undefined) projW.origType = projW.type;
|
|
||||||
projW.type = "hidden";
|
|
||||||
projW.hidden = true;
|
|
||||||
projW.computeSize = () => [0, -4];
|
|
||||||
}
|
|
||||||
|
|
||||||
// Ensure file_name is a combo (may be STRING from serialization)
|
|
||||||
const fileW = this.widgets?.find(w => w.name === "file_name");
|
|
||||||
if (fileW && fileW.type !== "combo") {
|
|
||||||
const node = this;
|
|
||||||
replaceWithCombo(this, "file_name", [], function (value) {
|
|
||||||
notifyRelays(node);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
_updateTitle(this);
|
|
||||||
|
|
||||||
// Deferred: fetch active project (and files) once graph is ready
|
|
||||||
const node = this;
|
|
||||||
queueMicrotask(() => refreshActiveProject(node));
|
|
||||||
};
|
|
||||||
|
|
||||||
// Re-check active project on click (picks up changes made in the Manager)
|
|
||||||
const origOnMouseDown = nodeType.prototype.onMouseDown;
|
|
||||||
nodeType.prototype.onMouseDown = function (e, localPos, graphCanvas) {
|
|
||||||
origOnMouseDown?.apply(this, arguments);
|
|
||||||
refreshActiveProject(this);
|
|
||||||
};
|
|
||||||
},
|
|
||||||
});
|
|
||||||
Reference in New Issue
Block a user