"""Celery tasks for OCR processing of uploaded screenshots.""" import asyncio import logging import json from uuid import UUID from app.workers.celery_app import celery_app logger = logging.getLogger(__name__) def _run_async(coro): """Run an async coroutine from synchronous Celery task context.""" loop = asyncio.new_event_loop() try: return loop.run_until_complete(coro) finally: loop.close() @celery_app.task(name="app.workers.ocr_tasks.process_image_ocr", bind=True, max_retries=3) def process_image_ocr(self, image_id: str): """Process a single image: classify page, extract fields, save to DB.""" _run_async(process_image_ocr_async(image_id)) async def process_images_ocr_batch_async(image_ids: list[str], max_concurrency: int) -> None: """Process many images with bounded OCR concurrency.""" if not image_ids: return # De-duplicate in-memory to prevent repeated processing of same image id # in a single batch submission. image_ids = list(dict.fromkeys(image_ids)) concurrency = max(1, max_concurrency) semaphore = asyncio.Semaphore(concurrency) async def _run_one(image_id: str) -> None: async with semaphore: try: await process_image_ocr_async(image_id) except Exception: # Keep batch processing alive even if one image fails. logger.exception("Image %s OCR failed in batch", image_id) await asyncio.gather(*[_run_one(image_id) for image_id in image_ids]) async def process_image_ocr_async(image_id_str: str): from app.core.database import async_session_factory from sqlalchemy import delete from app.models.evidence_image import EvidenceImage, OcrStatus from app.models.ocr_block import OcrBlock from app.models.transaction import TransactionRecord from app.services.ocr_service import classify_page, extract_transaction_fields from app.services.parser_service import parse_extracted_fields image_id = UUID(image_id_str) async with async_session_factory() as db: image = await db.get(EvidenceImage, image_id) if not image: logger.error("Image %s not found", image_id) return image.ocr_status = OcrStatus.processing await db.flush() # Re-run OCR for this image should replace old OCR blocks/records. await db.execute(delete(OcrBlock).where(OcrBlock.image_id == image.id)) await db.execute(delete(TransactionRecord).where(TransactionRecord.evidence_image_id == image.id)) try: source_app, page_type = await classify_page(image.file_path) image.source_app = source_app image.page_type = page_type raw_fields, raw_ocr_text = await extract_transaction_fields(image.file_path, source_app, page_type) is_empty_extract = raw_fields is None or raw_fields == {} or raw_fields == [] # save raw OCR block (direct model output for debugging) raw_block_content = (raw_ocr_text or "").strip() if raw_block_content: block = OcrBlock( image_id=image.id, content=raw_block_content, bbox={}, seq_order=0, confidence=raw_fields.get("confidence", 0.5) if isinstance(raw_fields, dict) else 0.5, ) db.add(block) elif not is_empty_extract: block = OcrBlock( image_id=image.id, content=json.dumps(raw_fields, ensure_ascii=False), bbox={}, seq_order=0, confidence=raw_fields.get("confidence", 0.5) if isinstance(raw_fields, dict) else 0.5, ) db.add(block) # parse into transaction records records = parse_extracted_fields(raw_fields, image.case_id, image.id, source_app) for r in records: db.add(r) image.ocr_status = OcrStatus.done await db.commit() logger.info("Image %s processed: %d transactions", image_id, len(records)) except Exception as e: image.ocr_status = OcrStatus.failed await db.commit() logger.error("Image %s OCR failed: %s", image_id, e) raise