refactor: unify active assignment checks and add project-type guards

- Replace all UnitAssignment "active" checks from `status == "active"` to
  `assigned_until == None` in both project_locations.py and projects.py.
  This aligns with the canonical definition: active = no end date set.
  (status field is still set in sync, but is no longer the query criterion)

- Add `_require_sound_project()` helper to both routers and call it at the
  top of every sound-monitoring-specific endpoint (FTP browser, FTP downloads,
  RND file viewer, all Excel report endpoints, combined report wizard,
  upload-all, NRL live status, NRL data upload). Vibration projects hitting
  these endpoints now receive a clear 400 instead of silently failing or
  returning empty results.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-03-27 21:12:38 +00:00
parent 33e962e73d
commit e8e155556a
2 changed files with 71 additions and 17 deletions

View File

@@ -35,6 +35,19 @@ from backend.templates_config import templates
router = APIRouter(prefix="/api/projects/{project_id}", tags=["project-locations"])
# ============================================================================
# Shared helpers
# ============================================================================
def _require_sound_project(project) -> None:
"""Raise 400 if the project is not a sound_monitoring project."""
if not project or project.project_type_id != "sound_monitoring":
raise HTTPException(
status_code=400,
detail="This feature is only available for Sound Monitoring projects.",
)
# ============================================================================
# Session period helpers
# ============================================================================
@@ -98,11 +111,11 @@ async def get_project_locations(
# Enrich with assignment info
locations_data = []
for location in locations:
# Get active assignment
# Get active assignment (active = assigned_until IS NULL)
assignment = db.query(UnitAssignment).filter(
and_(
UnitAssignment.location_id == location.id,
UnitAssignment.status == "active",
UnitAssignment.assigned_until == None,
)
).first()
@@ -258,11 +271,11 @@ async def delete_location(
if not location:
raise HTTPException(status_code=404, detail="Location not found")
# Check if location has active assignments
# Check if location has active assignments (active = assigned_until IS NULL)
active_assignments = db.query(UnitAssignment).filter(
and_(
UnitAssignment.location_id == location_id,
UnitAssignment.status == "active",
UnitAssignment.assigned_until == None,
)
).count()
@@ -569,9 +582,9 @@ async def get_available_units(
)
).all()
# Filter out units that already have active assignments
# Filter out units that already have active assignments (active = assigned_until IS NULL)
assigned_unit_ids = db.query(UnitAssignment.unit_id).filter(
UnitAssignment.status == "active"
UnitAssignment.assigned_until == None
).distinct().all()
assigned_unit_ids = [uid[0] for uid in assigned_unit_ids]
@@ -747,6 +760,9 @@ async def upload_nrl_data(
from datetime import datetime
# Verify project and location exist
project = db.query(Project).filter_by(id=project_id).first()
_require_sound_project(project)
location = db.query(MonitoringLocation).filter_by(
id=location_id, project_id=project_id
).first()
@@ -925,15 +941,18 @@ async def get_nrl_live_status(
Fetch cached status from SLMM for the unit assigned to this NRL and
return a compact HTML status card. Used in the NRL overview tab for
connected NRLs. Gracefully shows an offline message if SLMM is unreachable.
Sound Monitoring projects only.
"""
import os
import httpx
# Find the assigned unit
_require_sound_project(db.query(Project).filter_by(id=project_id).first())
# Find the assigned unit (active = assigned_until IS NULL)
assignment = db.query(UnitAssignment).filter(
and_(
UnitAssignment.location_id == location_id,
UnitAssignment.status == "active",
UnitAssignment.assigned_until == None,
)
).first()

View File

@@ -45,6 +45,21 @@ router = APIRouter(prefix="/api/projects", tags=["projects"])
logger = logging.getLogger(__name__)
# ============================================================================
# Shared helpers
# ============================================================================
def _require_sound_project(project: Project) -> None:
"""Raise 400 if the project is not a sound_monitoring project.
Call this at the top of any endpoint that only makes sense for sound projects
(report generation, FTP browser, RND file viewer, etc.)."""
if not project or project.project_type_id != "sound_monitoring":
raise HTTPException(
status_code=400,
detail="This feature is only available for Sound Monitoring projects.",
)
# ============================================================================
# RND file normalization — maps AU2 (older Rion) column names to the NL-43
# equivalents so report generation and the web viewer work for both formats.
@@ -398,11 +413,11 @@ async def get_projects_list(
project_id=project.id
).scalar()
# Count assigned units
# Count assigned units (active = assigned_until IS NULL)
unit_count = db.query(func.count(UnitAssignment.id)).filter(
and_(
UnitAssignment.project_id == project.id,
UnitAssignment.status == "active",
UnitAssignment.assigned_until == None,
)
).scalar()
@@ -806,11 +821,11 @@ async def get_project_dashboard(
# Get locations
locations = db.query(MonitoringLocation).filter_by(project_id=project_id).all()
# Get assigned units with details
# Get assigned units with details (active = assigned_until IS NULL)
assignments = db.query(UnitAssignment).filter(
and_(
UnitAssignment.project_id == project_id,
UnitAssignment.status == "active",
UnitAssignment.assigned_until == None,
)
).all()
@@ -899,11 +914,11 @@ async def get_project_units(
"""
from backend.models import DataFile
# Get all assignments for this project
# Get all assignments for this project (active = assigned_until IS NULL)
assignments = db.query(UnitAssignment).filter(
and_(
UnitAssignment.project_id == project_id,
UnitAssignment.status == "active",
UnitAssignment.assigned_until == None,
)
).all()
@@ -1160,15 +1175,18 @@ async def get_ftp_browser(
):
"""
Get FTP browser interface for downloading files from assigned SLMs.
Returns HTML partial with FTP browser.
Returns HTML partial with FTP browser. Sound Monitoring projects only.
"""
from backend.models import DataFile
# Get all assignments for this project
project = db.query(Project).filter_by(id=project_id).first()
_require_sound_project(project)
# Get all assignments for this project (active = assigned_until IS NULL)
assignments = db.query(UnitAssignment).filter(
and_(
UnitAssignment.project_id == project_id,
UnitAssignment.status == "active",
UnitAssignment.assigned_until == None,
)
).all()
@@ -1202,6 +1220,7 @@ async def ftp_download_to_server(
"""
Download a file from an SLM to the server via FTP.
Creates a DataFile record and stores the file in data/Projects/{project_id}/
Sound Monitoring projects only.
"""
import httpx
import os
@@ -1209,6 +1228,8 @@ async def ftp_download_to_server(
from pathlib import Path
from backend.models import DataFile
_require_sound_project(db.query(Project).filter_by(id=project_id).first())
data = await request.json()
unit_id = data.get("unit_id")
remote_path = data.get("remote_path")
@@ -1367,12 +1388,15 @@ async def ftp_download_folder_to_server(
Download an entire folder from an SLM to the server via FTP.
Extracts all files from the ZIP and preserves folder structure.
Creates individual DataFile records for each file.
Sound Monitoring projects only.
"""
import httpx
import os
import hashlib
import zipfile
import io
_require_sound_project(db.query(Project).filter_by(id=project_id).first())
from pathlib import Path
from backend.models import DataFile
@@ -1915,6 +1939,7 @@ async def view_rnd_file(
# Get project info
project = db.query(Project).filter_by(id=project_id).first()
_require_sound_project(project)
# Get location info if available
location = None
@@ -1958,12 +1983,15 @@ async def get_rnd_data(
"""
Get parsed RND file data as JSON.
Returns the measurement data for charts and tables.
Sound Monitoring projects only.
"""
from backend.models import DataFile
from pathlib import Path
import csv
import io
_require_sound_project(db.query(Project).filter_by(id=project_id).first())
# Get the file record
file_record = db.query(DataFile).filter_by(id=file_id).first()
if not file_record:
@@ -2120,6 +2148,7 @@ async def generate_excel_report(
# Get related data for report context
project = db.query(Project).filter_by(id=project_id).first()
_require_sound_project(project)
location = db.query(MonitoringLocation).filter_by(id=session.location_id).first() if session.location_id else None
# Build full file path
@@ -2550,6 +2579,7 @@ async def preview_report_data(
# Get related data for report context
project = db.query(Project).filter_by(id=project_id).first()
_require_sound_project(project)
location = db.query(MonitoringLocation).filter_by(id=session.location_id).first() if session.location_id else None
# Build full file path
@@ -2761,6 +2791,7 @@ async def generate_report_from_preview(
raise HTTPException(status_code=403, detail="File does not belong to this project")
project = db.query(Project).filter_by(id=project_id).first()
_require_sound_project(project)
location = db.query(MonitoringLocation).filter_by(id=session.location_id).first() if session.location_id else None
# Extract data from request
@@ -3041,6 +3072,7 @@ async def generate_combined_excel_report(
project = db.query(Project).filter_by(id=project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
_require_sound_project(project)
# Get all sessions with measurement files
sessions = db.query(MonitoringSession).filter_by(project_id=project_id).all()
@@ -3386,6 +3418,7 @@ async def combined_report_wizard(
project = db.query(Project).filter_by(id=project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
_require_sound_project(project)
sessions = db.query(MonitoringSession).filter_by(project_id=project_id).order_by(MonitoringSession.started_at).all()
@@ -3655,6 +3688,7 @@ async def generate_combined_from_preview(
project = db.query(Project).filter_by(id=project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
_require_sound_project(project)
report_title = data.get("report_title", "Background Noise Study")
project_name = data.get("project_name", project.name)
@@ -4130,6 +4164,7 @@ async def upload_all_project_data(
project = db.query(Project).filter_by(id=project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
_require_sound_project(project)
# Load all sound monitoring locations for this project
locations = db.query(MonitoringLocation).filter_by(