refactor: unify active assignment checks and add project-type guards

- Replace all UnitAssignment "active" checks from `status == "active"` to
  `assigned_until == None` in both project_locations.py and projects.py.
  This aligns with the canonical definition: active = no end date set.
  (status field is still set in sync, but is no longer the query criterion)

- Add `_require_sound_project()` helper to both routers and call it at the
  top of every sound-monitoring-specific endpoint (FTP browser, FTP downloads,
  RND file viewer, all Excel report endpoints, combined report wizard,
  upload-all, NRL live status, NRL data upload). Vibration projects hitting
  these endpoints now receive a clear 400 instead of silently failing or
  returning empty results.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-03-27 21:12:38 +00:00
parent 33e962e73d
commit e8e155556a
2 changed files with 71 additions and 17 deletions

View File

@@ -35,6 +35,19 @@ from backend.templates_config import templates
router = APIRouter(prefix="/api/projects/{project_id}", tags=["project-locations"]) router = APIRouter(prefix="/api/projects/{project_id}", tags=["project-locations"])
# ============================================================================
# Shared helpers
# ============================================================================
def _require_sound_project(project) -> None:
"""Raise 400 if the project is not a sound_monitoring project."""
if not project or project.project_type_id != "sound_monitoring":
raise HTTPException(
status_code=400,
detail="This feature is only available for Sound Monitoring projects.",
)
# ============================================================================ # ============================================================================
# Session period helpers # Session period helpers
# ============================================================================ # ============================================================================
@@ -98,11 +111,11 @@ async def get_project_locations(
# Enrich with assignment info # Enrich with assignment info
locations_data = [] locations_data = []
for location in locations: for location in locations:
# Get active assignment # Get active assignment (active = assigned_until IS NULL)
assignment = db.query(UnitAssignment).filter( assignment = db.query(UnitAssignment).filter(
and_( and_(
UnitAssignment.location_id == location.id, UnitAssignment.location_id == location.id,
UnitAssignment.status == "active", UnitAssignment.assigned_until == None,
) )
).first() ).first()
@@ -258,11 +271,11 @@ async def delete_location(
if not location: if not location:
raise HTTPException(status_code=404, detail="Location not found") raise HTTPException(status_code=404, detail="Location not found")
# Check if location has active assignments # Check if location has active assignments (active = assigned_until IS NULL)
active_assignments = db.query(UnitAssignment).filter( active_assignments = db.query(UnitAssignment).filter(
and_( and_(
UnitAssignment.location_id == location_id, UnitAssignment.location_id == location_id,
UnitAssignment.status == "active", UnitAssignment.assigned_until == None,
) )
).count() ).count()
@@ -569,9 +582,9 @@ async def get_available_units(
) )
).all() ).all()
# Filter out units that already have active assignments # Filter out units that already have active assignments (active = assigned_until IS NULL)
assigned_unit_ids = db.query(UnitAssignment.unit_id).filter( assigned_unit_ids = db.query(UnitAssignment.unit_id).filter(
UnitAssignment.status == "active" UnitAssignment.assigned_until == None
).distinct().all() ).distinct().all()
assigned_unit_ids = [uid[0] for uid in assigned_unit_ids] assigned_unit_ids = [uid[0] for uid in assigned_unit_ids]
@@ -747,6 +760,9 @@ async def upload_nrl_data(
from datetime import datetime from datetime import datetime
# Verify project and location exist # Verify project and location exist
project = db.query(Project).filter_by(id=project_id).first()
_require_sound_project(project)
location = db.query(MonitoringLocation).filter_by( location = db.query(MonitoringLocation).filter_by(
id=location_id, project_id=project_id id=location_id, project_id=project_id
).first() ).first()
@@ -925,15 +941,18 @@ async def get_nrl_live_status(
Fetch cached status from SLMM for the unit assigned to this NRL and Fetch cached status from SLMM for the unit assigned to this NRL and
return a compact HTML status card. Used in the NRL overview tab for return a compact HTML status card. Used in the NRL overview tab for
connected NRLs. Gracefully shows an offline message if SLMM is unreachable. connected NRLs. Gracefully shows an offline message if SLMM is unreachable.
Sound Monitoring projects only.
""" """
import os import os
import httpx import httpx
# Find the assigned unit _require_sound_project(db.query(Project).filter_by(id=project_id).first())
# Find the assigned unit (active = assigned_until IS NULL)
assignment = db.query(UnitAssignment).filter( assignment = db.query(UnitAssignment).filter(
and_( and_(
UnitAssignment.location_id == location_id, UnitAssignment.location_id == location_id,
UnitAssignment.status == "active", UnitAssignment.assigned_until == None,
) )
).first() ).first()

View File

@@ -45,6 +45,21 @@ router = APIRouter(prefix="/api/projects", tags=["projects"])
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# ============================================================================
# Shared helpers
# ============================================================================
def _require_sound_project(project: Project) -> None:
"""Raise 400 if the project is not a sound_monitoring project.
Call this at the top of any endpoint that only makes sense for sound projects
(report generation, FTP browser, RND file viewer, etc.)."""
if not project or project.project_type_id != "sound_monitoring":
raise HTTPException(
status_code=400,
detail="This feature is only available for Sound Monitoring projects.",
)
# ============================================================================ # ============================================================================
# RND file normalization — maps AU2 (older Rion) column names to the NL-43 # RND file normalization — maps AU2 (older Rion) column names to the NL-43
# equivalents so report generation and the web viewer work for both formats. # equivalents so report generation and the web viewer work for both formats.
@@ -398,11 +413,11 @@ async def get_projects_list(
project_id=project.id project_id=project.id
).scalar() ).scalar()
# Count assigned units # Count assigned units (active = assigned_until IS NULL)
unit_count = db.query(func.count(UnitAssignment.id)).filter( unit_count = db.query(func.count(UnitAssignment.id)).filter(
and_( and_(
UnitAssignment.project_id == project.id, UnitAssignment.project_id == project.id,
UnitAssignment.status == "active", UnitAssignment.assigned_until == None,
) )
).scalar() ).scalar()
@@ -806,11 +821,11 @@ async def get_project_dashboard(
# Get locations # Get locations
locations = db.query(MonitoringLocation).filter_by(project_id=project_id).all() locations = db.query(MonitoringLocation).filter_by(project_id=project_id).all()
# Get assigned units with details # Get assigned units with details (active = assigned_until IS NULL)
assignments = db.query(UnitAssignment).filter( assignments = db.query(UnitAssignment).filter(
and_( and_(
UnitAssignment.project_id == project_id, UnitAssignment.project_id == project_id,
UnitAssignment.status == "active", UnitAssignment.assigned_until == None,
) )
).all() ).all()
@@ -899,11 +914,11 @@ async def get_project_units(
""" """
from backend.models import DataFile from backend.models import DataFile
# Get all assignments for this project # Get all assignments for this project (active = assigned_until IS NULL)
assignments = db.query(UnitAssignment).filter( assignments = db.query(UnitAssignment).filter(
and_( and_(
UnitAssignment.project_id == project_id, UnitAssignment.project_id == project_id,
UnitAssignment.status == "active", UnitAssignment.assigned_until == None,
) )
).all() ).all()
@@ -1160,15 +1175,18 @@ async def get_ftp_browser(
): ):
""" """
Get FTP browser interface for downloading files from assigned SLMs. Get FTP browser interface for downloading files from assigned SLMs.
Returns HTML partial with FTP browser. Returns HTML partial with FTP browser. Sound Monitoring projects only.
""" """
from backend.models import DataFile from backend.models import DataFile
# Get all assignments for this project project = db.query(Project).filter_by(id=project_id).first()
_require_sound_project(project)
# Get all assignments for this project (active = assigned_until IS NULL)
assignments = db.query(UnitAssignment).filter( assignments = db.query(UnitAssignment).filter(
and_( and_(
UnitAssignment.project_id == project_id, UnitAssignment.project_id == project_id,
UnitAssignment.status == "active", UnitAssignment.assigned_until == None,
) )
).all() ).all()
@@ -1202,6 +1220,7 @@ async def ftp_download_to_server(
""" """
Download a file from an SLM to the server via FTP. Download a file from an SLM to the server via FTP.
Creates a DataFile record and stores the file in data/Projects/{project_id}/ Creates a DataFile record and stores the file in data/Projects/{project_id}/
Sound Monitoring projects only.
""" """
import httpx import httpx
import os import os
@@ -1209,6 +1228,8 @@ async def ftp_download_to_server(
from pathlib import Path from pathlib import Path
from backend.models import DataFile from backend.models import DataFile
_require_sound_project(db.query(Project).filter_by(id=project_id).first())
data = await request.json() data = await request.json()
unit_id = data.get("unit_id") unit_id = data.get("unit_id")
remote_path = data.get("remote_path") remote_path = data.get("remote_path")
@@ -1367,12 +1388,15 @@ async def ftp_download_folder_to_server(
Download an entire folder from an SLM to the server via FTP. Download an entire folder from an SLM to the server via FTP.
Extracts all files from the ZIP and preserves folder structure. Extracts all files from the ZIP and preserves folder structure.
Creates individual DataFile records for each file. Creates individual DataFile records for each file.
Sound Monitoring projects only.
""" """
import httpx import httpx
import os import os
import hashlib import hashlib
import zipfile import zipfile
import io import io
_require_sound_project(db.query(Project).filter_by(id=project_id).first())
from pathlib import Path from pathlib import Path
from backend.models import DataFile from backend.models import DataFile
@@ -1915,6 +1939,7 @@ async def view_rnd_file(
# Get project info # Get project info
project = db.query(Project).filter_by(id=project_id).first() project = db.query(Project).filter_by(id=project_id).first()
_require_sound_project(project)
# Get location info if available # Get location info if available
location = None location = None
@@ -1958,12 +1983,15 @@ async def get_rnd_data(
""" """
Get parsed RND file data as JSON. Get parsed RND file data as JSON.
Returns the measurement data for charts and tables. Returns the measurement data for charts and tables.
Sound Monitoring projects only.
""" """
from backend.models import DataFile from backend.models import DataFile
from pathlib import Path from pathlib import Path
import csv import csv
import io import io
_require_sound_project(db.query(Project).filter_by(id=project_id).first())
# Get the file record # Get the file record
file_record = db.query(DataFile).filter_by(id=file_id).first() file_record = db.query(DataFile).filter_by(id=file_id).first()
if not file_record: if not file_record:
@@ -2120,6 +2148,7 @@ async def generate_excel_report(
# Get related data for report context # Get related data for report context
project = db.query(Project).filter_by(id=project_id).first() project = db.query(Project).filter_by(id=project_id).first()
_require_sound_project(project)
location = db.query(MonitoringLocation).filter_by(id=session.location_id).first() if session.location_id else None location = db.query(MonitoringLocation).filter_by(id=session.location_id).first() if session.location_id else None
# Build full file path # Build full file path
@@ -2550,6 +2579,7 @@ async def preview_report_data(
# Get related data for report context # Get related data for report context
project = db.query(Project).filter_by(id=project_id).first() project = db.query(Project).filter_by(id=project_id).first()
_require_sound_project(project)
location = db.query(MonitoringLocation).filter_by(id=session.location_id).first() if session.location_id else None location = db.query(MonitoringLocation).filter_by(id=session.location_id).first() if session.location_id else None
# Build full file path # Build full file path
@@ -2761,6 +2791,7 @@ async def generate_report_from_preview(
raise HTTPException(status_code=403, detail="File does not belong to this project") raise HTTPException(status_code=403, detail="File does not belong to this project")
project = db.query(Project).filter_by(id=project_id).first() project = db.query(Project).filter_by(id=project_id).first()
_require_sound_project(project)
location = db.query(MonitoringLocation).filter_by(id=session.location_id).first() if session.location_id else None location = db.query(MonitoringLocation).filter_by(id=session.location_id).first() if session.location_id else None
# Extract data from request # Extract data from request
@@ -3041,6 +3072,7 @@ async def generate_combined_excel_report(
project = db.query(Project).filter_by(id=project_id).first() project = db.query(Project).filter_by(id=project_id).first()
if not project: if not project:
raise HTTPException(status_code=404, detail="Project not found") raise HTTPException(status_code=404, detail="Project not found")
_require_sound_project(project)
# Get all sessions with measurement files # Get all sessions with measurement files
sessions = db.query(MonitoringSession).filter_by(project_id=project_id).all() sessions = db.query(MonitoringSession).filter_by(project_id=project_id).all()
@@ -3386,6 +3418,7 @@ async def combined_report_wizard(
project = db.query(Project).filter_by(id=project_id).first() project = db.query(Project).filter_by(id=project_id).first()
if not project: if not project:
raise HTTPException(status_code=404, detail="Project not found") raise HTTPException(status_code=404, detail="Project not found")
_require_sound_project(project)
sessions = db.query(MonitoringSession).filter_by(project_id=project_id).order_by(MonitoringSession.started_at).all() sessions = db.query(MonitoringSession).filter_by(project_id=project_id).order_by(MonitoringSession.started_at).all()
@@ -3655,6 +3688,7 @@ async def generate_combined_from_preview(
project = db.query(Project).filter_by(id=project_id).first() project = db.query(Project).filter_by(id=project_id).first()
if not project: if not project:
raise HTTPException(status_code=404, detail="Project not found") raise HTTPException(status_code=404, detail="Project not found")
_require_sound_project(project)
report_title = data.get("report_title", "Background Noise Study") report_title = data.get("report_title", "Background Noise Study")
project_name = data.get("project_name", project.name) project_name = data.get("project_name", project.name)
@@ -4130,6 +4164,7 @@ async def upload_all_project_data(
project = db.query(Project).filter_by(id=project_id).first() project = db.query(Project).filter_by(id=project_id).first()
if not project: if not project:
raise HTTPException(status_code=404, detail="Project not found") raise HTTPException(status_code=404, detail="Project not found")
_require_sound_project(project)
# Load all sound monitoring locations for this project # Load all sound monitoring locations for this project
locations = db.query(MonitoringLocation).filter_by( locations = db.query(MonitoringLocation).filter_by(