146 lines
6.1 KiB
Python
146 lines
6.1 KiB
Python
"""OCR and multimodal extraction service.
|
||
|
||
Wraps calls to cloud OCR / multimodal APIs with a provider-agnostic interface.
|
||
When API keys are not configured, falls back to a mock implementation that
|
||
returns placeholder data (sufficient for demo / competition).
|
||
"""
|
||
import json
|
||
import logging
|
||
from pathlib import Path
|
||
|
||
import httpx
|
||
|
||
from app.core.config import settings
|
||
from app.models.evidence_image import SourceApp, PageType
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
# ── provider-agnostic interface ──────────────────────────────────────────
|
||
|
||
async def classify_page(image_path: str) -> tuple[SourceApp, PageType]:
|
||
"""Identify the source app and page type of a screenshot."""
|
||
if settings.LLM_API_KEY and settings.LLM_API_URL:
|
||
return await _classify_via_api(image_path)
|
||
return _classify_mock(image_path)
|
||
|
||
|
||
async def extract_transaction_fields(image_path: str, source_app: SourceApp, page_type: PageType) -> dict:
|
||
"""Extract structured transaction fields from a screenshot."""
|
||
if settings.LLM_API_KEY and settings.LLM_API_URL:
|
||
return await _extract_via_api(image_path, source_app, page_type)
|
||
return _extract_mock(image_path, source_app, page_type)
|
||
|
||
|
||
# ── real API implementation ──────────────────────────────────────────────
|
||
|
||
async def _classify_via_api(image_path: str) -> tuple[SourceApp, PageType]:
|
||
import base64
|
||
full_path = settings.upload_path / image_path
|
||
if not full_path.exists():
|
||
return SourceApp.other, PageType.unknown
|
||
|
||
image_b64 = base64.b64encode(full_path.read_bytes()).decode()
|
||
prompt = (
|
||
"请分析这张手机截图,判断它来自哪个APP(wechat/alipay/bank/digital_wallet/other)"
|
||
"以及页面类型(bill_list/bill_detail/transfer_receipt/sms_notice/balance/unknown)。"
|
||
"只返回JSON: {\"source_app\": \"...\", \"page_type\": \"...\"}"
|
||
)
|
||
|
||
try:
|
||
async with httpx.AsyncClient(timeout=30) as client:
|
||
resp = await client.post(
|
||
settings.LLM_API_URL,
|
||
headers={"Authorization": f"Bearer {settings.LLM_API_KEY}"},
|
||
json={
|
||
"model": settings.LLM_MODEL,
|
||
"messages": [
|
||
{
|
||
"role": "user",
|
||
"content": [
|
||
{"type": "text", "text": prompt},
|
||
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_b64}"}},
|
||
],
|
||
}
|
||
],
|
||
"max_tokens": 200,
|
||
},
|
||
)
|
||
resp.raise_for_status()
|
||
text = resp.json()["choices"][0]["message"]["content"]
|
||
data = json.loads(text.strip().strip("`").removeprefix("json").strip())
|
||
return SourceApp(data.get("source_app", "other")), PageType(data.get("page_type", "unknown"))
|
||
except Exception as e:
|
||
logger.warning("classify_page API failed: %s", e)
|
||
return SourceApp.other, PageType.unknown
|
||
|
||
|
||
async def _extract_via_api(image_path: str, source_app: SourceApp, page_type: PageType) -> dict:
|
||
import base64
|
||
full_path = settings.upload_path / image_path
|
||
if not full_path.exists():
|
||
return {}
|
||
|
||
image_b64 = base64.b64encode(full_path.read_bytes()).decode()
|
||
prompt = (
|
||
f"这是一张来自{source_app.value}的{page_type.value}截图。"
|
||
"请提取其中的交易信息,返回JSON格式,字段包括:"
|
||
"trade_time(交易时间,格式YYYY-MM-DD HH:MM:SS), amount(金额,数字), "
|
||
"direction(in或out), counterparty_name(对方名称), counterparty_account(对方账号), "
|
||
"self_account_tail_no(本方账户尾号), order_no(订单号), remark(备注), confidence(0-1)。"
|
||
"如果截图包含多笔交易,返回JSON数组。否则返回单个JSON对象。"
|
||
)
|
||
|
||
try:
|
||
async with httpx.AsyncClient(timeout=60) as client:
|
||
resp = await client.post(
|
||
settings.LLM_API_URL,
|
||
headers={"Authorization": f"Bearer {settings.LLM_API_KEY}"},
|
||
json={
|
||
"model": settings.LLM_MODEL,
|
||
"messages": [
|
||
{
|
||
"role": "user",
|
||
"content": [
|
||
{"type": "text", "text": prompt},
|
||
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_b64}"}},
|
||
],
|
||
}
|
||
],
|
||
"max_tokens": 2000,
|
||
},
|
||
)
|
||
resp.raise_for_status()
|
||
text = resp.json()["choices"][0]["message"]["content"]
|
||
return json.loads(text.strip().strip("`").removeprefix("json").strip())
|
||
except Exception as e:
|
||
logger.warning("extract_transaction_fields API failed: %s", e)
|
||
return {}
|
||
|
||
|
||
# ── mock fallback ────────────────────────────────────────────────────────
|
||
|
||
def _classify_mock(image_path: str) -> tuple[SourceApp, PageType]:
|
||
lower = image_path.lower()
|
||
if "wechat" in lower or "wx" in lower:
|
||
return SourceApp.wechat, PageType.bill_detail
|
||
if "alipay" in lower or "ali" in lower:
|
||
return SourceApp.alipay, PageType.bill_list
|
||
if "bank" in lower:
|
||
return SourceApp.bank, PageType.bill_detail
|
||
return SourceApp.other, PageType.unknown
|
||
|
||
|
||
def _extract_mock(image_path: str, source_app: SourceApp, page_type: PageType) -> dict:
|
||
return {
|
||
"trade_time": "2026-03-08 10:00:00",
|
||
"amount": 1000.00,
|
||
"direction": "out",
|
||
"counterparty_name": "模拟对手方",
|
||
"counterparty_account": "",
|
||
"self_account_tail_no": "",
|
||
"order_no": f"MOCK-{hash(image_path) % 100000:05d}",
|
||
"remark": "模拟交易",
|
||
"confidence": 0.80,
|
||
}
|