diff --git a/tests/test_trials.py b/tests/test_trials.py new file mode 100644 index 0000000..21df0fb --- /dev/null +++ b/tests/test_trials.py @@ -0,0 +1,27 @@ +import pytest +from datetime import datetime, timezone, timedelta +from unittest.mock import patch +from tracker import UsageTracker, DEFAULT_TRIAL_BUDGET + + +@pytest.fixture +def tracker(tmp_path): + return UsageTracker(db_path=str(tmp_path / "test.db")) + + +def test_start_trial_initializes(tracker): + tracker.start_trial("Some-Pack") + trials = tracker.get_trials() + assert len(trials) == 1 + t = trials[0] + assert t["package"] == "Some-Pack" + assert t["unused_boot_days"] == 0 + assert t["budget"] == DEFAULT_TRIAL_BUDGET + assert t["days_remaining"] == DEFAULT_TRIAL_BUDGET + assert t["expired"] is False + + +def test_start_trial_is_idempotent_resets(tracker): + tracker.start_trial("Some-Pack") + tracker.start_trial("Some-Pack") + assert len(tracker.get_trials()) == 1 diff --git a/tracker.py b/tracker.py index a854a72..ef27109 100644 --- a/tracker.py +++ b/tracker.py @@ -48,6 +48,15 @@ CREATE TABLE IF NOT EXISTS model_usage ( last_seen TEXT NOT NULL ); +CREATE TABLE IF NOT EXISTS trial_packages ( + package TEXT PRIMARY KEY, + enabled_at TEXT NOT NULL, + last_use_day TEXT NOT NULL, + last_boot_day TEXT NOT NULL, + unused_boot_days INTEGER NOT NULL DEFAULT 0, + budget INTEGER NOT NULL DEFAULT 7 +); + CREATE INDEX IF NOT EXISTS idx_node_usage_package ON node_usage(package); CREATE INDEX IF NOT EXISTS idx_prompt_log_timestamp ON prompt_log(timestamp); CREATE INDEX IF NOT EXISTS idx_model_usage_type ON model_usage(model_type); @@ -62,6 +71,9 @@ EXCLUDED_PACKAGES = { } +DEFAULT_TRIAL_BUDGET = 7 + + def _classify_age(timestamp, one_month_ago, two_months_ago, recent_status): """Classify an ISO timestamp into a removal tier. @@ -375,6 +387,51 @@ class UsageTracker: finally: conn.close() + def start_trial(self, package, budget=DEFAULT_TRIAL_BUDGET): + """Begin/restart a temporary-enable trial. The enable day is not counted.""" + now = datetime.now(timezone.utc) + today = now.date().isoformat() + with self._lock: + self._ensure_db() + conn = self._connect() + try: + conn.execute( + """INSERT INTO trial_packages + (package, enabled_at, last_use_day, last_boot_day, unused_boot_days, budget) + VALUES (?, ?, ?, ?, 0, ?) + ON CONFLICT(package) DO UPDATE SET + enabled_at = excluded.enabled_at, + last_use_day = excluded.last_use_day, + last_boot_day = excluded.last_boot_day, + unused_boot_days = 0, + budget = excluded.budget""", + (package, now.isoformat(), today, today, budget), + ) + conn.commit() + finally: + conn.close() + + def get_trials(self): + """Return trial rows with computed days_remaining/expired.""" + with self._lock: + self._ensure_db() + conn = self._connect() + try: + conn.row_factory = sqlite3.Row + rows = conn.execute( + "SELECT package, enabled_at, last_use_day, last_boot_day, " + "unused_boot_days, budget FROM trial_packages" + ).fetchall() + finally: + conn.close() + result = [] + for r in rows: + d = dict(r) + d["days_remaining"] = max(0, d["budget"] - d["unused_boot_days"]) + d["expired"] = d["unused_boot_days"] >= d["budget"] + result.append(d) + return result + def reset(self): """Clear all tracked data.""" with self._lock: