from uuid import UUID from sqlalchemy import select, func from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload from app.models.assessment import FraudAssessment, ConfidenceLevel from app.repositories.base import BaseRepository class AssessmentRepository(BaseRepository[FraudAssessment]): def __init__(self, session: AsyncSession): super().__init__(FraudAssessment, session) async def list_by_case( self, case_id: UUID, confidence_level: ConfidenceLevel | None = None, ) -> tuple[list[FraudAssessment], int]: query = ( select(FraudAssessment) .options(selectinload(FraudAssessment.transaction)) .where(FraudAssessment.case_id == case_id) ) count_q = select(func.count()).select_from(FraudAssessment).where(FraudAssessment.case_id == case_id) if confidence_level: query = query.where(FraudAssessment.confidence_level == confidence_level) count_q = count_q.where(FraudAssessment.confidence_level == confidence_level) total = (await self.session.execute(count_q)).scalar() or 0 query = query.order_by(FraudAssessment.created_at.asc()) result = await self.session.execute(query) return list(result.scalars().all()), total