from uuid import UUID from sqlalchemy import select, func from sqlalchemy.ext.asyncio import AsyncSession from app.models.transaction import TransactionRecord from app.repositories.base import BaseRepository class TransactionRepository(BaseRepository[TransactionRecord]): def __init__(self, session: AsyncSession): super().__init__(TransactionRecord, session) async def list_by_case( self, case_id: UUID, filter_type: str | None = None, ) -> tuple[list[TransactionRecord], int]: query = select(TransactionRecord).where(TransactionRecord.case_id == case_id) count_q = select(func.count()).select_from(TransactionRecord).where(TransactionRecord.case_id == case_id) if filter_type == "unique": query = query.where(TransactionRecord.is_duplicate.is_(False)) count_q = count_q.where(TransactionRecord.is_duplicate.is_(False)) elif filter_type == "duplicate": query = query.where(TransactionRecord.is_duplicate.is_(True)) count_q = count_q.where(TransactionRecord.is_duplicate.is_(True)) total = (await self.session.execute(count_q)).scalar() or 0 query = query.order_by(TransactionRecord.trade_time.asc()) result = await self.session.execute(query) return list(result.scalars().all()), total async def get_all_by_case(self, case_id: UUID) -> list[TransactionRecord]: result = await self.session.execute( select(TransactionRecord) .where(TransactionRecord.case_id == case_id) .order_by(TransactionRecord.trade_time.asc()) ) return list(result.scalars().all())