Files
fund-tracer/backend/app/services/ocr_service.py
2026-03-11 16:28:04 +08:00

146 lines
6.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""OCR and multimodal extraction service.
Wraps calls to cloud OCR / multimodal APIs with a provider-agnostic interface.
When API keys are not configured, falls back to a mock implementation that
returns placeholder data (sufficient for demo / competition).
"""
import json
import logging
from pathlib import Path
import httpx
from app.core.config import settings
from app.models.evidence_image import SourceApp, PageType
logger = logging.getLogger(__name__)
# ── provider-agnostic interface ──────────────────────────────────────────
async def classify_page(image_path: str) -> tuple[SourceApp, PageType]:
"""Identify the source app and page type of a screenshot."""
if settings.LLM_API_KEY and settings.LLM_API_URL:
return await _classify_via_api(image_path)
return _classify_mock(image_path)
async def extract_transaction_fields(image_path: str, source_app: SourceApp, page_type: PageType) -> dict:
"""Extract structured transaction fields from a screenshot."""
if settings.LLM_API_KEY and settings.LLM_API_URL:
return await _extract_via_api(image_path, source_app, page_type)
return _extract_mock(image_path, source_app, page_type)
# ── real API implementation ──────────────────────────────────────────────
async def _classify_via_api(image_path: str) -> tuple[SourceApp, PageType]:
import base64
full_path = settings.upload_path / image_path
if not full_path.exists():
return SourceApp.other, PageType.unknown
image_b64 = base64.b64encode(full_path.read_bytes()).decode()
prompt = (
"请分析这张手机截图判断它来自哪个APPwechat/alipay/bank/digital_wallet/other"
"以及页面类型bill_list/bill_detail/transfer_receipt/sms_notice/balance/unknown"
"只返回JSON: {\"source_app\": \"...\", \"page_type\": \"...\"}"
)
try:
async with httpx.AsyncClient(timeout=30) as client:
resp = await client.post(
settings.LLM_API_URL,
headers={"Authorization": f"Bearer {settings.LLM_API_KEY}"},
json={
"model": settings.LLM_MODEL,
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": prompt},
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_b64}"}},
],
}
],
"max_tokens": 200,
},
)
resp.raise_for_status()
text = resp.json()["choices"][0]["message"]["content"]
data = json.loads(text.strip().strip("`").removeprefix("json").strip())
return SourceApp(data.get("source_app", "other")), PageType(data.get("page_type", "unknown"))
except Exception as e:
logger.warning("classify_page API failed: %s", e)
return SourceApp.other, PageType.unknown
async def _extract_via_api(image_path: str, source_app: SourceApp, page_type: PageType) -> dict:
import base64
full_path = settings.upload_path / image_path
if not full_path.exists():
return {}
image_b64 = base64.b64encode(full_path.read_bytes()).decode()
prompt = (
f"这是一张来自{source_app.value}{page_type.value}截图。"
"请提取其中的交易信息返回JSON格式字段包括"
"trade_time(交易时间,格式YYYY-MM-DD HH:MM:SS), amount(金额,数字), "
"direction(in或out), counterparty_name(对方名称), counterparty_account(对方账号), "
"self_account_tail_no(本方账户尾号), order_no(订单号), remark(备注), confidence(0-1)。"
"如果截图包含多笔交易返回JSON数组。否则返回单个JSON对象。"
)
try:
async with httpx.AsyncClient(timeout=60) as client:
resp = await client.post(
settings.LLM_API_URL,
headers={"Authorization": f"Bearer {settings.LLM_API_KEY}"},
json={
"model": settings.LLM_MODEL,
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": prompt},
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_b64}"}},
],
}
],
"max_tokens": 2000,
},
)
resp.raise_for_status()
text = resp.json()["choices"][0]["message"]["content"]
return json.loads(text.strip().strip("`").removeprefix("json").strip())
except Exception as e:
logger.warning("extract_transaction_fields API failed: %s", e)
return {}
# ── mock fallback ────────────────────────────────────────────────────────
def _classify_mock(image_path: str) -> tuple[SourceApp, PageType]:
lower = image_path.lower()
if "wechat" in lower or "wx" in lower:
return SourceApp.wechat, PageType.bill_detail
if "alipay" in lower or "ali" in lower:
return SourceApp.alipay, PageType.bill_list
if "bank" in lower:
return SourceApp.bank, PageType.bill_detail
return SourceApp.other, PageType.unknown
def _extract_mock(image_path: str, source_app: SourceApp, page_type: PageType) -> dict:
return {
"trade_time": "2026-03-08 10:00:00",
"amount": 1000.00,
"direction": "out",
"counterparty_name": "模拟对手方",
"counterparty_account": "",
"self_account_tail_no": "",
"order_no": f"MOCK-{hash(image_path) % 100000:05d}",
"remark": "模拟交易",
"confidence": 0.80,
}