feat: extend prompt handler and add /nodes-stats/models endpoint
This commit is contained in:
+26
-4
@@ -4,7 +4,7 @@ import threading
|
|||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
from server import PromptServer
|
from server import PromptServer
|
||||||
|
|
||||||
from .mapper import NodePackageMapper
|
from .mapper import NodePackageMapper, ModelMapper
|
||||||
from .tracker import UsageTracker
|
from .tracker import UsageTracker
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -14,10 +14,11 @@ WEB_DIRECTORY = "./js"
|
|||||||
|
|
||||||
mapper = NodePackageMapper()
|
mapper = NodePackageMapper()
|
||||||
tracker = UsageTracker()
|
tracker = UsageTracker()
|
||||||
|
model_mapper = ModelMapper()
|
||||||
|
|
||||||
|
|
||||||
def on_prompt_handler(json_data):
|
def on_prompt_handler(json_data):
|
||||||
"""Called on every prompt submission. Extracts class_types and records usage."""
|
"""Called on every prompt submission. Extracts class_types and queues recording."""
|
||||||
try:
|
try:
|
||||||
prompt = json_data.get("prompt", {})
|
prompt = json_data.get("prompt", {})
|
||||||
class_types = set()
|
class_types = set()
|
||||||
@@ -26,9 +27,11 @@ def on_prompt_handler(json_data):
|
|||||||
if ct:
|
if ct:
|
||||||
class_types.add(ct)
|
class_types.add(ct)
|
||||||
if class_types:
|
if class_types:
|
||||||
|
# Pass the full prompt to the thread — model extraction (which calls
|
||||||
|
# INPUT_TYPES() on every node) happens off the main request thread.
|
||||||
threading.Thread(
|
threading.Thread(
|
||||||
target=tracker.record_usage,
|
target=_record_prompt,
|
||||||
args=(class_types, mapper),
|
args=(class_types, prompt),
|
||||||
daemon=True,
|
daemon=True,
|
||||||
).start()
|
).start()
|
||||||
except Exception:
|
except Exception:
|
||||||
@@ -36,6 +39,13 @@ def on_prompt_handler(json_data):
|
|||||||
return json_data
|
return json_data
|
||||||
|
|
||||||
|
|
||||||
|
def _record_prompt(class_types, prompt):
|
||||||
|
tracker.record_usage(class_types, mapper)
|
||||||
|
models = model_mapper.extract_models_from_prompt(prompt)
|
||||||
|
if models:
|
||||||
|
tracker.record_model_usage(models)
|
||||||
|
|
||||||
|
|
||||||
PromptServer.instance.add_on_prompt_handler(on_prompt_handler)
|
PromptServer.instance.add_on_prompt_handler(on_prompt_handler)
|
||||||
|
|
||||||
|
|
||||||
@@ -62,11 +72,23 @@ async def get_node_stats(request):
|
|||||||
return web.json_response({"error": "internal error"}, status=500)
|
return web.json_response({"error": "internal error"}, status=500)
|
||||||
|
|
||||||
|
|
||||||
|
@routes.get("/nodes-stats/models")
|
||||||
|
async def get_model_stats(request):
|
||||||
|
try:
|
||||||
|
installed_by_type = model_mapper.get_all_models()
|
||||||
|
stats = tracker.get_model_stats(installed_by_type)
|
||||||
|
return web.json_response(stats)
|
||||||
|
except Exception:
|
||||||
|
logger.error("nodes-stats: error getting model stats", exc_info=True)
|
||||||
|
return web.json_response({"error": "internal error"}, status=500)
|
||||||
|
|
||||||
|
|
||||||
@routes.post("/nodes-stats/reset")
|
@routes.post("/nodes-stats/reset")
|
||||||
async def reset_stats(request):
|
async def reset_stats(request):
|
||||||
try:
|
try:
|
||||||
tracker.reset()
|
tracker.reset()
|
||||||
mapper.invalidate()
|
mapper.invalidate()
|
||||||
|
model_mapper.invalidate()
|
||||||
return web.json_response({"status": "ok"})
|
return web.json_response({"status": "ok"})
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.error("nodes-stats: error resetting stats", exc_info=True)
|
logger.error("nodes-stats: error resetting stats", exc_info=True)
|
||||||
|
|||||||
Reference in New Issue
Block a user