2026-03-11 16:28:04 +08:00
|
|
|
"""Celery tasks for OCR processing of uploaded screenshots."""
|
|
|
|
|
import asyncio
|
|
|
|
|
import logging
|
2026-03-12 12:32:29 +08:00
|
|
|
import json
|
2026-03-11 16:28:04 +08:00
|
|
|
from uuid import UUID
|
|
|
|
|
|
|
|
|
|
from app.workers.celery_app import celery_app
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _run_async(coro):
|
|
|
|
|
"""Run an async coroutine from synchronous Celery task context."""
|
|
|
|
|
loop = asyncio.new_event_loop()
|
|
|
|
|
try:
|
|
|
|
|
return loop.run_until_complete(coro)
|
|
|
|
|
finally:
|
|
|
|
|
loop.close()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@celery_app.task(name="app.workers.ocr_tasks.process_image_ocr", bind=True, max_retries=3)
|
|
|
|
|
def process_image_ocr(self, image_id: str):
|
|
|
|
|
"""Process a single image: classify page, extract fields, save to DB."""
|
2026-03-12 12:32:29 +08:00
|
|
|
_run_async(process_image_ocr_async(image_id))
|
2026-03-11 16:28:04 +08:00
|
|
|
|
|
|
|
|
|
2026-03-12 19:57:30 +08:00
|
|
|
async def process_images_ocr_batch_async(image_ids: list[str], max_concurrency: int) -> None:
|
|
|
|
|
"""Process many images with bounded OCR concurrency."""
|
|
|
|
|
if not image_ids:
|
|
|
|
|
return
|
|
|
|
|
concurrency = max(1, max_concurrency)
|
|
|
|
|
semaphore = asyncio.Semaphore(concurrency)
|
|
|
|
|
|
|
|
|
|
async def _run_one(image_id: str) -> None:
|
|
|
|
|
async with semaphore:
|
|
|
|
|
try:
|
|
|
|
|
await process_image_ocr_async(image_id)
|
|
|
|
|
except Exception:
|
|
|
|
|
# Keep batch processing alive even if one image fails.
|
|
|
|
|
logger.exception("Image %s OCR failed in batch", image_id)
|
|
|
|
|
|
|
|
|
|
await asyncio.gather(*[_run_one(image_id) for image_id in image_ids])
|
|
|
|
|
|
|
|
|
|
|
2026-03-12 12:32:29 +08:00
|
|
|
async def process_image_ocr_async(image_id_str: str):
|
2026-03-11 16:28:04 +08:00
|
|
|
from app.core.database import async_session_factory
|
2026-03-12 12:32:29 +08:00
|
|
|
from sqlalchemy import delete
|
2026-03-11 16:28:04 +08:00
|
|
|
from app.models.evidence_image import EvidenceImage, OcrStatus
|
|
|
|
|
from app.models.ocr_block import OcrBlock
|
2026-03-12 12:32:29 +08:00
|
|
|
from app.models.transaction import TransactionRecord
|
2026-03-11 16:28:04 +08:00
|
|
|
from app.services.ocr_service import classify_page, extract_transaction_fields
|
|
|
|
|
from app.services.parser_service import parse_extracted_fields
|
|
|
|
|
|
|
|
|
|
image_id = UUID(image_id_str)
|
|
|
|
|
async with async_session_factory() as db:
|
|
|
|
|
image = await db.get(EvidenceImage, image_id)
|
|
|
|
|
if not image:
|
|
|
|
|
logger.error("Image %s not found", image_id)
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
image.ocr_status = OcrStatus.processing
|
|
|
|
|
await db.flush()
|
2026-03-12 12:32:29 +08:00
|
|
|
# Re-run OCR for this image should replace old OCR blocks/records.
|
|
|
|
|
await db.execute(delete(OcrBlock).where(OcrBlock.image_id == image.id))
|
|
|
|
|
await db.execute(delete(TransactionRecord).where(TransactionRecord.evidence_image_id == image.id))
|
2026-03-11 16:28:04 +08:00
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
source_app, page_type = await classify_page(image.file_path)
|
|
|
|
|
image.source_app = source_app
|
|
|
|
|
image.page_type = page_type
|
|
|
|
|
|
2026-03-12 12:32:29 +08:00
|
|
|
raw_fields, raw_ocr_text = await extract_transaction_fields(image.file_path, source_app, page_type)
|
|
|
|
|
|
|
|
|
|
is_empty_extract = raw_fields is None or raw_fields == {} or raw_fields == []
|
|
|
|
|
|
|
|
|
|
# save raw OCR block (direct model output for debugging)
|
|
|
|
|
raw_block_content = (raw_ocr_text or "").strip()
|
|
|
|
|
if raw_block_content:
|
|
|
|
|
block = OcrBlock(
|
|
|
|
|
image_id=image.id,
|
|
|
|
|
content=raw_block_content,
|
|
|
|
|
bbox={},
|
|
|
|
|
seq_order=0,
|
|
|
|
|
confidence=raw_fields.get("confidence", 0.5) if isinstance(raw_fields, dict) else 0.5,
|
|
|
|
|
)
|
|
|
|
|
db.add(block)
|
|
|
|
|
elif not is_empty_extract:
|
|
|
|
|
block = OcrBlock(
|
|
|
|
|
image_id=image.id,
|
|
|
|
|
content=json.dumps(raw_fields, ensure_ascii=False),
|
|
|
|
|
bbox={},
|
|
|
|
|
seq_order=0,
|
|
|
|
|
confidence=raw_fields.get("confidence", 0.5) if isinstance(raw_fields, dict) else 0.5,
|
|
|
|
|
)
|
|
|
|
|
db.add(block)
|
2026-03-11 16:28:04 +08:00
|
|
|
|
|
|
|
|
# parse into transaction records
|
|
|
|
|
records = parse_extracted_fields(raw_fields, image.case_id, image.id, source_app)
|
|
|
|
|
for r in records:
|
|
|
|
|
db.add(r)
|
|
|
|
|
|
|
|
|
|
image.ocr_status = OcrStatus.done
|
|
|
|
|
await db.commit()
|
|
|
|
|
logger.info("Image %s processed: %d transactions", image_id, len(records))
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
image.ocr_status = OcrStatus.failed
|
|
|
|
|
await db.commit()
|
|
|
|
|
logger.error("Image %s OCR failed: %s", image_id, e)
|
|
|
|
|
raise
|