293 lines
12 KiB
Python
293 lines
12 KiB
Python
"""OCR and multimodal extraction service.
|
||
|
||
Both classify_page and extract_transaction_fields use OpenAI-compatible
|
||
multimodal chat completion APIs. OCR and LLM can point to different
|
||
providers / models via separate env vars:
|
||
|
||
OCR_API_URL / OCR_API_KEY / OCR_MODEL — for page classification & field extraction
|
||
LLM_API_URL / LLM_API_KEY / LLM_MODEL — for reasoning tasks (assessment, suggestions)
|
||
|
||
When OCR keys are not set, falls back to LLM keys.
|
||
When neither is set, returns mock data (sufficient for demo).
|
||
"""
|
||
import base64
|
||
import json
|
||
import logging
|
||
import ast
|
||
import re
|
||
|
||
import httpx
|
||
|
||
from app.core.config import settings
|
||
from app.models.evidence_image import SourceApp, PageType
|
||
|
||
logger = logging.getLogger(__name__)
|
||
ENABLE_LLM_REPAIR = False # temporary: disabled per debugging request
|
||
|
||
|
||
def _ocr_config() -> tuple[str, str, str]:
|
||
"""Return (api_url, api_key, model) for OCR, falling back to LLM config."""
|
||
url = settings.OCR_API_URL or settings.LLM_API_URL
|
||
key = settings.OCR_API_KEY or settings.LLM_API_KEY
|
||
model = settings.OCR_MODEL or settings.LLM_MODEL
|
||
return url, key, model
|
||
|
||
|
||
def _llm_config() -> tuple[str, str, str]:
|
||
"""Return (api_url, api_key, model) for text-only LLM repair."""
|
||
return settings.LLM_API_URL, settings.LLM_API_KEY, settings.LLM_MODEL
|
||
|
||
|
||
def _ocr_available() -> bool:
|
||
url, key, model = _ocr_config()
|
||
return bool(url and key and model)
|
||
|
||
|
||
def _missing_ocr_fields() -> list[str]:
|
||
missing: list[str] = []
|
||
if not (settings.OCR_API_URL or settings.LLM_API_URL):
|
||
missing.append("OCR_API_URL(or LLM_API_URL)")
|
||
if not (settings.OCR_API_KEY or settings.LLM_API_KEY):
|
||
missing.append("OCR_API_KEY(or LLM_API_KEY)")
|
||
if not (settings.OCR_MODEL or settings.LLM_MODEL):
|
||
missing.append("OCR_MODEL(or LLM_MODEL)")
|
||
return missing
|
||
|
||
|
||
def _llm_available() -> bool:
|
||
url, key, model = _llm_config()
|
||
return bool(url and key and model)
|
||
|
||
|
||
# ── provider-agnostic interface ──────────────────────────────────────────
|
||
|
||
async def classify_page(image_path: str) -> tuple[SourceApp, PageType]:
|
||
"""Identify the source app and page type of a screenshot."""
|
||
if _ocr_available():
|
||
return await _classify_via_api(image_path)
|
||
if not settings.OCR_ALLOW_MOCK_FALLBACK:
|
||
missing = ", ".join(_missing_ocr_fields()) or "unknown"
|
||
raise RuntimeError(f"OCR configuration missing: {missing}")
|
||
logger.warning("OCR unavailable, falling back to mock classification for image: %s", image_path)
|
||
return _classify_mock(image_path)
|
||
|
||
|
||
async def extract_transaction_fields(
|
||
image_path: str, source_app: SourceApp, page_type: PageType
|
||
) -> tuple[dict | list, str]:
|
||
"""Extract structured transaction fields from a screenshot."""
|
||
if _ocr_available():
|
||
return await _extract_via_api(image_path, source_app, page_type)
|
||
if not settings.OCR_ALLOW_MOCK_FALLBACK:
|
||
missing = ", ".join(_missing_ocr_fields()) or "unknown"
|
||
raise RuntimeError(f"OCR configuration missing: {missing}")
|
||
logger.warning("OCR unavailable, falling back to mock extraction for image: %s", image_path)
|
||
mock_data = _extract_mock(image_path, source_app, page_type)
|
||
return mock_data, json.dumps(mock_data, ensure_ascii=False)
|
||
|
||
|
||
# ── OpenAI-compatible API implementation ─────────────────────────────────
|
||
|
||
async def _call_vision(prompt: str, image_b64: str, max_tokens: int = 2000) -> str:
|
||
"""Send a vision request to the OCR model endpoint."""
|
||
url, key, model = _ocr_config()
|
||
async with httpx.AsyncClient(timeout=60) as client:
|
||
resp = await client.post(
|
||
url,
|
||
headers={"Authorization": f"Bearer {key}"},
|
||
json={
|
||
"model": model,
|
||
"messages": [
|
||
{
|
||
"role": "user",
|
||
"content": [
|
||
{"type": "text", "text": prompt},
|
||
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_b64}"}},
|
||
],
|
||
}
|
||
],
|
||
"max_tokens": max_tokens,
|
||
"temperature": 0,
|
||
},
|
||
)
|
||
resp.raise_for_status()
|
||
body = resp.json()
|
||
choice0 = (body.get("choices") or [{}])[0]
|
||
msg = choice0.get("message") or {}
|
||
content = msg.get("content", "")
|
||
finish_reason = choice0.get("finish_reason")
|
||
usage = body.get("usage") or {}
|
||
return content if isinstance(content, str) else str(content)
|
||
|
||
|
||
async def _call_text_llm(prompt: str, max_tokens: int = 1200) -> str:
|
||
"""Send a text-only request to the reasoning LLM endpoint."""
|
||
url, key, model = _llm_config()
|
||
async with httpx.AsyncClient(timeout=60) as client:
|
||
resp = await client.post(
|
||
url,
|
||
headers={"Authorization": f"Bearer {key}"},
|
||
json={
|
||
"model": model,
|
||
"messages": [{"role": "user", "content": prompt}],
|
||
"max_tokens": max_tokens,
|
||
},
|
||
)
|
||
resp.raise_for_status()
|
||
return resp.json()["choices"][0]["message"]["content"]
|
||
|
||
|
||
def _parse_json_response(text: str):
|
||
"""Strip markdown fences and parse JSON from LLM output."""
|
||
cleaned = text.strip()
|
||
if cleaned.startswith("```"):
|
||
cleaned = cleaned.split("\n", 1)[-1]
|
||
if cleaned.endswith("```"):
|
||
cleaned = cleaned.rsplit("```", 1)[0]
|
||
cleaned = cleaned.strip().removeprefix("json").strip()
|
||
# Try parsing the full body first.
|
||
try:
|
||
return json.loads(cleaned)
|
||
except Exception:
|
||
pass
|
||
|
||
# Extract likely JSON block when model wraps with extra text.
|
||
match = re.search(r"(\{[\s\S]*\}|\[[\s\S]*\])", cleaned)
|
||
candidate = match.group(1) if match else cleaned
|
||
|
||
# Normalize common model output issues:
|
||
# - smart quotes
|
||
# - trailing commas
|
||
normalized = (
|
||
candidate
|
||
.replace("“", "\"")
|
||
.replace("”", "\"")
|
||
.replace("‘", "'")
|
||
.replace("’", "'")
|
||
)
|
||
normalized = re.sub(r",\s*([}\]])", r"\1", normalized)
|
||
try:
|
||
return json.loads(normalized)
|
||
except Exception:
|
||
pass
|
||
|
||
# Last fallback: Python-like literal payloads with single quotes / True/False/None.
|
||
py_like = re.sub(r"\btrue\b", "True", normalized, flags=re.IGNORECASE)
|
||
py_like = re.sub(r"\bfalse\b", "False", py_like, flags=re.IGNORECASE)
|
||
py_like = re.sub(r"\bnull\b", "None", py_like, flags=re.IGNORECASE)
|
||
parsed = ast.literal_eval(py_like)
|
||
if isinstance(parsed, (dict, list)):
|
||
return parsed
|
||
raise ValueError("Parsed payload is neither dict nor list")
|
||
|
||
|
||
async def _repair_broken_json_with_llm(broken_text: str) -> str:
|
||
"""Use LLM_MODEL to repair malformed OCR JSON text without adding new semantics."""
|
||
prompt = (
|
||
"你是JSON修复器。下面是OCR模型返回的可能损坏JSON文本。\n"
|
||
"任务:仅修复语法并输出一个可被 json.loads 直接解析的JSON。\n"
|
||
"硬性要求:\n"
|
||
"1) 只能输出JSON,不要Markdown,不要解释。\n"
|
||
"2) 保留原有字段和语义,不新增未出现的信息。\n"
|
||
"3) 若出现截断,允许最小闭合修复(补齐缺失括号/引号)。\n"
|
||
"4) 结果必须是对象或数组。\n\n"
|
||
"损坏文本如下:\n"
|
||
f"{broken_text}"
|
||
)
|
||
return await _call_text_llm(prompt, max_tokens=1200)
|
||
|
||
|
||
async def _classify_via_api(image_path: str) -> tuple[SourceApp, PageType]:
|
||
full_path = settings.upload_path / image_path
|
||
if not full_path.exists():
|
||
return SourceApp.other, PageType.unknown
|
||
|
||
image_b64 = base64.b64encode(full_path.read_bytes()).decode()
|
||
prompt = (
|
||
"请分析这张手机截图,判断它来自哪个APP(wechat/alipay/bank/digital_wallet/other)"
|
||
"以及页面类型(bill_list/bill_detail/transfer_receipt/sms_notice/balance/unknown)。\n"
|
||
'只返回JSON,格式: {"source_app": "...", "page_type": "..."}'
|
||
)
|
||
|
||
try:
|
||
text = await _call_vision(prompt, image_b64, max_tokens=600)
|
||
data = _parse_json_response(text)
|
||
return SourceApp(data.get("source_app", "other")), PageType(data.get("page_type", "unknown"))
|
||
except Exception as e:
|
||
logger.warning("classify_page API failed: %s", e)
|
||
return SourceApp.other, PageType.unknown
|
||
|
||
|
||
async def _extract_via_api(
|
||
image_path: str, source_app: SourceApp, page_type: PageType
|
||
) -> tuple[dict | list, str]:
|
||
full_path = settings.upload_path / image_path
|
||
if not full_path.exists():
|
||
return {}, ""
|
||
|
||
image_b64 = base64.b64encode(full_path.read_bytes()).decode()
|
||
prompt = (
|
||
"你是账单OCR结构化引擎。你的输出会被程序直接 json.loads 解析,任何非JSON字符都会导致任务失败。\n"
|
||
f"输入图片来自 {source_app.value} / {page_type.value}。\n\n"
|
||
"【硬性输出规则】\n"
|
||
"1) 只能输出一个合法JSON值:对象或数组;禁止Markdown代码块、禁止解释文字、禁止注释。\n"
|
||
"2) 键名必须使用英文双引号;字符串值必须使用英文双引号;禁止单引号、禁止中文引号。\n"
|
||
"3) 禁止尾逗号,禁止 NaN/Infinity/undefined。\n"
|
||
"4) 若只有1笔交易,输出对象;若>=2笔交易,输出数组。\n"
|
||
"5) 每笔交易字段固定为:\n"
|
||
' {"trade_time":"YYYY-MM-DD HH:MM:SS","amount":123.45,"direction":"in|out",'
|
||
'"counterparty_name":"","counterparty_account":"","self_account_tail_no":"","order_no":"","remark":"","confidence":0.0}\n'
|
||
"6) 字段约束:\n"
|
||
"- trade_time: 必须为 YYYY-MM-DD HH:MM:SS;无法识别时填空字符串\"\"。\n"
|
||
"- amount: 必须是数字(不要货币符号);无法识别时填 0。\n"
|
||
"- direction: 仅允许 in 或 out;无法判断默认 out。\n"
|
||
"- confidence: 0~1 的数字;无法判断默认 0.5。\n"
|
||
"- 其他文本字段无法识别时填空字符串\"\"。\n\n"
|
||
"【输出前自检】\n"
|
||
"- 检查是否是严格JSON(可被 json.loads 直接解析)。\n"
|
||
"- 检查每条记录都包含全部9个字段。\n"
|
||
"- 检查 amount/confidence 为数字类型,direction 仅为 in/out。\n"
|
||
"现在只输出最终JSON,不要输出任何额外文本。"
|
||
)
|
||
|
||
try:
|
||
text = await _call_vision(prompt, image_b64, max_tokens=6000)
|
||
try:
|
||
parsed = _parse_json_response(text)
|
||
except Exception as parse_err:
|
||
if not ENABLE_LLM_REPAIR or not _llm_available():
|
||
raise
|
||
repaired_text = await _repair_broken_json_with_llm(text)
|
||
parsed = _parse_json_response(repaired_text)
|
||
return parsed, text
|
||
except Exception as e:
|
||
logger.warning("extract_transaction_fields API failed: %s", e)
|
||
return {}, text if "text" in locals() else ""
|
||
|
||
|
||
# ── mock fallback ────────────────────────────────────────────────────────
|
||
|
||
def _classify_mock(image_path: str) -> tuple[SourceApp, PageType]:
|
||
lower = image_path.lower()
|
||
if "wechat" in lower or "wx" in lower:
|
||
return SourceApp.wechat, PageType.bill_detail
|
||
if "alipay" in lower or "ali" in lower:
|
||
return SourceApp.alipay, PageType.bill_list
|
||
if "bank" in lower:
|
||
return SourceApp.bank, PageType.bill_detail
|
||
return SourceApp.other, PageType.unknown
|
||
|
||
|
||
def _extract_mock(image_path: str, source_app: SourceApp, page_type: PageType) -> dict:
|
||
return {
|
||
"trade_time": "2026-03-08 10:00:00",
|
||
"amount": 1000.00,
|
||
"direction": "out",
|
||
"counterparty_name": "模拟对手方",
|
||
"counterparty_account": "",
|
||
"self_account_tail_no": "",
|
||
"order_no": f"MOCK-{hash(image_path) % 100000:05d}",
|
||
"remark": "模拟交易",
|
||
"confidence": 0.80,
|
||
}
|