Add SQLite project database + ComfyUI connector nodes
- db.py: ProjectDB class with SQLite schema (projects, data_files, sequences, history_trees), WAL mode, CRUD, import, and query helpers - api_routes.py: REST API endpoints on NiceGUI/FastAPI for ComfyUI to query project data over the network - project_loader.py: ComfyUI nodes (ProjectLoaderDynamic, Standard, VACE, LoRA) that fetch data from NiceGUI REST API via HTTP - web/project_dynamic.js: Frontend JS for dynamic project loader node - tab_projects_ng.py: Projects management tab in NiceGUI UI - state.py: Added db, current_project, db_enabled fields - main.py: DB init, API route registration, projects tab - utils.py: sync_to_db() dual-write helper - tab_batch_ng.py, tab_raw_ng.py, tab_timeline_ng.py: dual-write sync calls after save_json when project DB is enabled - __init__.py: Merged project node class mappings - tests/test_db.py: 30 tests for database layer - tests/test_project_loader.py: 17 tests for ComfyUI connector nodes Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
201
tests/test_project_loader.py
Normal file
201
tests/test_project_loader.py
Normal file
@@ -0,0 +1,201 @@
|
||||
import json
|
||||
from unittest.mock import patch, MagicMock
|
||||
from io import BytesIO
|
||||
|
||||
import pytest
|
||||
|
||||
from project_loader import (
|
||||
ProjectLoaderDynamic,
|
||||
ProjectLoaderStandard,
|
||||
ProjectLoaderVACE,
|
||||
ProjectLoaderLoRA,
|
||||
_fetch_json,
|
||||
_fetch_data,
|
||||
_fetch_keys,
|
||||
MAX_DYNAMIC_OUTPUTS,
|
||||
)
|
||||
|
||||
|
||||
def _mock_urlopen(data: dict):
|
||||
"""Create a mock context manager for urllib.request.urlopen."""
|
||||
response = MagicMock()
|
||||
response.read.return_value = json.dumps(data).encode()
|
||||
response.__enter__ = lambda s: s
|
||||
response.__exit__ = MagicMock(return_value=False)
|
||||
return response
|
||||
|
||||
|
||||
class TestFetchHelpers:
|
||||
def test_fetch_json_success(self):
|
||||
data = {"key": "value"}
|
||||
with patch("project_loader.urllib.request.urlopen", return_value=_mock_urlopen(data)):
|
||||
result = _fetch_json("http://example.com/api")
|
||||
assert result == data
|
||||
|
||||
def test_fetch_json_failure(self):
|
||||
import urllib.error
|
||||
with patch("project_loader.urllib.request.urlopen", side_effect=OSError("connection refused")):
|
||||
result = _fetch_json("http://example.com/api")
|
||||
assert result == {}
|
||||
|
||||
def test_fetch_data_builds_url(self):
|
||||
data = {"prompt": "hello"}
|
||||
with patch("project_loader.urllib.request.urlopen", return_value=_mock_urlopen(data)) as mock:
|
||||
result = _fetch_data("http://localhost:8080", "proj1", "batch_i2v", 1)
|
||||
assert result == data
|
||||
called_url = mock.call_args[0][0]
|
||||
assert "/api/projects/proj1/files/batch_i2v/data?seq=1" in called_url
|
||||
|
||||
def test_fetch_keys_builds_url(self):
|
||||
data = {"keys": ["prompt"], "types": ["STRING"]}
|
||||
with patch("project_loader.urllib.request.urlopen", return_value=_mock_urlopen(data)) as mock:
|
||||
result = _fetch_keys("http://localhost:8080", "proj1", "batch_i2v", 1)
|
||||
assert result == data
|
||||
called_url = mock.call_args[0][0]
|
||||
assert "/api/projects/proj1/files/batch_i2v/keys?seq=1" in called_url
|
||||
|
||||
def test_fetch_data_strips_trailing_slash(self):
|
||||
data = {"prompt": "hello"}
|
||||
with patch("project_loader.urllib.request.urlopen", return_value=_mock_urlopen(data)) as mock:
|
||||
_fetch_data("http://localhost:8080/", "proj1", "file1", 1)
|
||||
called_url = mock.call_args[0][0]
|
||||
assert "//api" not in called_url
|
||||
|
||||
|
||||
class TestProjectLoaderDynamic:
|
||||
def test_load_dynamic_with_keys(self):
|
||||
data = {"prompt": "hello", "seed": 42, "cfg": 1.5}
|
||||
node = ProjectLoaderDynamic()
|
||||
with patch("project_loader._fetch_data", return_value=data):
|
||||
result = node.load_dynamic(
|
||||
"http://localhost:8080", "proj1", "batch_i2v", 1,
|
||||
output_keys="prompt,seed,cfg"
|
||||
)
|
||||
assert result[0] == "hello"
|
||||
assert result[1] == 42
|
||||
assert result[2] == 1.5
|
||||
assert len(result) == MAX_DYNAMIC_OUTPUTS
|
||||
|
||||
def test_load_dynamic_empty_keys(self):
|
||||
node = ProjectLoaderDynamic()
|
||||
with patch("project_loader._fetch_data", return_value={"prompt": "hello"}):
|
||||
result = node.load_dynamic(
|
||||
"http://localhost:8080", "proj1", "batch_i2v", 1,
|
||||
output_keys=""
|
||||
)
|
||||
assert all(v == "" for v in result)
|
||||
|
||||
def test_load_dynamic_missing_key(self):
|
||||
node = ProjectLoaderDynamic()
|
||||
with patch("project_loader._fetch_data", return_value={"prompt": "hello"}):
|
||||
result = node.load_dynamic(
|
||||
"http://localhost:8080", "proj1", "batch_i2v", 1,
|
||||
output_keys="nonexistent"
|
||||
)
|
||||
assert result[0] == ""
|
||||
|
||||
def test_load_dynamic_bool_becomes_string(self):
|
||||
node = ProjectLoaderDynamic()
|
||||
with patch("project_loader._fetch_data", return_value={"flag": True}):
|
||||
result = node.load_dynamic(
|
||||
"http://localhost:8080", "proj1", "batch_i2v", 1,
|
||||
output_keys="flag"
|
||||
)
|
||||
assert result[0] == "true"
|
||||
|
||||
def test_input_types_has_manager_url(self):
|
||||
inputs = ProjectLoaderDynamic.INPUT_TYPES()
|
||||
assert "manager_url" in inputs["required"]
|
||||
assert "project_name" in inputs["required"]
|
||||
assert "file_name" in inputs["required"]
|
||||
assert "sequence_number" in inputs["required"]
|
||||
|
||||
def test_category(self):
|
||||
assert ProjectLoaderDynamic.CATEGORY == "utils/json/project"
|
||||
|
||||
|
||||
class TestProjectLoaderStandard:
|
||||
def test_load_standard(self):
|
||||
data = {
|
||||
"general_prompt": "hello",
|
||||
"general_negative": "bad",
|
||||
"current_prompt": "specific",
|
||||
"negative": "neg",
|
||||
"camera": "pan",
|
||||
"flf": 0.5,
|
||||
"seed": 42,
|
||||
"video file path": "/v.mp4",
|
||||
"reference image path": "/r.png",
|
||||
"flf image path": "/f.png",
|
||||
}
|
||||
node = ProjectLoaderStandard()
|
||||
with patch("project_loader._fetch_data", return_value=data):
|
||||
result = node.load_standard("http://localhost:8080", "proj1", "batch", 1)
|
||||
assert result == ("hello", "bad", "specific", "neg", "pan", 0.5, 42, "/v.mp4", "/r.png", "/f.png")
|
||||
|
||||
def test_load_standard_defaults(self):
|
||||
node = ProjectLoaderStandard()
|
||||
with patch("project_loader._fetch_data", return_value={}):
|
||||
result = node.load_standard("http://localhost:8080", "proj1", "batch", 1)
|
||||
assert result[0] == "" # general_prompt
|
||||
assert result[5] == 0.0 # flf
|
||||
assert result[6] == 0 # seed
|
||||
|
||||
|
||||
class TestProjectLoaderVACE:
|
||||
def test_load_vace(self):
|
||||
data = {
|
||||
"general_prompt": "hello",
|
||||
"general_negative": "bad",
|
||||
"current_prompt": "specific",
|
||||
"negative": "neg",
|
||||
"camera": "pan",
|
||||
"flf": 0.5,
|
||||
"seed": 42,
|
||||
"frame_to_skip": 81,
|
||||
"input_a_frames": 16,
|
||||
"input_b_frames": 16,
|
||||
"reference path": "/ref",
|
||||
"reference switch": 1,
|
||||
"vace schedule": 2,
|
||||
"video file path": "/v.mp4",
|
||||
"reference image path": "/r.png",
|
||||
}
|
||||
node = ProjectLoaderVACE()
|
||||
with patch("project_loader._fetch_data", return_value=data):
|
||||
result = node.load_vace("http://localhost:8080", "proj1", "batch", 1)
|
||||
assert result[7] == 81 # frame_to_skip
|
||||
assert result[12] == 2 # vace_schedule
|
||||
|
||||
|
||||
class TestProjectLoaderLoRA:
|
||||
def test_load_loras(self):
|
||||
data = {
|
||||
"lora 1 high": "<lora:model1:1.0>",
|
||||
"lora 1 low": "<lora:model1:0.5>",
|
||||
"lora 2 high": "",
|
||||
"lora 2 low": "",
|
||||
"lora 3 high": "",
|
||||
"lora 3 low": "",
|
||||
}
|
||||
node = ProjectLoaderLoRA()
|
||||
with patch("project_loader._fetch_data", return_value=data):
|
||||
result = node.load_loras("http://localhost:8080", "proj1", "batch", 1)
|
||||
assert result[0] == "<lora:model1:1.0>"
|
||||
assert result[1] == "<lora:model1:0.5>"
|
||||
|
||||
def test_load_loras_empty(self):
|
||||
node = ProjectLoaderLoRA()
|
||||
with patch("project_loader._fetch_data", return_value={}):
|
||||
result = node.load_loras("http://localhost:8080", "proj1", "batch", 1)
|
||||
assert all(v == "" for v in result)
|
||||
|
||||
|
||||
class TestNodeMappings:
|
||||
def test_mappings_exist(self):
|
||||
from project_loader import PROJECT_NODE_CLASS_MAPPINGS, PROJECT_NODE_DISPLAY_NAME_MAPPINGS
|
||||
assert "ProjectLoaderDynamic" in PROJECT_NODE_CLASS_MAPPINGS
|
||||
assert "ProjectLoaderStandard" in PROJECT_NODE_CLASS_MAPPINGS
|
||||
assert "ProjectLoaderVACE" in PROJECT_NODE_CLASS_MAPPINGS
|
||||
assert "ProjectLoaderLoRA" in PROJECT_NODE_CLASS_MAPPINGS
|
||||
assert len(PROJECT_NODE_DISPLAY_NAME_MAPPINGS) == 4
|
||||
Reference in New Issue
Block a user