from uuid import UUID from sqlalchemy import select, func from sqlalchemy.ext.asyncio import AsyncSession from app.models.evidence_image import EvidenceImage, SourceApp, PageType, OcrStatus from app.repositories.base import BaseRepository class ImageRepository(BaseRepository[EvidenceImage]): def __init__(self, session: AsyncSession): super().__init__(EvidenceImage, session) async def find_by_hash(self, file_hash: str) -> EvidenceImage | None: result = await self.session.execute( select(EvidenceImage).where(EvidenceImage.file_hash == file_hash) ) return result.scalar_one_or_none() async def find_by_hash_in_case(self, case_id: UUID, file_hashes: list[str]) -> EvidenceImage | None: if not file_hashes: return None result = await self.session.execute( select(EvidenceImage).where( EvidenceImage.case_id == case_id, EvidenceImage.file_hash.in_(file_hashes), ) ) return result.scalar_one_or_none() async def list_by_case( self, case_id: UUID, source_app: SourceApp | None = None, page_type: PageType | None = None, ) -> list[EvidenceImage]: query = select(EvidenceImage).where(EvidenceImage.case_id == case_id) if source_app: query = query.where(EvidenceImage.source_app == source_app) if page_type: query = query.where(EvidenceImage.page_type == page_type) query = query.order_by(EvidenceImage.uploaded_at.desc()) result = await self.session.execute(query) return list(result.scalars().all()) async def count_by_case(self, case_id: UUID) -> int: result = await self.session.execute( select(func.count()).select_from(EvidenceImage).where(EvidenceImage.case_id == case_id) ) return result.scalar() or 0 async def list_for_ocr(self, case_id: UUID, include_done: bool = False) -> list[EvidenceImage]: query = select(EvidenceImage).where(EvidenceImage.case_id == case_id) # Always exclude currently-processing images to avoid duplicate OCR # submission from different trigger paths (upload/workspace/screenshots). if include_done: query = query.where(EvidenceImage.ocr_status != OcrStatus.processing) else: query = query.where(EvidenceImage.ocr_status.in_([OcrStatus.pending, OcrStatus.failed])) result = await self.session.execute(query.order_by(EvidenceImage.uploaded_at.desc())) return list(result.scalars().all()) async def list_by_ids_in_case(self, case_id: UUID, image_ids: list[UUID]) -> list[EvidenceImage]: if not image_ids: return [] result = await self.session.execute( select(EvidenceImage) .where(EvidenceImage.case_id == case_id, EvidenceImage.id.in_(image_ids)) .order_by(EvidenceImage.uploaded_at.desc()) ) return list(result.scalars().all())