diff --git a/tests/test_model_tracker.py b/tests/test_model_tracker.py new file mode 100644 index 0000000..b59eecc --- /dev/null +++ b/tests/test_model_tracker.py @@ -0,0 +1,74 @@ +import pytest +import tempfile +import os +from tracker import UsageTracker + + +@pytest.fixture +def tracker(tmp_path): + return UsageTracker(db_path=str(tmp_path / "test.db")) + + +def test_record_and_retrieve_model_usage(tracker): + tracker.record_model_usage([("dreamshaper.safetensors", "checkpoints")]) + tracker.record_model_usage([("dreamshaper.safetensors", "checkpoints")]) + + raw = tracker.get_raw_model_stats() + assert len(raw) == 1 + assert raw[0]["model_name"] == "dreamshaper.safetensors" + assert raw[0]["model_type"] == "checkpoints" + assert raw[0]["count"] == 2 + + +def test_record_multiple_models(tracker): + tracker.record_model_usage([ + ("dreamshaper.safetensors", "checkpoints"), + ("vae.safetensors", "vae"), + ]) + raw = tracker.get_raw_model_stats() + assert len(raw) == 2 + + +def test_reset_clears_model_usage(tracker): + tracker.record_model_usage([("model.safetensors", "checkpoints")]) + tracker.reset() + assert tracker.get_raw_model_stats() == [] + + +def test_empty_models_returns_empty(tracker): + assert tracker.get_raw_model_stats() == [] + + +def test_get_model_stats_used(tracker): + tracker.record_model_usage([("model.safetensors", "checkpoints")]) + installed = {"checkpoints": ["model.safetensors"]} + result = tracker.get_model_stats(installed) + assert len(result) == 1 + assert result[0]["model_type"] == "checkpoints" + assert result[0]["models"][0]["status"] == "used" + assert result[0]["models"][0]["count"] == 1 + + +def test_get_model_stats_never_used_new(tracker): + installed = {"checkpoints": ["unused.safetensors"]} + result = tracker.get_model_stats(installed) + assert result[0]["models"][0]["status"] == "unused_new" + assert result[0]["models"][0]["count"] == 0 + + +def test_get_model_stats_uninstalled(tracker): + tracker.record_model_usage([("gone.safetensors", "checkpoints")]) + installed = {} # no longer on disk + result = tracker.get_model_stats(installed) + assert result[0]["models"][0]["status"] == "uninstalled" + assert result[0]["models"][0]["installed"] is False + + +def test_get_model_stats_sorted_by_status(tracker): + tracker.record_model_usage([("active.safetensors", "checkpoints")]) + installed = {"checkpoints": ["active.safetensors", "unused.safetensors"]} + result = tracker.get_model_stats(installed) + models = result[0]["models"] + statuses = [m["status"] for m in models] + # unused_new (2) comes before used (3) in STATUS_ORDER + assert statuses.index("unused_new") < statuses.index("used") diff --git a/tracker.py b/tracker.py index 82cf7e9..08e0908 100644 --- a/tracker.py +++ b/tracker.py @@ -40,8 +40,17 @@ CREATE TABLE IF NOT EXISTS prompt_log ( class_types TEXT NOT NULL ); +CREATE TABLE IF NOT EXISTS model_usage ( + model_name TEXT PRIMARY KEY, + model_type TEXT NOT NULL, + count INTEGER NOT NULL DEFAULT 0, + first_seen TEXT NOT NULL, + last_seen TEXT NOT NULL +); + CREATE INDEX IF NOT EXISTS idx_node_usage_package ON node_usage(package); CREATE INDEX IF NOT EXISTS idx_prompt_log_timestamp ON prompt_log(timestamp); +CREATE INDEX IF NOT EXISTS idx_model_usage_type ON model_usage(model_type); """ @@ -116,6 +125,46 @@ class UsageTracker: finally: conn.close() + def record_model_usage(self, models): + """Record usage of model files from a single prompt. + + models: list of (model_name, model_type) tuples + """ + if not models: + return + now = datetime.now(timezone.utc).isoformat() + with self._lock: + self._ensure_db() + conn = self._connect() + try: + for model_name, model_type in models: + conn.execute( + """INSERT INTO model_usage (model_name, model_type, count, first_seen, last_seen) + VALUES (?, ?, 1, ?, ?) + ON CONFLICT(model_name) DO UPDATE SET + count = count + 1, + last_seen = excluded.last_seen""", + (model_name, model_type, now, now), + ) + conn.commit() + finally: + conn.close() + + def get_raw_model_stats(self): + """Return raw per-model usage rows from DB.""" + with self._lock: + self._ensure_db() + conn = self._connect() + try: + conn.row_factory = sqlite3.Row + rows = conn.execute( + "SELECT model_name, model_type, count, first_seen, last_seen " + "FROM model_usage ORDER BY count DESC" + ).fetchall() + return [dict(r) for r in rows] + finally: + conn.close() + def get_node_stats(self): """Return raw per-node usage data.""" with self._lock: @@ -216,6 +265,98 @@ class UsageTracker: result.sort(key=lambda p: p["total_executions"]) return result + def get_model_stats(self, installed_by_type): + """Return per-type grouped model stats with tier classification. + + installed_by_type: {model_type: [model_name, ...]} from ModelMapper + """ + db_rows = self.get_raw_model_stats() + db_models = {r["model_name"]: r for r in db_rows} + + now = datetime.now(timezone.utc) + one_month_ago = (now - timedelta(days=30)).isoformat() + two_months_ago = (now - timedelta(days=60)).isoformat() + tracking_start = self._get_first_prompt_time() + + STATUS_ORDER = { + "safe_to_remove": 0, + "consider_removing": 1, + "unused_new": 2, + "used": 3, + "uninstalled": 4, + } + + result_by_type = {} + + # Process installed models + for model_type, filenames in installed_by_type.items(): + entries = [] + for model_name in filenames: + if model_name in db_models: + row = db_models[model_name] + last_seen = row["last_seen"] + if last_seen < two_months_ago: + status = "safe_to_remove" + elif last_seen < one_month_ago: + status = "consider_removing" + else: + status = "used" + entry = { + "model_name": model_name, + "model_type": model_type, + "count": row["count"], + "first_seen": row["first_seen"], + "last_seen": last_seen, + "installed": True, + "status": status, + } + else: + if tracking_start is None: + status = "unused_new" + elif tracking_start < two_months_ago: + status = "safe_to_remove" + elif tracking_start < one_month_ago: + status = "consider_removing" + else: + status = "unused_new" + entry = { + "model_name": model_name, + "model_type": model_type, + "count": 0, + "first_seen": None, + "last_seen": None, + "installed": True, + "status": status, + } + entries.append(entry) + result_by_type[model_type] = entries + + # Add uninstalled (in DB but not on disk) + installed_names = { + name for names in installed_by_type.values() for name in names + } + for model_name, row in db_models.items(): + if model_name not in installed_names: + model_type = row["model_type"] + result_by_type.setdefault(model_type, []).append({ + "model_name": model_name, + "model_type": model_type, + "count": row["count"], + "first_seen": row["first_seen"], + "last_seen": row["last_seen"], + "installed": False, + "status": "uninstalled", + }) + + # Sort each type's models by status tier then name + result = [] + for model_type in sorted(result_by_type): + models = result_by_type[model_type] + models.sort(key=lambda m: (STATUS_ORDER.get(m["status"], 5), m["model_name"])) + result.append({"model_type": model_type, "models": models}) + + return result + def _get_first_prompt_time(self): """Return the timestamp of the earliest recorded prompt, or None.""" with self._lock: @@ -237,6 +378,7 @@ class UsageTracker: try: conn.execute("DELETE FROM node_usage") conn.execute("DELETE FROM prompt_log") + conn.execute("DELETE FROM model_usage") conn.commit() finally: conn.close()