diff --git a/project_loader.py b/project_loader.py index c0ce0de..eaa6a88 100644 --- a/project_loader.py +++ b/project_loader.py @@ -231,13 +231,61 @@ class ProjectSource: return () +class ProjectKey: + """Single-output relay — fetches one key from a ProjectSource.""" + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "source_label": ("STRING", {"default": "", "multiline": False}), + "key_name": ("STRING", {"default": "", "multiline": False}), + "key_type": ("STRING", {"default": "STRING", "multiline": False}), + }, + "optional": { + "manager_url": ("STRING", {"default": "http://localhost:8080", "multiline": False}), + "project_name": ("STRING", {"default": "", "multiline": False}), + "file_name": ("STRING", {"default": "", "multiline": False}), + "sequence_number": ("INT", {"default": 1, "min": 1, "max": 9999}), + }, + } + + RETURN_TYPES = (any_type,) + RETURN_NAMES = ("value",) + FUNCTION = "fetch_key" + CATEGORY = "utils/json/project" + OUTPUT_NODE = False + + def fetch_key(self, source_label, key_name, key_type, + manager_url="http://localhost:8080", project_name="", + file_name="", sequence_number=1): + data = _fetch_data(manager_url, project_name, file_name, sequence_number) + if data.get("error") in ("http_error", "network_error", "parse_error"): + msg = data.get("message", "Unknown error") + raise RuntimeError(f"Failed to fetch data: {msg}") + + val = data.get(key_name, "") + + if key_type == "INT": + return (to_int(val),) + elif key_type == "FLOAT": + return (to_float(val),) + elif isinstance(val, bool): + return (str(val).lower(),) + elif isinstance(val, (int, float)): + return (val,) + else: + return (str(val),) + + # --- Mappings --- PROJECT_NODE_CLASS_MAPPINGS = { "ProjectLoaderDynamic": ProjectLoaderDynamic, "ProjectSource": ProjectSource, + "ProjectKey": ProjectKey, } PROJECT_NODE_DISPLAY_NAME_MAPPINGS = { "ProjectLoaderDynamic": "Project Loader (Dynamic)", "ProjectSource": "Project Source", + "ProjectKey": "Project Key", } diff --git a/tests/test_project_loader.py b/tests/test_project_loader.py index 1a3a98c..1caae73 100644 --- a/tests/test_project_loader.py +++ b/tests/test_project_loader.py @@ -235,9 +235,108 @@ class TestProjectSource: assert ProjectSource.CATEGORY == "utils/json/project" +class TestProjectKey: + def test_input_types(self): + from project_loader import ProjectKey + inputs = ProjectKey.INPUT_TYPES() + assert "source_label" in inputs["required"] + assert "key_name" in inputs["required"] + assert "key_type" in inputs["required"] + + def test_single_output(self): + from project_loader import ProjectKey + assert len(ProjectKey.RETURN_TYPES) == 1 + assert len(ProjectKey.RETURN_NAMES) == 1 + + def test_fetch_key_string(self): + from project_loader import ProjectKey + node = ProjectKey() + data = {"prompt": "hello", "seed": 42} + with patch("project_loader._fetch_data", return_value=data): + result = node.fetch_key( + source_label="my_source", + key_name="prompt", + key_type="STRING", + manager_url="http://localhost:8080", + project_name="proj1", + file_name="batch_i2v", + sequence_number=1, + ) + assert result == ("hello",) + + def test_fetch_key_int_coercion(self): + from project_loader import ProjectKey + node = ProjectKey() + data = {"seed": "42"} + with patch("project_loader._fetch_data", return_value=data): + result = node.fetch_key( + source_label="my_source", + key_name="seed", + key_type="INT", + manager_url="http://localhost:8080", + project_name="proj1", + file_name="batch_i2v", + sequence_number=1, + ) + assert result == (42,) + + def test_fetch_key_float_coercion(self): + from project_loader import ProjectKey + node = ProjectKey() + data = {"cfg": "1.5"} + with patch("project_loader._fetch_data", return_value=data): + result = node.fetch_key( + source_label="my_source", + key_name="cfg", + key_type="FLOAT", + manager_url="http://localhost:8080", + project_name="proj1", + file_name="batch_i2v", + sequence_number=1, + ) + assert result == (1.5,) + + def test_fetch_key_missing_key(self): + from project_loader import ProjectKey + node = ProjectKey() + with patch("project_loader._fetch_data", return_value={}): + result = node.fetch_key( + source_label="my_source", + key_name="nonexistent", + key_type="STRING", + manager_url="http://localhost:8080", + project_name="proj1", + file_name="batch_i2v", + sequence_number=1, + ) + assert result == ("",) + + def test_fetch_key_network_error(self): + from project_loader import ProjectKey + node = ProjectKey() + error_resp = {"error": "network_error", "message": "Connection refused"} + with patch("project_loader._fetch_data", return_value=error_resp): + with pytest.raises(RuntimeError, match="Failed to fetch"): + node.fetch_key( + source_label="my_source", + key_name="prompt", + key_type="STRING", + manager_url="http://localhost:8080", + project_name="proj1", + file_name="batch_i2v", + sequence_number=1, + ) + + def test_category(self): + from project_loader import ProjectKey + assert ProjectKey.CATEGORY == "utils/json/project" + + class TestNodeMappings: def test_mappings_exist(self): from project_loader import PROJECT_NODE_CLASS_MAPPINGS, PROJECT_NODE_DISPLAY_NAME_MAPPINGS assert "ProjectLoaderDynamic" in PROJECT_NODE_CLASS_MAPPINGS - assert len(PROJECT_NODE_CLASS_MAPPINGS) == 1 - assert len(PROJECT_NODE_DISPLAY_NAME_MAPPINGS) == 1 + assert "ProjectSource" in PROJECT_NODE_CLASS_MAPPINGS + assert "ProjectKey" in PROJECT_NODE_CLASS_MAPPINGS + assert len(PROJECT_NODE_CLASS_MAPPINGS) == 3 + assert len(PROJECT_NODE_DISPLAY_NAME_MAPPINGS) == 3