diff --git a/mapper.py b/mapper.py index 66a7657..81e8d22 100644 --- a/mapper.py +++ b/mapper.py @@ -66,3 +66,116 @@ class NodePackageMapper: def invalidate(self): """Force rebuild on next access (e.g. after node reload).""" self._map = None + + +# Folder types that are not model files and should not be tracked +EXCLUDED_FOLDER_TYPES = { + "loras", + "configs", + "custom_nodes", + "temp", + "output", + "input", + "annotators", + "assets", +} + + +class ModelMapper: + """Tracks which folder_paths model types exist and resolves filenames to types.""" + + def __init__(self): + self._folder_files = None # {folder_type: frozenset(filenames)} + self._reverse = None # {filename: folder_type} + + def _build(self): + try: + import folder_paths + + self._folder_files = {} + for folder_type in folder_paths.folder_names_and_paths: + if folder_type in EXCLUDED_FOLDER_TYPES: + continue + try: + files = folder_paths.get_filename_list(folder_type) + except Exception: + files = [] + if files: + self._folder_files[folder_type] = frozenset(files) + + # Reverse map: filename -> folder_type (last write wins on collision) + self._reverse = {} + for folder_type, files in self._folder_files.items(): + for f in files: + self._reverse[f] = folder_type + + except Exception: + logger.warning("ModelMapper: failed to build model map", exc_info=True) + self._folder_files = {} + self._reverse = {} + + def _ensure(self): + if self._folder_files is None: + self._build() + + def get_model_type(self, filename): + """Return the folder type for a filename, or None if not tracked.""" + self._ensure() + return self._reverse.get(filename) + + def get_all_models(self): + """Return {folder_type: [filename, ...]} for all tracked types.""" + self._ensure() + return {k: sorted(v) for k, v in self._folder_files.items()} + + def extract_models_from_prompt(self, prompt): + """Scan a prompt dict and return (model_name, model_type) pairs. + + For each node, inspects INPUT_TYPES() to find list-type (folder dropdown) + inputs, then resolves the selected value against the folder_paths reverse map. + """ + self._ensure() + try: + import nodes as comfy_nodes + except ImportError: + return [] + + seen = set() + results = [] + + for node_data in prompt.values(): + class_type = node_data.get("class_type") + node_inputs = node_data.get("inputs", {}) + if not class_type or not node_inputs: + continue + + node_cls = comfy_nodes.NODE_CLASS_MAPPINGS.get(class_type) + if node_cls is None: + continue + + try: + input_types = node_cls.INPUT_TYPES() + except Exception: + continue + + for category in ("required", "optional"): + for input_name, input_def in input_types.get(category, {}).items(): + if not isinstance(input_def, (list, tuple)) or not input_def: + continue + # ComfyUI folder dropdowns have a list as their type + if not isinstance(input_def[0], list): + continue + value = node_inputs.get(input_name) + if not isinstance(value, str) or value in seen: + continue + model_type = self.get_model_type(value) + if model_type: + seen.add(value) + results.append((value, model_type)) + + return results + + def invalidate(self): + """Force rebuild on next access.""" + self._folder_files = None + self._reverse = None diff --git a/tests/test_model_mapper.py b/tests/test_model_mapper.py new file mode 100644 index 0000000..0df081a --- /dev/null +++ b/tests/test_model_mapper.py @@ -0,0 +1,96 @@ +import pytest +from unittest.mock import patch, MagicMock +from mapper import ModelMapper + + +FAKE_FOLDER_NAMES = { + "checkpoints": ([], {}), + "vae": ([], {}), + "loras": ([], {}), + "configs": ([], {}), +} + +FAKE_FILES = { + "checkpoints": ["dream.safetensors", "v15.ckpt"], + "vae": ["vae.safetensors"], + "loras": ["style.safetensors"], +} + + +def _make_mapper(): + # conftest.py already put a MagicMock in sys.modules["folder_paths"], + # so we can configure it directly here. + import folder_paths as fp + fp.folder_names_and_paths = FAKE_FOLDER_NAMES + fp.get_filename_list.side_effect = lambda t: FAKE_FILES.get(t, []) + m = ModelMapper() + m._build() + return m + + +def test_get_model_type_known(): + m = _make_mapper() + assert m.get_model_type("dream.safetensors") == "checkpoints" + assert m.get_model_type("vae.safetensors") == "vae" + + +def test_loras_excluded(): + m = _make_mapper() + assert m.get_model_type("style.safetensors") is None + + +def test_get_all_models(): + m = _make_mapper() + all_models = m.get_all_models() + assert "checkpoints" in all_models + assert "vae" in all_models + assert "loras" not in all_models + assert "dream.safetensors" in all_models["checkpoints"] + + +def test_unknown_filename_returns_none(): + m = _make_mapper() + assert m.get_model_type("nonexistent.ckpt") is None + + +def test_extract_models_from_prompt(): + m = _make_mapper() + + fake_node_cls = MagicMock() + fake_node_cls.INPUT_TYPES.return_value = { + "required": { + "ckpt_name": (["dream.safetensors", "v15.ckpt"],), + "steps": ("INT", {"default": 20}), + } + } + + fake_prompt = { + "1": { + "class_type": "CheckpointLoaderSimple", + "inputs": {"ckpt_name": "dream.safetensors", "steps": 20}, + } + } + + import nodes as comfy_nodes + comfy_nodes.NODE_CLASS_MAPPINGS = {"CheckpointLoaderSimple": fake_node_cls} + results = m.extract_models_from_prompt(fake_prompt) + + assert ("dream.safetensors", "checkpoints") in results + + +def test_extract_models_skips_non_list_inputs(): + m = _make_mapper() + + fake_node_cls = MagicMock() + fake_node_cls.INPUT_TYPES.return_value = { + "required": { + "text": ("STRING", {}), + } + } + fake_prompt = {"1": {"class_type": "CLIPTextEncode", "inputs": {"text": "hello"}}} + + import nodes as comfy_nodes + comfy_nodes.NODE_CLASS_MAPPINGS = {"CLIPTextEncode": fake_node_cls} + results = m.extract_models_from_prompt(fake_prompt) + + assert results == []