feat: add ModelMapper with folder_paths introspection
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -66,3 +66,116 @@ class NodePackageMapper:
|
||||
def invalidate(self):
|
||||
"""Force rebuild on next access (e.g. after node reload)."""
|
||||
self._map = None
|
||||
|
||||
|
||||
# Folder types that are not model files and should not be tracked
|
||||
EXCLUDED_FOLDER_TYPES = {
|
||||
"loras",
|
||||
"configs",
|
||||
"custom_nodes",
|
||||
"temp",
|
||||
"output",
|
||||
"input",
|
||||
"annotators",
|
||||
"assets",
|
||||
}
|
||||
|
||||
|
||||
class ModelMapper:
|
||||
"""Tracks which folder_paths model types exist and resolves filenames to types."""
|
||||
|
||||
def __init__(self):
|
||||
self._folder_files = None # {folder_type: frozenset(filenames)}
|
||||
self._reverse = None # {filename: folder_type}
|
||||
|
||||
def _build(self):
|
||||
try:
|
||||
import folder_paths
|
||||
|
||||
self._folder_files = {}
|
||||
for folder_type in folder_paths.folder_names_and_paths:
|
||||
if folder_type in EXCLUDED_FOLDER_TYPES:
|
||||
continue
|
||||
try:
|
||||
files = folder_paths.get_filename_list(folder_type)
|
||||
except Exception:
|
||||
files = []
|
||||
if files:
|
||||
self._folder_files[folder_type] = frozenset(files)
|
||||
|
||||
# Reverse map: filename -> folder_type (last write wins on collision)
|
||||
self._reverse = {}
|
||||
for folder_type, files in self._folder_files.items():
|
||||
for f in files:
|
||||
self._reverse[f] = folder_type
|
||||
|
||||
except Exception:
|
||||
logger.warning("ModelMapper: failed to build model map", exc_info=True)
|
||||
self._folder_files = {}
|
||||
self._reverse = {}
|
||||
|
||||
def _ensure(self):
|
||||
if self._folder_files is None:
|
||||
self._build()
|
||||
|
||||
def get_model_type(self, filename):
|
||||
"""Return the folder type for a filename, or None if not tracked."""
|
||||
self._ensure()
|
||||
return self._reverse.get(filename)
|
||||
|
||||
def get_all_models(self):
|
||||
"""Return {folder_type: [filename, ...]} for all tracked types."""
|
||||
self._ensure()
|
||||
return {k: sorted(v) for k, v in self._folder_files.items()}
|
||||
|
||||
def extract_models_from_prompt(self, prompt):
|
||||
"""Scan a prompt dict and return (model_name, model_type) pairs.
|
||||
|
||||
For each node, inspects INPUT_TYPES() to find list-type (folder dropdown)
|
||||
inputs, then resolves the selected value against the folder_paths reverse map.
|
||||
"""
|
||||
self._ensure()
|
||||
try:
|
||||
import nodes as comfy_nodes
|
||||
except ImportError:
|
||||
return []
|
||||
|
||||
seen = set()
|
||||
results = []
|
||||
|
||||
for node_data in prompt.values():
|
||||
class_type = node_data.get("class_type")
|
||||
node_inputs = node_data.get("inputs", {})
|
||||
if not class_type or not node_inputs:
|
||||
continue
|
||||
|
||||
node_cls = comfy_nodes.NODE_CLASS_MAPPINGS.get(class_type)
|
||||
if node_cls is None:
|
||||
continue
|
||||
|
||||
try:
|
||||
input_types = node_cls.INPUT_TYPES()
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
for category in ("required", "optional"):
|
||||
for input_name, input_def in input_types.get(category, {}).items():
|
||||
if not isinstance(input_def, (list, tuple)) or not input_def:
|
||||
continue
|
||||
# ComfyUI folder dropdowns have a list as their type
|
||||
if not isinstance(input_def[0], list):
|
||||
continue
|
||||
value = node_inputs.get(input_name)
|
||||
if not isinstance(value, str) or value in seen:
|
||||
continue
|
||||
model_type = self.get_model_type(value)
|
||||
if model_type:
|
||||
seen.add(value)
|
||||
results.append((value, model_type))
|
||||
|
||||
return results
|
||||
|
||||
def invalidate(self):
|
||||
"""Force rebuild on next access."""
|
||||
self._folder_files = None
|
||||
self._reverse = None
|
||||
|
||||
@@ -0,0 +1,96 @@
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
from mapper import ModelMapper
|
||||
|
||||
|
||||
FAKE_FOLDER_NAMES = {
|
||||
"checkpoints": ([], {}),
|
||||
"vae": ([], {}),
|
||||
"loras": ([], {}),
|
||||
"configs": ([], {}),
|
||||
}
|
||||
|
||||
FAKE_FILES = {
|
||||
"checkpoints": ["dream.safetensors", "v15.ckpt"],
|
||||
"vae": ["vae.safetensors"],
|
||||
"loras": ["style.safetensors"],
|
||||
}
|
||||
|
||||
|
||||
def _make_mapper():
|
||||
# conftest.py already put a MagicMock in sys.modules["folder_paths"],
|
||||
# so we can configure it directly here.
|
||||
import folder_paths as fp
|
||||
fp.folder_names_and_paths = FAKE_FOLDER_NAMES
|
||||
fp.get_filename_list.side_effect = lambda t: FAKE_FILES.get(t, [])
|
||||
m = ModelMapper()
|
||||
m._build()
|
||||
return m
|
||||
|
||||
|
||||
def test_get_model_type_known():
|
||||
m = _make_mapper()
|
||||
assert m.get_model_type("dream.safetensors") == "checkpoints"
|
||||
assert m.get_model_type("vae.safetensors") == "vae"
|
||||
|
||||
|
||||
def test_loras_excluded():
|
||||
m = _make_mapper()
|
||||
assert m.get_model_type("style.safetensors") is None
|
||||
|
||||
|
||||
def test_get_all_models():
|
||||
m = _make_mapper()
|
||||
all_models = m.get_all_models()
|
||||
assert "checkpoints" in all_models
|
||||
assert "vae" in all_models
|
||||
assert "loras" not in all_models
|
||||
assert "dream.safetensors" in all_models["checkpoints"]
|
||||
|
||||
|
||||
def test_unknown_filename_returns_none():
|
||||
m = _make_mapper()
|
||||
assert m.get_model_type("nonexistent.ckpt") is None
|
||||
|
||||
|
||||
def test_extract_models_from_prompt():
|
||||
m = _make_mapper()
|
||||
|
||||
fake_node_cls = MagicMock()
|
||||
fake_node_cls.INPUT_TYPES.return_value = {
|
||||
"required": {
|
||||
"ckpt_name": (["dream.safetensors", "v15.ckpt"],),
|
||||
"steps": ("INT", {"default": 20}),
|
||||
}
|
||||
}
|
||||
|
||||
fake_prompt = {
|
||||
"1": {
|
||||
"class_type": "CheckpointLoaderSimple",
|
||||
"inputs": {"ckpt_name": "dream.safetensors", "steps": 20},
|
||||
}
|
||||
}
|
||||
|
||||
import nodes as comfy_nodes
|
||||
comfy_nodes.NODE_CLASS_MAPPINGS = {"CheckpointLoaderSimple": fake_node_cls}
|
||||
results = m.extract_models_from_prompt(fake_prompt)
|
||||
|
||||
assert ("dream.safetensors", "checkpoints") in results
|
||||
|
||||
|
||||
def test_extract_models_skips_non_list_inputs():
|
||||
m = _make_mapper()
|
||||
|
||||
fake_node_cls = MagicMock()
|
||||
fake_node_cls.INPUT_TYPES.return_value = {
|
||||
"required": {
|
||||
"text": ("STRING", {}),
|
||||
}
|
||||
}
|
||||
fake_prompt = {"1": {"class_type": "CLIPTextEncode", "inputs": {"text": "hello"}}}
|
||||
|
||||
import nodes as comfy_nodes
|
||||
comfy_nodes.NODE_CLASS_MAPPINGS = {"CLIPTextEncode": fake_node_cls}
|
||||
results = m.extract_models_from_prompt(fake_prompt)
|
||||
|
||||
assert results == []
|
||||
Reference in New Issue
Block a user