feat: extend prompt handler and add /nodes-stats/models endpoint

This commit is contained in:
2026-04-09 18:04:16 +02:00
parent 38e95c150a
commit ed39c5918a
+26 -4
View File
@@ -4,7 +4,7 @@ import threading
from aiohttp import web from aiohttp import web
from server import PromptServer from server import PromptServer
from .mapper import NodePackageMapper from .mapper import NodePackageMapper, ModelMapper
from .tracker import UsageTracker from .tracker import UsageTracker
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -14,10 +14,11 @@ WEB_DIRECTORY = "./js"
mapper = NodePackageMapper() mapper = NodePackageMapper()
tracker = UsageTracker() tracker = UsageTracker()
model_mapper = ModelMapper()
def on_prompt_handler(json_data): def on_prompt_handler(json_data):
"""Called on every prompt submission. Extracts class_types and records usage.""" """Called on every prompt submission. Extracts class_types and queues recording."""
try: try:
prompt = json_data.get("prompt", {}) prompt = json_data.get("prompt", {})
class_types = set() class_types = set()
@@ -26,9 +27,11 @@ def on_prompt_handler(json_data):
if ct: if ct:
class_types.add(ct) class_types.add(ct)
if class_types: if class_types:
# Pass the full prompt to the thread — model extraction (which calls
# INPUT_TYPES() on every node) happens off the main request thread.
threading.Thread( threading.Thread(
target=tracker.record_usage, target=_record_prompt,
args=(class_types, mapper), args=(class_types, prompt),
daemon=True, daemon=True,
).start() ).start()
except Exception: except Exception:
@@ -36,6 +39,13 @@ def on_prompt_handler(json_data):
return json_data return json_data
def _record_prompt(class_types, prompt):
tracker.record_usage(class_types, mapper)
models = model_mapper.extract_models_from_prompt(prompt)
if models:
tracker.record_model_usage(models)
PromptServer.instance.add_on_prompt_handler(on_prompt_handler) PromptServer.instance.add_on_prompt_handler(on_prompt_handler)
@@ -62,11 +72,23 @@ async def get_node_stats(request):
return web.json_response({"error": "internal error"}, status=500) return web.json_response({"error": "internal error"}, status=500)
@routes.get("/nodes-stats/models")
async def get_model_stats(request):
try:
installed_by_type = model_mapper.get_all_models()
stats = tracker.get_model_stats(installed_by_type)
return web.json_response(stats)
except Exception:
logger.error("nodes-stats: error getting model stats", exc_info=True)
return web.json_response({"error": "internal error"}, status=500)
@routes.post("/nodes-stats/reset") @routes.post("/nodes-stats/reset")
async def reset_stats(request): async def reset_stats(request):
try: try:
tracker.reset() tracker.reset()
mapper.invalidate() mapper.invalidate()
model_mapper.invalidate()
return web.json_response({"status": "ok"}) return web.json_response({"status": "ok"})
except Exception: except Exception:
logger.error("nodes-stats: error resetting stats", exc_info=True) logger.error("nodes-stats: error resetting stats", exc_info=True)