diff --git a/tests/test_trials.py b/tests/test_trials.py index 8029fe3..2f36c41 100644 --- a/tests/test_trials.py +++ b/tests/test_trials.py @@ -53,3 +53,19 @@ def test_tick_reaches_expiry(tracker): assert t["unused_boot_days"] == DEFAULT_TRIAL_BUDGET assert t["expired"] is True assert t["days_remaining"] == 0 + + +def test_reset_zeroes_counter(tracker): + tracker.start_trial("Pack") + with patch("tracker.datetime") as m: + m.now.return_value = _ahead(1) + tracker.tick_boot_days() + assert tracker.get_trials()[0]["unused_boot_days"] == 1 + tracker.reset_trials_for({"Pack", "Not-On-Trial"}) + assert tracker.get_trials()[0]["unused_boot_days"] == 0 + + +def test_reset_empty_is_noop(tracker): + tracker.start_trial("Pack") + tracker.reset_trials_for(set()) + assert tracker.get_trials()[0]["unused_boot_days"] == 0 diff --git a/tracker.py b/tracker.py index 76ce991..0ac5b68 100644 --- a/tracker.py +++ b/tracker.py @@ -450,6 +450,25 @@ class UsageTracker: finally: conn.close() + def reset_trials_for(self, packages): + """Reset the unused-day counter for any of these packages that are on trial.""" + if not packages: + return + today = datetime.now(timezone.utc).date().isoformat() + with self._lock: + self._ensure_db() + conn = self._connect() + try: + conn.executemany( + """UPDATE trial_packages + SET unused_boot_days = 0, last_use_day = ? + WHERE package = ?""", + [(today, p) for p in packages], + ) + conn.commit() + finally: + conn.close() + def reset(self): """Clear all tracked data.""" with self._lock: