first commit
This commit is contained in:
0
backend/app/workers/__init__.py
Normal file
0
backend/app/workers/__init__.py
Normal file
38
backend/app/workers/analysis_tasks.py
Normal file
38
backend/app/workers/analysis_tasks.py
Normal file
@@ -0,0 +1,38 @@
|
||||
"""Celery task: full-case analysis pipeline."""
|
||||
import asyncio
|
||||
import logging
|
||||
from uuid import UUID
|
||||
|
||||
from app.workers.celery_app import celery_app
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _run_async(coro):
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
return loop.run_until_complete(coro)
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
||||
@celery_app.task(name="app.workers.analysis_tasks.run_full_analysis", bind=True, max_retries=2)
|
||||
def run_full_analysis(self, case_id_str: str):
|
||||
_run_async(_run(case_id_str))
|
||||
|
||||
|
||||
async def _run(case_id_str: str):
|
||||
from app.core.database import async_session_factory
|
||||
from app.services.analysis_pipeline import run_analysis_sync
|
||||
|
||||
case_id = UUID(case_id_str)
|
||||
|
||||
async with async_session_factory() as db:
|
||||
try:
|
||||
await run_analysis_sync(case_id, db)
|
||||
await db.commit()
|
||||
logger.info("Full analysis completed for case %s", case_id)
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error("Analysis failed for case %s: %s", case_id, e)
|
||||
raise
|
||||
25
backend/app/workers/celery_app.py
Normal file
25
backend/app/workers/celery_app.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from celery import Celery
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
celery_app = Celery(
|
||||
"fund_tracer",
|
||||
broker=settings.REDIS_URL,
|
||||
backend=settings.REDIS_URL,
|
||||
)
|
||||
|
||||
celery_app.conf.update(
|
||||
task_serializer="json",
|
||||
accept_content=["json"],
|
||||
result_serializer="json",
|
||||
timezone="Asia/Shanghai",
|
||||
enable_utc=True,
|
||||
task_track_started=True,
|
||||
task_routes={
|
||||
"app.workers.ocr_tasks.*": {"queue": "ocr"},
|
||||
"app.workers.analysis_tasks.*": {"queue": "analysis"},
|
||||
"app.workers.report_tasks.*": {"queue": "reports"},
|
||||
},
|
||||
)
|
||||
|
||||
celery_app.autodiscover_tasks(["app.workers"])
|
||||
74
backend/app/workers/ocr_tasks.py
Normal file
74
backend/app/workers/ocr_tasks.py
Normal file
@@ -0,0 +1,74 @@
|
||||
"""Celery tasks for OCR processing of uploaded screenshots."""
|
||||
import asyncio
|
||||
import logging
|
||||
from uuid import UUID
|
||||
|
||||
from app.workers.celery_app import celery_app
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _run_async(coro):
|
||||
"""Run an async coroutine from synchronous Celery task context."""
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
return loop.run_until_complete(coro)
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
||||
@celery_app.task(name="app.workers.ocr_tasks.process_image_ocr", bind=True, max_retries=3)
|
||||
def process_image_ocr(self, image_id: str):
|
||||
"""Process a single image: classify page, extract fields, save to DB."""
|
||||
_run_async(_process(image_id))
|
||||
|
||||
|
||||
async def _process(image_id_str: str):
|
||||
from app.core.database import async_session_factory
|
||||
from app.models.evidence_image import EvidenceImage, OcrStatus
|
||||
from app.models.ocr_block import OcrBlock
|
||||
from app.services.ocr_service import classify_page, extract_transaction_fields
|
||||
from app.services.parser_service import parse_extracted_fields
|
||||
|
||||
image_id = UUID(image_id_str)
|
||||
|
||||
async with async_session_factory() as db:
|
||||
image = await db.get(EvidenceImage, image_id)
|
||||
if not image:
|
||||
logger.error("Image %s not found", image_id)
|
||||
return
|
||||
|
||||
image.ocr_status = OcrStatus.processing
|
||||
await db.flush()
|
||||
|
||||
try:
|
||||
source_app, page_type = await classify_page(image.file_path)
|
||||
image.source_app = source_app
|
||||
image.page_type = page_type
|
||||
|
||||
raw_fields = await extract_transaction_fields(image.file_path, source_app, page_type)
|
||||
|
||||
# save raw OCR block
|
||||
block = OcrBlock(
|
||||
image_id=image.id,
|
||||
content=str(raw_fields),
|
||||
bbox={},
|
||||
seq_order=0,
|
||||
confidence=raw_fields.get("confidence", 0.5) if isinstance(raw_fields, dict) else 0.5,
|
||||
)
|
||||
db.add(block)
|
||||
|
||||
# parse into transaction records
|
||||
records = parse_extracted_fields(raw_fields, image.case_id, image.id, source_app)
|
||||
for r in records:
|
||||
db.add(r)
|
||||
|
||||
image.ocr_status = OcrStatus.done
|
||||
await db.commit()
|
||||
logger.info("Image %s processed: %d transactions", image_id, len(records))
|
||||
|
||||
except Exception as e:
|
||||
image.ocr_status = OcrStatus.failed
|
||||
await db.commit()
|
||||
logger.error("Image %s OCR failed: %s", image_id, e)
|
||||
raise
|
||||
41
backend/app/workers/report_tasks.py
Normal file
41
backend/app/workers/report_tasks.py
Normal file
@@ -0,0 +1,41 @@
|
||||
"""Celery task: async report generation."""
|
||||
import asyncio
|
||||
import logging
|
||||
from uuid import UUID
|
||||
|
||||
from app.workers.celery_app import celery_app
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _run_async(coro):
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
return loop.run_until_complete(coro)
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
||||
@celery_app.task(name="app.workers.report_tasks.generate_report_async", bind=True)
|
||||
def generate_report_async(self, case_id_str: str, report_type: str):
|
||||
_run_async(_run(case_id_str, report_type))
|
||||
|
||||
|
||||
async def _run(case_id_str: str, report_type: str):
|
||||
from app.core.database import async_session_factory
|
||||
from app.models.report import ReportType
|
||||
from app.schemas.report import ReportCreate
|
||||
from app.services.report_service import generate_report
|
||||
|
||||
case_id = UUID(case_id_str)
|
||||
body = ReportCreate(report_type=ReportType(report_type))
|
||||
|
||||
async with async_session_factory() as db:
|
||||
try:
|
||||
report = await generate_report(case_id, body, db)
|
||||
await db.commit()
|
||||
logger.info("Report generated for case %s: %s", case_id, report.file_path)
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error("Report generation failed: %s", e)
|
||||
raise
|
||||
Reference in New Issue
Block a user