diff --git a/__init__.py b/__init__.py index 3f1c494..8f9a991 100644 --- a/__init__.py +++ b/__init__.py @@ -44,6 +44,13 @@ def _record_prompt(class_types, prompt): tracker.record_usage(class_types, mapper) except Exception: logger.warning("nodes-stats: error recording node usage", exc_info=True) + try: + packages = {mapper.get_package(ct) for ct in class_types} + packages.discard("__builtin__") + packages.discard("__unknown__") + tracker.reset_trials_for(packages) + except Exception: + logger.warning("nodes-stats: error resetting trials", exc_info=True) try: models = model_mapper.extract_models_from_prompt(prompt) if models: @@ -52,6 +59,13 @@ def _record_prompt(class_types, prompt): logger.warning("nodes-stats: error recording model usage", exc_info=True) +# Age temporary-enable trials once per process start (one "boot"). +try: + tracker.tick_boot_days() +except Exception: + logger.warning("nodes-stats: error ticking trial boot days", exc_info=True) + + PromptServer.instance.add_on_prompt_handler(on_prompt_handler) @@ -99,3 +113,40 @@ async def reset_stats(request): except Exception: logger.error("nodes-stats: error resetting stats", exc_info=True) return web.json_response({"error": "internal error"}, status=500) + + +@routes.get("/nodes-stats/trials") +async def get_trials(request): + try: + return web.json_response(tracker.get_trials()) + except Exception: + logger.error("nodes-stats: error getting trials", exc_info=True) + return web.json_response({"error": "internal error"}, status=500) + + +@routes.post("/nodes-stats/trials/start") +async def start_trial(request): + try: + data = await request.json() + package = data.get("package") + if not package: + return web.json_response({"error": "package required"}, status=400) + tracker.start_trial(package) + return web.json_response({"status": "ok"}) + except Exception: + logger.error("nodes-stats: error starting trial", exc_info=True) + return web.json_response({"error": "internal error"}, status=500) + + +@routes.post("/nodes-stats/trials/stop") +async def stop_trial(request): + try: + data = await request.json() + package = data.get("package") + if not package: + return web.json_response({"error": "package required"}, status=400) + tracker.stop_trial(package) + return web.json_response({"status": "ok"}) + except Exception: + logger.error("nodes-stats: error stopping trial", exc_info=True) + return web.json_response({"error": "internal error"}, status=500)