""" sfm/cache.py — Persistent SQLite cache for SFM device data. Caching strategy ---------------- +------------------+----------------------------------+-------------------------+ | Data | Mutability | Invalidation | +------------------+----------------------------------+-------------------------+ | Device info | Effectively immutable (firmware, | Manual clear / force | | (serial, model, | serial never change) | refresh query param | | compliance cfg) | | | +------------------+----------------------------------+-------------------------+ | Event headers | Append-only (new events added, | Fetch new ones when | | (peaks, ts, | old never modified) | device event count > | | project info) | | cached count | +------------------+----------------------------------+-------------------------+ | Full waveforms | Immutable once recorded | Never (permanent cache) | | (raw ADC samples)| | | +------------------+----------------------------------+-------------------------+ | Monitor status | Frequently changing | TTL = 30 seconds | | (battery, memory)| | | +------------------+----------------------------------+-------------------------+ Keys ---- All cached rows are keyed by (host, tcp_port) for TCP connections, or (port, baud) for serial connections. Within a device, events are keyed by index (0-based). The device serial number is stored once we learn it, and used for display / debugging only — the network address is the primary routing key (same as how the rest of the SFM code operates). """ from __future__ import annotations import json import logging import time from pathlib import Path from typing import Optional try: import sqlalchemy as sa from sqlalchemy import orm except ImportError: raise ImportError( "sqlalchemy is required for the SFM cache.\n" "Install it with: pip install sqlalchemy" ) log = logging.getLogger("sfm.cache") # ── Schema ──────────────────────────────────────────────────────────────────── Base = orm.declarative_base() _MONITOR_STATUS_TTL = 30 # seconds class CachedDevice(Base): """ Device identity + compliance config, keyed by connection address. Stores the full serialised JSON blob returned by /device/info so the endpoint can return it verbatim on a cache hit without re-connecting. """ __tablename__ = "cached_devices" # Connection key — either TCP (host+port) or serial (port+baud) conn_key = sa.Column(sa.String, primary_key=True) # e.g. "tcp:1.2.3.4:12345" serial = sa.Column(sa.String, nullable=True) # e.g. "BE11529" info_json = sa.Column(sa.Text, nullable=False) # full /device/info response JSON updated_at = sa.Column(sa.Float, nullable=False) # Unix timestamp of last write # When a config write happens we set this flag so the next /device/info call # fetches fresh data instead of serving stale compliance config. config_dirty = sa.Column(sa.Boolean, default=False, nullable=False) class CachedEvent(Base): """ Per-event header + peak values + project info, keyed by (conn_key, index). Events are immutable once recorded on the device; once we have an event in the cache it never needs to be re-downloaded unless explicitly requested. The two extra columns `waveform_key` and `event_timestamp` are an integrity stamp: when set_event() / set_waveform() are called with a different (waveform_key, event_timestamp) for the same (conn_key, index), we know the device was erased and re-recorded — the cached row no longer refers to the same physical event and the entire device's cache is flushed before the new entry is written. This catches the post-erase key-reuse bug where the device's first new event (key 01110000) collides with the first event we previously downloaded. """ __tablename__ = "cached_events" conn_key = sa.Column(sa.String, primary_key=True) index = sa.Column(sa.Integer, primary_key=True) event_json = sa.Column(sa.Text, nullable=False) # serialised Event dict cached_at = sa.Column(sa.Float, nullable=False) # Unix timestamp waveform_key = sa.Column(sa.String, nullable=True) # 8-hex device key event_timestamp = sa.Column(sa.String, nullable=True) # ISO-8601 from 0C class CachedWaveform(Base): """ Full raw ADC waveform for a single event (SUB 5A full download). These are large (up to several MB) and expensive to fetch over cellular. Once downloaded they are immutable and cached permanently — but the cache row is invalidated when the device is erased and a new event lands at the same index (see CachedEvent docstring). """ __tablename__ = "cached_waveforms" conn_key = sa.Column(sa.String, primary_key=True) index = sa.Column(sa.Integer, primary_key=True) waveform_json = sa.Column(sa.Text, nullable=False) # full /device/event/{idx}/waveform response JSON cached_at = sa.Column(sa.Float, nullable=False) waveform_key = sa.Column(sa.String, nullable=True) # 8-hex device key event_timestamp = sa.Column(sa.String, nullable=True) # ISO-8601 from 0C class CachedMonitorStatus(Base): """ Monitor status (battery, memory, is_monitoring) with a short TTL. These change frequently during field operations so we keep them only for MONITOR_STATUS_TTL seconds before re-fetching from the device. """ __tablename__ = "cached_monitor_status" conn_key = sa.Column(sa.String, primary_key=True) status_json = sa.Column(sa.Text, nullable=False) cached_at = sa.Column(sa.Float, nullable=False) # ── Cache store ─────────────────────────────────────────────────────────────── class SFMCache: """ SQLite-backed cache for SFM device data. Usage ----- cache = SFMCache() # stores in sfm/data/sfm_cache.db by default cache = SFMCache(":memory:") # in-memory (tests / ephemeral mode) All public methods accept a *conn_key* string — use make_conn_key() to build a consistent key from the transport parameters. """ def __init__(self, db_path: str | Path | None = None) -> None: in_memory = (db_path == ":memory:") if db_path is None: # Default: alongside this file in sfm/data/ db_path = Path(__file__).parent / "data" / "sfm_cache.db" if not in_memory: db_path = Path(db_path) db_path.parent.mkdir(parents=True, exist_ok=True) url = "sqlite:///:memory:" if in_memory else f"sqlite:///{db_path}" engine = sa.create_engine(url, connect_args={"check_same_thread": False}) Base.metadata.create_all(engine) self._Session = orm.sessionmaker(bind=engine) # In-place schema migration: add the (waveform_key, event_timestamp) # integrity-stamp columns to legacy cache DBs that predate the # post-erase eviction logic. ALTER TABLE ADD COLUMN is idempotent # via the column-presence check below. with engine.begin() as conn: for table in ("cached_events", "cached_waveforms"): cols = { r[1] for r in conn.exec_driver_sql(f"PRAGMA table_info({table})").fetchall() } for new_col, ddl in ( ("waveform_key", "TEXT"), ("event_timestamp", "TEXT"), ): if new_col not in cols: log.info("cache schema: %s ADD COLUMN %s %s", table, new_col, ddl) conn.exec_driver_sql(f"ALTER TABLE {table} ADD COLUMN {new_col} {ddl}") log.info("SFM cache opened: %s", db_path) # ── Connection key ──────────────────────────────────────────────────────── @staticmethod def make_conn_key( host: Optional[str], tcp_port: int, port: Optional[str], baud: int, ) -> str: """Return a stable string key for this transport configuration.""" if host: return f"tcp:{host}:{tcp_port}" return f"serial:{port}:{baud}" # ── Device info ─────────────────────────────────────────────────────────── def get_device_info(self, conn_key: str) -> Optional[dict]: """ Return cached device info dict, or None if not cached / config_dirty. """ with self._Session() as s: row = s.get(CachedDevice, conn_key) if row is None or row.config_dirty: return None return json.loads(row.info_json) def set_device_info(self, conn_key: str, info: dict) -> None: """Store device info and clear any dirty flag.""" with self._Session() as s: row = s.get(CachedDevice, conn_key) serial = info.get("serial") if row is None: row = CachedDevice( conn_key=conn_key, serial=serial, info_json=json.dumps(info), updated_at=time.time(), config_dirty=False, ) s.add(row) else: row.serial = serial row.info_json = json.dumps(info) row.updated_at = time.time() row.config_dirty = False s.commit() log.debug("cached device info for %s (serial=%s)", conn_key, serial) def mark_config_dirty(self, conn_key: str) -> None: """ Called after a successful POST /device/config write. Forces the next /device/info call to re-read compliance config from the device instead of serving the now-stale cached version. """ with self._Session() as s: row = s.get(CachedDevice, conn_key) if row: row.config_dirty = True s.commit() log.debug("marked config dirty for %s", conn_key) # ── Events ──────────────────────────────────────────────────────────────── def get_cached_event_count(self, conn_key: str) -> int: """Return the number of events we have cached for this device.""" with self._Session() as s: return s.query(CachedEvent).filter_by(conn_key=conn_key).count() def get_all_events(self, conn_key: str) -> Optional[list[dict]]: """ Return all cached events as a list of dicts, sorted by index. Returns None if nothing is cached yet. """ with self._Session() as s: rows = ( s.query(CachedEvent) .filter_by(conn_key=conn_key) .order_by(CachedEvent.index) .all() ) if not rows: return None return [json.loads(r.event_json) for r in rows] def get_event(self, conn_key: str, index: int) -> Optional[dict]: """Return a single cached event by index, or None if not cached.""" with self._Session() as s: row = s.get(CachedEvent, (conn_key, index)) return json.loads(row.event_json) if row else None @staticmethod def _event_signature(ev: dict) -> tuple[Optional[str], Optional[str]]: """ Extract the (waveform_key_hex, timestamp_iso) integrity stamp from a serialised event dict. Either field may be None if the source Event was missing it; the comparison logic in set_events/set_waveform treats "both sides have a value AND they differ" as the only eviction trigger, so partial data never spuriously flushes cache. """ key = ev.get("waveform_key") or ev.get("_waveform_key") if isinstance(key, (bytes, bytearray)): key = bytes(key).hex() ts = ev.get("timestamp") if isinstance(ts, dict): # _serialise_timestamp returns a dict like {"iso": "...", ...} ts = ts.get("iso") or ts.get("string") or None return (key if isinstance(key, str) else None, ts if isinstance(ts, str) else None) def _maybe_flush_on_mismatch( self, s, conn_key: str, index: int, new_key: Optional[str], new_ts: Optional[str], ) -> bool: """ Check whether the cached entry at (conn_key, index) has a different (waveform_key, timestamp) than the incoming one. If so, treat it as a post-erase key-reuse signal and flush ALL cached events/waveforms for this device, then return True. Returns False when no flush was needed. """ if not new_key and not new_ts: return False # nothing to compare against existing = s.get(CachedEvent, (conn_key, index)) if existing is None: existing = s.get(CachedWaveform, (conn_key, index)) if existing is None: return False old_key = existing.waveform_key old_ts = existing.event_timestamp # Only flush when both sides have populated values and they differ. differs = ( (new_key and old_key and new_key != old_key) or (new_ts and old_ts and new_ts != old_ts) ) if not differs: return False log.warning( "cache: device %s — index %d (key=%s, ts=%s) replaces (key=%s, ts=%s); " "flushing all cached events/waveforms for this device " "(post-erase key reuse detected)", conn_key, index, new_key, new_ts, old_key, old_ts, ) s.query(CachedEvent).filter_by(conn_key=conn_key).delete() s.query(CachedWaveform).filter_by(conn_key=conn_key).delete() return True def set_events(self, conn_key: str, events: list[dict]) -> None: """ Upsert a list of event dicts. Existing rows are updated; new rows are inserted. This is used to add newly-discovered events to the cache. Eviction: if any incoming event has a different (waveform_key, timestamp) than the row currently cached at the same index, we flush the entire device's cache before inserting the new entries. Catches post-erase key reuse where index 0 silently switches identity. """ now = time.time() with self._Session() as s: # Eviction check: scan incoming events for any (index, key, ts) # that conflicts with a cached row. A single conflict triggers # a full device-wide flush so we don't end up with a mixed-era # cache. for ev in events: key, ts = self._event_signature(ev) if self._maybe_flush_on_mismatch(s, conn_key, ev["index"], key, ts): s.commit() break # cache is now empty for this device; carry on for ev in events: idx = ev["index"] key, ts = self._event_signature(ev) row = s.get(CachedEvent, (conn_key, idx)) if row is None: row = CachedEvent( conn_key=conn_key, index=idx, event_json=json.dumps(ev), cached_at=now, waveform_key=key, event_timestamp=ts, ) s.add(row) log.debug("cached new event %d for %s", idx, conn_key) else: # Refresh in case project_info was backfilled after initial store row.event_json = json.dumps(ev) if key: row.waveform_key = key if ts: row.event_timestamp = ts s.commit() # ── Waveforms ───────────────────────────────────────────────────────────── def get_waveform(self, conn_key: str, index: int) -> Optional[dict]: """Return a cached full waveform response dict, or None if not cached.""" with self._Session() as s: row = s.get(CachedWaveform, (conn_key, index)) if row is None: return None log.debug("waveform cache hit: %s event %d", conn_key, index) return json.loads(row.waveform_json) def set_waveform(self, conn_key: str, index: int, waveform: dict) -> None: """ Store a full waveform response dict permanently. Like set_events, this checks the (waveform_key, timestamp) signature of the incoming entry against what's currently cached at the same index. A mismatch flushes the entire device's cache before insert. """ key, ts = self._event_signature(waveform) with self._Session() as s: self._maybe_flush_on_mismatch(s, conn_key, index, key, ts) row = s.get(CachedWaveform, (conn_key, index)) if row is None: row = CachedWaveform( conn_key=conn_key, index=index, waveform_json=json.dumps(waveform), cached_at=time.time(), waveform_key=key, event_timestamp=ts, ) s.add(row) else: row.waveform_json = json.dumps(waveform) row.cached_at = time.time() if key: row.waveform_key = key if ts: row.event_timestamp = ts s.commit() log.debug("cached waveform for %s event %d (key=%s, ts=%s)", conn_key, index, key, ts) # ── Monitor status ──────────────────────────────────────────────────────── def get_monitor_status(self, conn_key: str) -> Optional[dict]: """Return cached monitor status if it's within TTL, else None.""" with self._Session() as s: row = s.get(CachedMonitorStatus, conn_key) if row is None: return None age = time.time() - row.cached_at if age > _MONITOR_STATUS_TTL: log.debug("monitor status expired (age=%.1fs) for %s", age, conn_key) return None return json.loads(row.status_json) def set_monitor_status(self, conn_key: str, status: dict) -> None: """Store monitor status.""" with self._Session() as s: row = s.get(CachedMonitorStatus, conn_key) if row is None: row = CachedMonitorStatus( conn_key=conn_key, status_json=json.dumps(status), cached_at=time.time(), ) s.add(row) else: row.status_json = json.dumps(status) row.cached_at = time.time() s.commit() def invalidate_monitor_status(self, conn_key: str) -> None: """ Called after start/stop monitoring so the next status poll re-reads from device. """ with self._Session() as s: row = s.get(CachedMonitorStatus, conn_key) if row: s.delete(row) s.commit() # ── Cache management ────────────────────────────────────────────────────── def clear_device(self, conn_key: str) -> dict: """ Remove all cached data for a device. Returns counts of deleted rows. """ counts = {} with self._Session() as s: counts["device_info"] = s.query(CachedDevice).filter_by(conn_key=conn_key).delete() counts["events"] = s.query(CachedEvent).filter_by(conn_key=conn_key).delete() counts["waveforms"] = s.query(CachedWaveform).filter_by(conn_key=conn_key).delete() counts["monitor_status"] = s.query(CachedMonitorStatus).filter_by(conn_key=conn_key).delete() s.commit() log.info("cleared cache for %s: %s", conn_key, counts) return counts def stats(self) -> dict: """Return row counts for all cache tables (for /cache/stats endpoint).""" with self._Session() as s: return { "devices": s.query(CachedDevice).count(), "events": s.query(CachedEvent).count(), "waveforms": s.query(CachedWaveform).count(), "monitor_status": s.query(CachedMonitorStatus).count(), } # ── Module-level singleton ──────────────────────────────────────────────────── # Instantiated once when the module is imported; shared across all requests. _cache: Optional[SFMCache] = None def get_cache() -> SFMCache: """Return the module-level cache singleton, initialising it on first call.""" global _cache if _cache is None: _cache = SFMCache() return _cache