68 lines
2.7 KiB
Python
68 lines
2.7 KiB
Python
from uuid import UUID
|
|
|
|
from sqlalchemy import select, func
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.models.evidence_image import EvidenceImage, SourceApp, PageType, OcrStatus
|
|
from app.repositories.base import BaseRepository
|
|
|
|
|
|
class ImageRepository(BaseRepository[EvidenceImage]):
|
|
def __init__(self, session: AsyncSession):
|
|
super().__init__(EvidenceImage, session)
|
|
|
|
async def find_by_hash(self, file_hash: str) -> EvidenceImage | None:
|
|
result = await self.session.execute(
|
|
select(EvidenceImage).where(EvidenceImage.file_hash == file_hash)
|
|
)
|
|
return result.scalar_one_or_none()
|
|
|
|
async def find_by_hash_in_case(self, case_id: UUID, file_hashes: list[str]) -> EvidenceImage | None:
|
|
if not file_hashes:
|
|
return None
|
|
result = await self.session.execute(
|
|
select(EvidenceImage).where(
|
|
EvidenceImage.case_id == case_id,
|
|
EvidenceImage.file_hash.in_(file_hashes),
|
|
)
|
|
)
|
|
return result.scalar_one_or_none()
|
|
|
|
async def list_by_case(
|
|
self,
|
|
case_id: UUID,
|
|
source_app: SourceApp | None = None,
|
|
page_type: PageType | None = None,
|
|
) -> list[EvidenceImage]:
|
|
query = select(EvidenceImage).where(EvidenceImage.case_id == case_id)
|
|
if source_app:
|
|
query = query.where(EvidenceImage.source_app == source_app)
|
|
if page_type:
|
|
query = query.where(EvidenceImage.page_type == page_type)
|
|
query = query.order_by(EvidenceImage.uploaded_at.desc())
|
|
result = await self.session.execute(query)
|
|
return list(result.scalars().all())
|
|
|
|
async def count_by_case(self, case_id: UUID) -> int:
|
|
result = await self.session.execute(
|
|
select(func.count()).select_from(EvidenceImage).where(EvidenceImage.case_id == case_id)
|
|
)
|
|
return result.scalar() or 0
|
|
|
|
async def list_for_ocr(self, case_id: UUID, include_done: bool = False) -> list[EvidenceImage]:
|
|
query = select(EvidenceImage).where(EvidenceImage.case_id == case_id)
|
|
if not include_done:
|
|
query = query.where(EvidenceImage.ocr_status != OcrStatus.done)
|
|
result = await self.session.execute(query.order_by(EvidenceImage.uploaded_at.desc()))
|
|
return list(result.scalars().all())
|
|
|
|
async def list_by_ids_in_case(self, case_id: UUID, image_ids: list[UUID]) -> list[EvidenceImage]:
|
|
if not image_ids:
|
|
return []
|
|
result = await self.session.execute(
|
|
select(EvidenceImage)
|
|
.where(EvidenceImage.case_id == case_id, EvidenceImage.id.in_(image_ids))
|
|
.order_by(EvidenceImage.uploaded_at.desc())
|
|
)
|
|
return list(result.scalars().all())
|