146 lines
6.1 KiB
Python
146 lines
6.1 KiB
Python
|
|
"""OCR and multimodal extraction service.
|
|||
|
|
|
|||
|
|
Wraps calls to cloud OCR / multimodal APIs with a provider-agnostic interface.
|
|||
|
|
When API keys are not configured, falls back to a mock implementation that
|
|||
|
|
returns placeholder data (sufficient for demo / competition).
|
|||
|
|
"""
|
|||
|
|
import json
|
|||
|
|
import logging
|
|||
|
|
from pathlib import Path
|
|||
|
|
|
|||
|
|
import httpx
|
|||
|
|
|
|||
|
|
from app.core.config import settings
|
|||
|
|
from app.models.evidence_image import SourceApp, PageType
|
|||
|
|
|
|||
|
|
logger = logging.getLogger(__name__)
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ── provider-agnostic interface ──────────────────────────────────────────
|
|||
|
|
|
|||
|
|
async def classify_page(image_path: str) -> tuple[SourceApp, PageType]:
|
|||
|
|
"""Identify the source app and page type of a screenshot."""
|
|||
|
|
if settings.LLM_API_KEY and settings.LLM_API_URL:
|
|||
|
|
return await _classify_via_api(image_path)
|
|||
|
|
return _classify_mock(image_path)
|
|||
|
|
|
|||
|
|
|
|||
|
|
async def extract_transaction_fields(image_path: str, source_app: SourceApp, page_type: PageType) -> dict:
|
|||
|
|
"""Extract structured transaction fields from a screenshot."""
|
|||
|
|
if settings.LLM_API_KEY and settings.LLM_API_URL:
|
|||
|
|
return await _extract_via_api(image_path, source_app, page_type)
|
|||
|
|
return _extract_mock(image_path, source_app, page_type)
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ── real API implementation ──────────────────────────────────────────────
|
|||
|
|
|
|||
|
|
async def _classify_via_api(image_path: str) -> tuple[SourceApp, PageType]:
|
|||
|
|
import base64
|
|||
|
|
full_path = settings.upload_path / image_path
|
|||
|
|
if not full_path.exists():
|
|||
|
|
return SourceApp.other, PageType.unknown
|
|||
|
|
|
|||
|
|
image_b64 = base64.b64encode(full_path.read_bytes()).decode()
|
|||
|
|
prompt = (
|
|||
|
|
"请分析这张手机截图,判断它来自哪个APP(wechat/alipay/bank/digital_wallet/other)"
|
|||
|
|
"以及页面类型(bill_list/bill_detail/transfer_receipt/sms_notice/balance/unknown)。"
|
|||
|
|
"只返回JSON: {\"source_app\": \"...\", \"page_type\": \"...\"}"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
async with httpx.AsyncClient(timeout=30) as client:
|
|||
|
|
resp = await client.post(
|
|||
|
|
settings.LLM_API_URL,
|
|||
|
|
headers={"Authorization": f"Bearer {settings.LLM_API_KEY}"},
|
|||
|
|
json={
|
|||
|
|
"model": settings.LLM_MODEL,
|
|||
|
|
"messages": [
|
|||
|
|
{
|
|||
|
|
"role": "user",
|
|||
|
|
"content": [
|
|||
|
|
{"type": "text", "text": prompt},
|
|||
|
|
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_b64}"}},
|
|||
|
|
],
|
|||
|
|
}
|
|||
|
|
],
|
|||
|
|
"max_tokens": 200,
|
|||
|
|
},
|
|||
|
|
)
|
|||
|
|
resp.raise_for_status()
|
|||
|
|
text = resp.json()["choices"][0]["message"]["content"]
|
|||
|
|
data = json.loads(text.strip().strip("`").removeprefix("json").strip())
|
|||
|
|
return SourceApp(data.get("source_app", "other")), PageType(data.get("page_type", "unknown"))
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.warning("classify_page API failed: %s", e)
|
|||
|
|
return SourceApp.other, PageType.unknown
|
|||
|
|
|
|||
|
|
|
|||
|
|
async def _extract_via_api(image_path: str, source_app: SourceApp, page_type: PageType) -> dict:
|
|||
|
|
import base64
|
|||
|
|
full_path = settings.upload_path / image_path
|
|||
|
|
if not full_path.exists():
|
|||
|
|
return {}
|
|||
|
|
|
|||
|
|
image_b64 = base64.b64encode(full_path.read_bytes()).decode()
|
|||
|
|
prompt = (
|
|||
|
|
f"这是一张来自{source_app.value}的{page_type.value}截图。"
|
|||
|
|
"请提取其中的交易信息,返回JSON格式,字段包括:"
|
|||
|
|
"trade_time(交易时间,格式YYYY-MM-DD HH:MM:SS), amount(金额,数字), "
|
|||
|
|
"direction(in或out), counterparty_name(对方名称), counterparty_account(对方账号), "
|
|||
|
|
"self_account_tail_no(本方账户尾号), order_no(订单号), remark(备注), confidence(0-1)。"
|
|||
|
|
"如果截图包含多笔交易,返回JSON数组。否则返回单个JSON对象。"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
async with httpx.AsyncClient(timeout=60) as client:
|
|||
|
|
resp = await client.post(
|
|||
|
|
settings.LLM_API_URL,
|
|||
|
|
headers={"Authorization": f"Bearer {settings.LLM_API_KEY}"},
|
|||
|
|
json={
|
|||
|
|
"model": settings.LLM_MODEL,
|
|||
|
|
"messages": [
|
|||
|
|
{
|
|||
|
|
"role": "user",
|
|||
|
|
"content": [
|
|||
|
|
{"type": "text", "text": prompt},
|
|||
|
|
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_b64}"}},
|
|||
|
|
],
|
|||
|
|
}
|
|||
|
|
],
|
|||
|
|
"max_tokens": 2000,
|
|||
|
|
},
|
|||
|
|
)
|
|||
|
|
resp.raise_for_status()
|
|||
|
|
text = resp.json()["choices"][0]["message"]["content"]
|
|||
|
|
return json.loads(text.strip().strip("`").removeprefix("json").strip())
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.warning("extract_transaction_fields API failed: %s", e)
|
|||
|
|
return {}
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ── mock fallback ────────────────────────────────────────────────────────
|
|||
|
|
|
|||
|
|
def _classify_mock(image_path: str) -> tuple[SourceApp, PageType]:
|
|||
|
|
lower = image_path.lower()
|
|||
|
|
if "wechat" in lower or "wx" in lower:
|
|||
|
|
return SourceApp.wechat, PageType.bill_detail
|
|||
|
|
if "alipay" in lower or "ali" in lower:
|
|||
|
|
return SourceApp.alipay, PageType.bill_list
|
|||
|
|
if "bank" in lower:
|
|||
|
|
return SourceApp.bank, PageType.bill_detail
|
|||
|
|
return SourceApp.other, PageType.unknown
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _extract_mock(image_path: str, source_app: SourceApp, page_type: PageType) -> dict:
|
|||
|
|
return {
|
|||
|
|
"trade_time": "2026-03-08 10:00:00",
|
|||
|
|
"amount": 1000.00,
|
|||
|
|
"direction": "out",
|
|||
|
|
"counterparty_name": "模拟对手方",
|
|||
|
|
"counterparty_account": "",
|
|||
|
|
"self_account_tail_no": "",
|
|||
|
|
"order_no": f"MOCK-{hash(image_path) % 100000:05d}",
|
|||
|
|
"remark": "模拟交易",
|
|||
|
|
"confidence": 0.80,
|
|||
|
|
}
|