Files
fund-tracer/backend/app/services/ocr_service.py

146 lines
6.1 KiB
Python
Raw Normal View History

2026-03-11 16:28:04 +08:00
"""OCR and multimodal extraction service.
Wraps calls to cloud OCR / multimodal APIs with a provider-agnostic interface.
When API keys are not configured, falls back to a mock implementation that
returns placeholder data (sufficient for demo / competition).
"""
import json
import logging
from pathlib import Path
import httpx
from app.core.config import settings
from app.models.evidence_image import SourceApp, PageType
logger = logging.getLogger(__name__)
# ── provider-agnostic interface ──────────────────────────────────────────
async def classify_page(image_path: str) -> tuple[SourceApp, PageType]:
"""Identify the source app and page type of a screenshot."""
if settings.LLM_API_KEY and settings.LLM_API_URL:
return await _classify_via_api(image_path)
return _classify_mock(image_path)
async def extract_transaction_fields(image_path: str, source_app: SourceApp, page_type: PageType) -> dict:
"""Extract structured transaction fields from a screenshot."""
if settings.LLM_API_KEY and settings.LLM_API_URL:
return await _extract_via_api(image_path, source_app, page_type)
return _extract_mock(image_path, source_app, page_type)
# ── real API implementation ──────────────────────────────────────────────
async def _classify_via_api(image_path: str) -> tuple[SourceApp, PageType]:
import base64
full_path = settings.upload_path / image_path
if not full_path.exists():
return SourceApp.other, PageType.unknown
image_b64 = base64.b64encode(full_path.read_bytes()).decode()
prompt = (
"请分析这张手机截图判断它来自哪个APPwechat/alipay/bank/digital_wallet/other"
"以及页面类型bill_list/bill_detail/transfer_receipt/sms_notice/balance/unknown"
"只返回JSON: {\"source_app\": \"...\", \"page_type\": \"...\"}"
)
try:
async with httpx.AsyncClient(timeout=30) as client:
resp = await client.post(
settings.LLM_API_URL,
headers={"Authorization": f"Bearer {settings.LLM_API_KEY}"},
json={
"model": settings.LLM_MODEL,
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": prompt},
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_b64}"}},
],
}
],
"max_tokens": 200,
},
)
resp.raise_for_status()
text = resp.json()["choices"][0]["message"]["content"]
data = json.loads(text.strip().strip("`").removeprefix("json").strip())
return SourceApp(data.get("source_app", "other")), PageType(data.get("page_type", "unknown"))
except Exception as e:
logger.warning("classify_page API failed: %s", e)
return SourceApp.other, PageType.unknown
async def _extract_via_api(image_path: str, source_app: SourceApp, page_type: PageType) -> dict:
import base64
full_path = settings.upload_path / image_path
if not full_path.exists():
return {}
image_b64 = base64.b64encode(full_path.read_bytes()).decode()
prompt = (
f"这是一张来自{source_app.value}{page_type.value}截图。"
"请提取其中的交易信息返回JSON格式字段包括"
"trade_time(交易时间,格式YYYY-MM-DD HH:MM:SS), amount(金额,数字), "
"direction(in或out), counterparty_name(对方名称), counterparty_account(对方账号), "
"self_account_tail_no(本方账户尾号), order_no(订单号), remark(备注), confidence(0-1)。"
"如果截图包含多笔交易返回JSON数组。否则返回单个JSON对象。"
)
try:
async with httpx.AsyncClient(timeout=60) as client:
resp = await client.post(
settings.LLM_API_URL,
headers={"Authorization": f"Bearer {settings.LLM_API_KEY}"},
json={
"model": settings.LLM_MODEL,
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": prompt},
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_b64}"}},
],
}
],
"max_tokens": 2000,
},
)
resp.raise_for_status()
text = resp.json()["choices"][0]["message"]["content"]
return json.loads(text.strip().strip("`").removeprefix("json").strip())
except Exception as e:
logger.warning("extract_transaction_fields API failed: %s", e)
return {}
# ── mock fallback ────────────────────────────────────────────────────────
def _classify_mock(image_path: str) -> tuple[SourceApp, PageType]:
lower = image_path.lower()
if "wechat" in lower or "wx" in lower:
return SourceApp.wechat, PageType.bill_detail
if "alipay" in lower or "ali" in lower:
return SourceApp.alipay, PageType.bill_list
if "bank" in lower:
return SourceApp.bank, PageType.bill_detail
return SourceApp.other, PageType.unknown
def _extract_mock(image_path: str, source_app: SourceApp, page_type: PageType) -> dict:
return {
"trade_time": "2026-03-08 10:00:00",
"amount": 1000.00,
"direction": "out",
"counterparty_name": "模拟对手方",
"counterparty_account": "",
"self_account_tail_no": "",
"order_no": f"MOCK-{hash(image_path) % 100000:05d}",
"remark": "模拟交易",
"confidence": 0.80,
}