From a1a80bbb4d2a3c5e2a614966b5215d71875cbd30 Mon Sep 17 00:00:00 2001 From: serversdwn Date: Mon, 16 Feb 2026 04:25:51 +0000 Subject: [PATCH] add: new persisent connection approach, env variables for tcp keepalive and persist, added connection pool class. --- app/main.py | 7 + app/routers.py | 28 +++ app/services.py | 451 ++++++++++++++++++++++++++++++++++++++++-------- 3 files changed, 417 insertions(+), 69 deletions(-) diff --git a/app/main.py b/app/main.py index 406d7fc..176de97 100644 --- a/app/main.py +++ b/app/main.py @@ -29,7 +29,11 @@ logger.info("Database tables initialized") @asynccontextmanager async def lifespan(app: FastAPI): """Manage application lifecycle - startup and shutdown events.""" + from app.services import _connection_pool + # Startup + logger.info("Starting TCP connection pool cleanup task...") + _connection_pool.start_cleanup() logger.info("Starting background poller...") await poller.start() logger.info("Background poller started") @@ -40,6 +44,9 @@ async def lifespan(app: FastAPI): logger.info("Stopping background poller...") await poller.stop() logger.info("Background poller stopped") + logger.info("Closing TCP connection pool...") + await _connection_pool.close_all() + logger.info("TCP connection pool closed") app = FastAPI( diff --git a/app/routers.py b/app/routers.py index 40ecaf8..89d8ce7 100644 --- a/app/routers.py +++ b/app/routers.py @@ -93,6 +93,34 @@ class PollingConfigPayload(BaseModel): poll_enabled: bool | None = Field(None, description="Enable or disable background polling for this device") +# ============================================================================ +# TCP CONNECTION POOL ENDPOINTS (must be before /{unit_id} routes) +# ============================================================================ + +@router.get("/_connections/status") +async def get_connection_pool_status(): + """Get status of the persistent TCP connection pool. + + Returns information about cached connections, keepalive settings, + and per-device connection age/idle times. + """ + from app.services import _connection_pool + return {"status": "ok", "pool": _connection_pool.get_stats()} + + +@router.post("/_connections/flush") +async def flush_connection_pool(): + """Close all cached TCP connections. + + Useful for debugging or forcing fresh connections to all devices. + """ + from app.services import _connection_pool + await _connection_pool.close_all() + # Restart cleanup task since close_all cancels it + _connection_pool.start_cleanup() + return {"status": "ok", "message": "All cached connections closed"} + + # ============================================================================ # GLOBAL POLLING STATUS ENDPOINT (must be before /{unit_id} routes) # ============================================================================ diff --git a/app/services.py b/app/services.py index 7e4c554..d85cb4b 100644 --- a/app/services.py +++ b/app/services.py @@ -1,20 +1,22 @@ """ NL43 TCP connector and snapshot persistence. -Implements simple per-request TCP calls to avoid long-lived socket complexity. -Extend to pooled connections/DRD streaming later. +Implements persistent per-device TCP connections with OS-level keepalive +to reduce handshake overhead and survive cellular modem NAT timeouts. +Falls back to per-request connections on error with transparent retry. """ import asyncio import contextlib import logging +import socket import time import os import zipfile import tempfile -from dataclasses import dataclass +from dataclasses import dataclass, field from datetime import datetime, timezone, timedelta -from typing import Optional, List, Dict +from typing import Optional, List, Dict, Tuple from sqlalchemy.orm import Session from ftplib import FTP from pathlib import Path @@ -234,6 +236,293 @@ async def _get_device_lock(device_key: str) -> asyncio.Lock: return _device_locks[device_key] +# --------------------------------------------------------------------------- +# Persistent TCP connection pool with OS-level keepalive +# --------------------------------------------------------------------------- + +# Configuration via environment variables +TCP_PERSISTENT_ENABLED = os.getenv("TCP_PERSISTENT_ENABLED", "true").lower() == "true" +TCP_IDLE_TTL = float(os.getenv("TCP_IDLE_TTL", "120")) # Close idle connections after N seconds +TCP_MAX_AGE = float(os.getenv("TCP_MAX_AGE", "300")) # Force reconnect after N seconds +TCP_KEEPALIVE_IDLE = int(os.getenv("TCP_KEEPALIVE_IDLE", "15")) # Seconds idle before probes +TCP_KEEPALIVE_INTERVAL = int(os.getenv("TCP_KEEPALIVE_INTERVAL", "10")) # Seconds between probes +TCP_KEEPALIVE_COUNT = int(os.getenv("TCP_KEEPALIVE_COUNT", "3")) # Failed probes before dead + +logger.info( + f"TCP connection pool: persistent={TCP_PERSISTENT_ENABLED}, " + f"idle_ttl={TCP_IDLE_TTL}s, max_age={TCP_MAX_AGE}s, " + f"keepalive_idle={TCP_KEEPALIVE_IDLE}s, keepalive_interval={TCP_KEEPALIVE_INTERVAL}s, " + f"keepalive_count={TCP_KEEPALIVE_COUNT}" +) + + +@dataclass +class DeviceConnection: + """Tracks a cached TCP connection and its metadata.""" + reader: asyncio.StreamReader + writer: asyncio.StreamWriter + device_key: str + host: str + port: int + created_at: float = field(default_factory=time.time) + last_used_at: float = field(default_factory=time.time) + + +class ConnectionPool: + """Per-device persistent TCP connection cache with OS-level keepalive. + + Each NL-43 device supports only one TCP connection at a time. This pool + caches that single connection per device key and reuses it across commands, + avoiding repeated TCP handshakes over high-latency cellular links. + + Keepalive probes keep cellular NAT tables alive and detect dead connections + before the next command attempt. + """ + + def __init__( + self, + enable_persistent: bool = True, + idle_ttl: float = 120.0, + max_age: float = 300.0, + keepalive_idle: int = 15, + keepalive_interval: int = 10, + keepalive_count: int = 3, + ): + self._connections: Dict[str, DeviceConnection] = {} + self._lock = asyncio.Lock() + self._enable_persistent = enable_persistent + self._idle_ttl = idle_ttl + self._max_age = max_age + self._keepalive_idle = keepalive_idle + self._keepalive_interval = keepalive_interval + self._keepalive_count = keepalive_count + self._cleanup_task: Optional[asyncio.Task] = None + + # -- lifecycle ---------------------------------------------------------- + + def start_cleanup(self): + """Start background task that evicts stale connections.""" + if self._enable_persistent and self._cleanup_task is None: + self._cleanup_task = asyncio.create_task(self._cleanup_loop()) + logger.info("Connection pool cleanup task started") + + async def close_all(self): + """Close all cached connections (called at shutdown).""" + if self._cleanup_task is not None: + self._cleanup_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await self._cleanup_task + self._cleanup_task = None + + async with self._lock: + for key, conn in list(self._connections.items()): + await self._close_connection(conn, reason="shutdown") + self._connections.clear() + logger.info("Connection pool: all connections closed") + + # -- public API --------------------------------------------------------- + + async def acquire( + self, device_key: str, host: str, port: int, timeout: float + ) -> Tuple[asyncio.StreamReader, asyncio.StreamWriter, bool]: + """Get a connection for a device (cached or fresh). + + Returns: + (reader, writer, from_cache) — from_cache is True if reused. + """ + if self._enable_persistent: + async with self._lock: + conn = self._connections.pop(device_key, None) + + if conn is not None: + if self._is_alive(conn): + self._drain_buffer(conn.reader) + conn.last_used_at = time.time() + logger.debug(f"Pool hit for {device_key} (age={time.time() - conn.created_at:.0f}s)") + return conn.reader, conn.writer, True + else: + await self._close_connection(conn, reason="stale") + + # Open fresh connection + reader, writer = await self._open_connection(host, port, timeout) + logger.debug(f"New connection opened for {device_key}") + return reader, writer, False + + async def release(self, device_key: str, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, host: str, port: int): + """Return a connection to the pool for reuse.""" + if not self._enable_persistent: + self._close_writer(writer) + return + + # Check transport is still healthy before caching + if writer.transport.is_closing() or reader.at_eof(): + self._close_writer(writer) + return + + conn = DeviceConnection( + reader=reader, + writer=writer, + device_key=device_key, + host=host, + port=port, + ) + + async with self._lock: + # Evict any existing connection for this device (shouldn't happen + # under normal locking, but be safe) + old = self._connections.pop(device_key, None) + if old is not None: + await self._close_connection(old, reason="replaced") + self._connections[device_key] = conn + + async def discard(self, device_key: str): + """Close and remove a connection from the pool (called on errors).""" + async with self._lock: + conn = self._connections.pop(device_key, None) + if conn is not None: + await self._close_connection(conn, reason="discarded") + logger.debug(f"Pool discard for {device_key}") + + def get_stats(self) -> dict: + """Return pool status for diagnostics.""" + now = time.time() + connections = {} + for key, conn in self._connections.items(): + connections[key] = { + "host": conn.host, + "port": conn.port, + "age_seconds": round(now - conn.created_at, 1), + "idle_seconds": round(now - conn.last_used_at, 1), + "alive": self._is_alive(conn), + } + return { + "enabled": self._enable_persistent, + "active_connections": len(self._connections), + "idle_ttl": self._idle_ttl, + "max_age": self._max_age, + "keepalive_idle": self._keepalive_idle, + "keepalive_interval": self._keepalive_interval, + "keepalive_count": self._keepalive_count, + "connections": connections, + } + + # -- internals ---------------------------------------------------------- + + async def _open_connection( + self, host: str, port: int, timeout: float + ) -> Tuple[asyncio.StreamReader, asyncio.StreamWriter]: + """Open a new TCP connection with keepalive options set.""" + try: + reader, writer = await asyncio.wait_for( + asyncio.open_connection(host, port), timeout=timeout + ) + except asyncio.TimeoutError: + raise ConnectionError(f"Failed to connect to device at {host}:{port}") + except Exception as e: + raise ConnectionError(f"Failed to connect to device: {e}") + + # Set TCP keepalive on the underlying socket + self._set_keepalive(writer) + return reader, writer + + def _set_keepalive(self, writer: asyncio.StreamWriter): + """Configure OS-level TCP keepalive on the connection socket.""" + try: + sock = writer.transport.get_extra_info("socket") + if sock is None: + logger.warning("Could not access underlying socket for keepalive") + return + + sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) + + # Linux-specific keepalive tuning + if hasattr(socket, "TCP_KEEPIDLE"): + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, self._keepalive_idle) + if hasattr(socket, "TCP_KEEPINTVL"): + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, self._keepalive_interval) + if hasattr(socket, "TCP_KEEPCNT"): + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, self._keepalive_count) + + logger.debug( + f"TCP keepalive set: idle={self._keepalive_idle}s, " + f"interval={self._keepalive_interval}s, count={self._keepalive_count}" + ) + except OSError as e: + logger.warning(f"Failed to set TCP keepalive: {e}") + + def _is_alive(self, conn: DeviceConnection) -> bool: + """Check whether a cached connection is still usable.""" + now = time.time() + + # Age / idle checks + if now - conn.last_used_at > self._idle_ttl: + logger.debug(f"Connection {conn.device_key} idle too long ({now - conn.last_used_at:.0f}s > {self._idle_ttl}s)") + return False + if now - conn.created_at > self._max_age: + logger.debug(f"Connection {conn.device_key} too old ({now - conn.created_at:.0f}s > {self._max_age}s)") + return False + + # Transport-level checks + transport = conn.writer.transport + if transport.is_closing(): + logger.debug(f"Connection {conn.device_key} transport is closing") + return False + if conn.reader.at_eof(): + logger.debug(f"Connection {conn.device_key} reader at EOF") + return False + + return True + + @staticmethod + def _drain_buffer(reader: asyncio.StreamReader): + """Drain any pending bytes (e.g. '$' prompt) from an idle connection.""" + buf = reader._buffer # noqa: SLF001 — internal but stable across CPython + if buf: + pending = bytes(buf) + buf.clear() + logger.debug(f"Drained {len(pending)} bytes from cached connection: {pending!r}") + + @staticmethod + def _close_writer(writer: asyncio.StreamWriter): + """Close a writer, suppressing errors.""" + try: + writer.close() + except Exception: + pass + + async def _close_connection(self, conn: DeviceConnection, reason: str = ""): + """Fully close a cached connection.""" + logger.debug(f"Closing connection {conn.device_key} ({reason})") + conn.writer.close() + with contextlib.suppress(Exception): + await conn.writer.wait_closed() + + async def _cleanup_loop(self): + """Periodically evict idle/expired connections.""" + try: + while True: + await asyncio.sleep(30) + async with self._lock: + for key in list(self._connections): + conn = self._connections[key] + if not self._is_alive(conn): + del self._connections[key] + await self._close_connection(conn, reason="cleanup") + except asyncio.CancelledError: + pass + + +# Module-level pool singleton +_connection_pool = ConnectionPool( + enable_persistent=TCP_PERSISTENT_ENABLED, + idle_ttl=TCP_IDLE_TTL, + max_age=TCP_MAX_AGE, + keepalive_idle=TCP_KEEPALIVE_IDLE, + keepalive_interval=TCP_KEEPALIVE_INTERVAL, + keepalive_count=TCP_KEEPALIVE_COUNT, +) + + class NL43Client: def __init__(self, host: str, port: int, timeout: float = 5.0, ftp_username: str = None, ftp_password: str = None, ftp_port: int = 21): self.host = host @@ -275,72 +564,97 @@ class NL43Client: return await self._send_command_unlocked(cmd) async def _send_command_unlocked(self, cmd: str) -> str: - """Internal: send command without acquiring device lock (lock must be held by caller).""" + """Internal: send command without acquiring device lock (lock must be held by caller). + + Uses the connection pool to reuse cached TCP connections when possible. + If a cached connection fails, retries once with a fresh connection. + """ await self._enforce_rate_limit() logger.info(f"Sending command to {self.device_key}: {cmd.strip()}") try: - reader, writer = await asyncio.wait_for( - asyncio.open_connection(self.host, self.port), timeout=self.timeout + reader, writer, from_cache = await _connection_pool.acquire( + self.device_key, self.host, self.port, self.timeout ) - except asyncio.TimeoutError: - logger.error(f"Connection timeout to {self.device_key}") - raise ConnectionError(f"Failed to connect to device at {self.host}:{self.port}") - except Exception as e: - logger.error(f"Connection failed to {self.device_key}: {e}") - raise ConnectionError(f"Failed to connect to device: {str(e)}") + except ConnectionError: + logger.error(f"Connection failed to {self.device_key}") + raise try: - writer.write(cmd.encode("ascii")) - await writer.drain() - - # Read first line (result code) - first_line_data = await asyncio.wait_for(reader.readuntil(b"\n"), timeout=self.timeout) - result_code = first_line_data.decode(errors="ignore").strip() - - # Remove leading $ prompt if present - if result_code.startswith("$"): - result_code = result_code[1:].strip() - - logger.info(f"Result code from {self.device_key}: {result_code}") - - # Check result code - if result_code == "R+0000": - # Success - for query commands, read the second line with actual data - is_query = cmd.strip().endswith("?") - if is_query: - data_line = await asyncio.wait_for(reader.readuntil(b"\n"), timeout=self.timeout) - response = data_line.decode(errors="ignore").strip() - logger.debug(f"Data line from {self.device_key}: {response}") - return response - else: - # Setting command - return success code - return result_code - elif result_code == "R+0001": - raise ValueError("Command error - device did not recognize command") - elif result_code == "R+0002": - raise ValueError("Parameter error - invalid parameter value") - elif result_code == "R+0003": - raise ValueError("Spec/type error - command not supported by this device model") - elif result_code == "R+0004": - raise ValueError("Status error - device is in wrong state for this command") - else: - raise ValueError(f"Unknown result code: {result_code}") - - except asyncio.TimeoutError: - logger.error(f"Response timeout from {self.device_key}") - raise TimeoutError(f"Device did not respond within {self.timeout}s") - except Exception as e: - logger.error(f"Communication error with {self.device_key}: {e}") - raise - finally: - writer.close() - with contextlib.suppress(Exception): - await writer.wait_closed() - # Record completion time for rate limiting — NL43 requires ≥1s - # after response before next command, so measure from connection close + response = await self._execute_command(reader, writer, cmd) + # Success — return connection to pool for reuse + await _connection_pool.release(self.device_key, reader, writer, self.host, self.port) _last_command_time[self.device_key] = time.time() + return response + + except Exception as e: + # Discard the bad connection + await _connection_pool.discard(self.device_key) + ConnectionPool._close_writer(writer) + + if from_cache: + # Retry once with a fresh connection — the cached one may have gone stale + logger.warning(f"Cached connection failed for {self.device_key}, retrying fresh: {e}") + await self._enforce_rate_limit() + + try: + reader, writer, _ = await _connection_pool.acquire( + self.device_key, self.host, self.port, self.timeout + ) + except ConnectionError: + logger.error(f"Retry connection also failed to {self.device_key}") + raise + + try: + response = await self._execute_command(reader, writer, cmd) + await _connection_pool.release(self.device_key, reader, writer, self.host, self.port) + _last_command_time[self.device_key] = time.time() + return response + except Exception: + await _connection_pool.discard(self.device_key) + ConnectionPool._close_writer(writer) + raise + else: + raise + + async def _execute_command(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, cmd: str) -> str: + """Send a command over an existing connection and parse the NL43 response.""" + writer.write(cmd.encode("ascii")) + await writer.drain() + + # Read first line (result code) + first_line_data = await asyncio.wait_for(reader.readuntil(b"\n"), timeout=self.timeout) + result_code = first_line_data.decode(errors="ignore").strip() + + # Remove leading $ prompt if present + if result_code.startswith("$"): + result_code = result_code[1:].strip() + + logger.info(f"Result code from {self.device_key}: {result_code}") + + # Check result code + if result_code == "R+0000": + # Success — for query commands, read the second line with actual data + is_query = cmd.strip().endswith("?") + if is_query: + data_line = await asyncio.wait_for(reader.readuntil(b"\n"), timeout=self.timeout) + response = data_line.decode(errors="ignore").strip() + logger.debug(f"Data line from {self.device_key}: {response}") + return response + else: + # Setting command — return success code + return result_code + elif result_code == "R+0001": + raise ValueError("Command error - device did not recognize command") + elif result_code == "R+0002": + raise ValueError("Parameter error - invalid parameter value") + elif result_code == "R+0003": + raise ValueError("Spec/type error - command not supported by this device model") + elif result_code == "R+0004": + raise ValueError("Status error - device is in wrong state for this command") + else: + raise ValueError(f"Unknown result code: {result_code}") async def request_dod(self) -> NL43Snapshot: """Request DOD (Data Output Display) snapshot from device. @@ -582,20 +896,19 @@ class NL43Client: # Acquire per-device lock - held for entire streaming session device_lock = await _get_device_lock(self.device_key) async with device_lock: + # Evict any cached connection — streaming needs its own dedicated socket + await _connection_pool.discard(self.device_key) await self._enforce_rate_limit() logger.info(f"Starting DRD stream for {self.device_key}") try: - reader, writer = await asyncio.wait_for( - asyncio.open_connection(self.host, self.port), timeout=self.timeout + reader, writer = await _connection_pool._open_connection( + self.host, self.port, self.timeout ) - except asyncio.TimeoutError: - logger.error(f"DRD stream connection timeout to {self.device_key}") - raise ConnectionError(f"Failed to connect to device at {self.host}:{self.port}") - except Exception as e: - logger.error(f"DRD stream connection failed to {self.device_key}: {e}") - raise ConnectionError(f"Failed to connect to device: {str(e)}") + except ConnectionError: + logger.error(f"DRD stream connection failed to {self.device_key}") + raise try: # Start DRD streaming