Files
fund-tracer/backend/app/services/ocr_service.py

293 lines
12 KiB
Python
Raw Normal View History

2026-03-11 16:28:04 +08:00
"""OCR and multimodal extraction service.
2026-03-12 12:32:29 +08:00
Both classify_page and extract_transaction_fields use OpenAI-compatible
multimodal chat completion APIs. OCR and LLM can point to different
providers / models via separate env vars:
OCR_API_URL / OCR_API_KEY / OCR_MODEL for page classification & field extraction
LLM_API_URL / LLM_API_KEY / LLM_MODEL for reasoning tasks (assessment, suggestions)
When OCR keys are not set, falls back to LLM keys.
When neither is set, returns mock data (sufficient for demo).
2026-03-11 16:28:04 +08:00
"""
2026-03-12 12:32:29 +08:00
import base64
2026-03-11 16:28:04 +08:00
import json
import logging
2026-03-12 12:32:29 +08:00
import ast
import re
2026-03-11 16:28:04 +08:00
import httpx
from app.core.config import settings
from app.models.evidence_image import SourceApp, PageType
logger = logging.getLogger(__name__)
2026-03-12 12:32:29 +08:00
ENABLE_LLM_REPAIR = False # temporary: disabled per debugging request
def _ocr_config() -> tuple[str, str, str]:
"""Return (api_url, api_key, model) for OCR, falling back to LLM config."""
url = settings.OCR_API_URL or settings.LLM_API_URL
key = settings.OCR_API_KEY or settings.LLM_API_KEY
model = settings.OCR_MODEL or settings.LLM_MODEL
return url, key, model
def _llm_config() -> tuple[str, str, str]:
"""Return (api_url, api_key, model) for text-only LLM repair."""
return settings.LLM_API_URL, settings.LLM_API_KEY, settings.LLM_MODEL
def _ocr_available() -> bool:
url, key, model = _ocr_config()
return bool(url and key and model)
2026-03-13 23:29:55 +08:00
def _missing_ocr_fields() -> list[str]:
missing: list[str] = []
if not (settings.OCR_API_URL or settings.LLM_API_URL):
missing.append("OCR_API_URL(or LLM_API_URL)")
if not (settings.OCR_API_KEY or settings.LLM_API_KEY):
missing.append("OCR_API_KEY(or LLM_API_KEY)")
if not (settings.OCR_MODEL or settings.LLM_MODEL):
missing.append("OCR_MODEL(or LLM_MODEL)")
return missing
2026-03-12 12:32:29 +08:00
def _llm_available() -> bool:
url, key, model = _llm_config()
return bool(url and key and model)
2026-03-11 16:28:04 +08:00
# ── provider-agnostic interface ──────────────────────────────────────────
async def classify_page(image_path: str) -> tuple[SourceApp, PageType]:
"""Identify the source app and page type of a screenshot."""
2026-03-12 12:32:29 +08:00
if _ocr_available():
2026-03-11 16:28:04 +08:00
return await _classify_via_api(image_path)
2026-03-13 23:29:55 +08:00
if not settings.OCR_ALLOW_MOCK_FALLBACK:
missing = ", ".join(_missing_ocr_fields()) or "unknown"
raise RuntimeError(f"OCR configuration missing: {missing}")
logger.warning("OCR unavailable, falling back to mock classification for image: %s", image_path)
2026-03-11 16:28:04 +08:00
return _classify_mock(image_path)
2026-03-12 12:32:29 +08:00
async def extract_transaction_fields(
image_path: str, source_app: SourceApp, page_type: PageType
) -> tuple[dict | list, str]:
2026-03-11 16:28:04 +08:00
"""Extract structured transaction fields from a screenshot."""
2026-03-12 12:32:29 +08:00
if _ocr_available():
2026-03-11 16:28:04 +08:00
return await _extract_via_api(image_path, source_app, page_type)
2026-03-13 23:29:55 +08:00
if not settings.OCR_ALLOW_MOCK_FALLBACK:
missing = ", ".join(_missing_ocr_fields()) or "unknown"
raise RuntimeError(f"OCR configuration missing: {missing}")
logger.warning("OCR unavailable, falling back to mock extraction for image: %s", image_path)
2026-03-12 12:32:29 +08:00
mock_data = _extract_mock(image_path, source_app, page_type)
return mock_data, json.dumps(mock_data, ensure_ascii=False)
# ── OpenAI-compatible API implementation ─────────────────────────────────
async def _call_vision(prompt: str, image_b64: str, max_tokens: int = 2000) -> str:
"""Send a vision request to the OCR model endpoint."""
url, key, model = _ocr_config()
async with httpx.AsyncClient(timeout=60) as client:
resp = await client.post(
url,
headers={"Authorization": f"Bearer {key}"},
json={
"model": model,
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": prompt},
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_b64}"}},
],
}
],
"max_tokens": max_tokens,
"temperature": 0,
},
)
resp.raise_for_status()
body = resp.json()
choice0 = (body.get("choices") or [{}])[0]
msg = choice0.get("message") or {}
content = msg.get("content", "")
finish_reason = choice0.get("finish_reason")
usage = body.get("usage") or {}
return content if isinstance(content, str) else str(content)
async def _call_text_llm(prompt: str, max_tokens: int = 1200) -> str:
"""Send a text-only request to the reasoning LLM endpoint."""
url, key, model = _llm_config()
async with httpx.AsyncClient(timeout=60) as client:
resp = await client.post(
url,
headers={"Authorization": f"Bearer {key}"},
json={
"model": model,
"messages": [{"role": "user", "content": prompt}],
"max_tokens": max_tokens,
},
)
resp.raise_for_status()
return resp.json()["choices"][0]["message"]["content"]
def _parse_json_response(text: str):
"""Strip markdown fences and parse JSON from LLM output."""
cleaned = text.strip()
if cleaned.startswith("```"):
cleaned = cleaned.split("\n", 1)[-1]
if cleaned.endswith("```"):
cleaned = cleaned.rsplit("```", 1)[0]
cleaned = cleaned.strip().removeprefix("json").strip()
# Try parsing the full body first.
try:
return json.loads(cleaned)
except Exception:
pass
# Extract likely JSON block when model wraps with extra text.
match = re.search(r"(\{[\s\S]*\}|\[[\s\S]*\])", cleaned)
candidate = match.group(1) if match else cleaned
# Normalize common model output issues:
# - smart quotes
# - trailing commas
normalized = (
candidate
.replace("", "\"")
.replace("", "\"")
.replace("", "'")
.replace("", "'")
)
normalized = re.sub(r",\s*([}\]])", r"\1", normalized)
try:
return json.loads(normalized)
except Exception:
pass
# Last fallback: Python-like literal payloads with single quotes / True/False/None.
py_like = re.sub(r"\btrue\b", "True", normalized, flags=re.IGNORECASE)
py_like = re.sub(r"\bfalse\b", "False", py_like, flags=re.IGNORECASE)
py_like = re.sub(r"\bnull\b", "None", py_like, flags=re.IGNORECASE)
parsed = ast.literal_eval(py_like)
if isinstance(parsed, (dict, list)):
return parsed
raise ValueError("Parsed payload is neither dict nor list")
async def _repair_broken_json_with_llm(broken_text: str) -> str:
"""Use LLM_MODEL to repair malformed OCR JSON text without adding new semantics."""
prompt = (
"你是JSON修复器。下面是OCR模型返回的可能损坏JSON文本。\n"
"任务:仅修复语法并输出一个可被 json.loads 直接解析的JSON。\n"
"硬性要求:\n"
"1) 只能输出JSON不要Markdown不要解释。\n"
"2) 保留原有字段和语义,不新增未出现的信息。\n"
"3) 若出现截断,允许最小闭合修复(补齐缺失括号/引号)。\n"
"4) 结果必须是对象或数组。\n\n"
"损坏文本如下:\n"
f"{broken_text}"
)
return await _call_text_llm(prompt, max_tokens=1200)
2026-03-11 16:28:04 +08:00
async def _classify_via_api(image_path: str) -> tuple[SourceApp, PageType]:
full_path = settings.upload_path / image_path
if not full_path.exists():
return SourceApp.other, PageType.unknown
image_b64 = base64.b64encode(full_path.read_bytes()).decode()
prompt = (
"请分析这张手机截图判断它来自哪个APPwechat/alipay/bank/digital_wallet/other"
2026-03-12 12:32:29 +08:00
"以及页面类型bill_list/bill_detail/transfer_receipt/sms_notice/balance/unknown\n"
'只返回JSON格式: {"source_app": "...", "page_type": "..."}'
2026-03-11 16:28:04 +08:00
)
try:
2026-03-12 12:32:29 +08:00
text = await _call_vision(prompt, image_b64, max_tokens=600)
data = _parse_json_response(text)
return SourceApp(data.get("source_app", "other")), PageType(data.get("page_type", "unknown"))
2026-03-11 16:28:04 +08:00
except Exception as e:
logger.warning("classify_page API failed: %s", e)
return SourceApp.other, PageType.unknown
2026-03-12 12:32:29 +08:00
async def _extract_via_api(
image_path: str, source_app: SourceApp, page_type: PageType
) -> tuple[dict | list, str]:
2026-03-11 16:28:04 +08:00
full_path = settings.upload_path / image_path
if not full_path.exists():
2026-03-12 12:32:29 +08:00
return {}, ""
2026-03-11 16:28:04 +08:00
image_b64 = base64.b64encode(full_path.read_bytes()).decode()
prompt = (
2026-03-12 12:32:29 +08:00
"你是账单OCR结构化引擎。你的输出会被程序直接 json.loads 解析任何非JSON字符都会导致任务失败。\n"
f"输入图片来自 {source_app.value} / {page_type.value}\n\n"
"【硬性输出规则】\n"
"1) 只能输出一个合法JSON值对象或数组禁止Markdown代码块、禁止解释文字、禁止注释。\n"
"2) 键名必须使用英文双引号;字符串值必须使用英文双引号;禁止单引号、禁止中文引号。\n"
"3) 禁止尾逗号,禁止 NaN/Infinity/undefined。\n"
"4) 若只有1笔交易输出对象若>=2笔交易输出数组。\n"
"5) 每笔交易字段固定为:\n"
' {"trade_time":"YYYY-MM-DD HH:MM:SS","amount":123.45,"direction":"in|out",'
'"counterparty_name":"","counterparty_account":"","self_account_tail_no":"","order_no":"","remark":"","confidence":0.0}\n'
"6) 字段约束:\n"
"- trade_time: 必须为 YYYY-MM-DD HH:MM:SS无法识别时填空字符串\"\"\n"
"- amount: 必须是数字(不要货币符号);无法识别时填 0。\n"
"- direction: 仅允许 in 或 out无法判断默认 out。\n"
"- confidence: 0~1 的数字;无法判断默认 0.5。\n"
"- 其他文本字段无法识别时填空字符串\"\"\n\n"
"【输出前自检】\n"
"- 检查是否是严格JSON可被 json.loads 直接解析)。\n"
"- 检查每条记录都包含全部9个字段。\n"
"- 检查 amount/confidence 为数字类型direction 仅为 in/out。\n"
"现在只输出最终JSON不要输出任何额外文本。"
2026-03-11 16:28:04 +08:00
)
try:
2026-03-12 12:32:29 +08:00
text = await _call_vision(prompt, image_b64, max_tokens=6000)
try:
parsed = _parse_json_response(text)
except Exception as parse_err:
if not ENABLE_LLM_REPAIR or not _llm_available():
raise
repaired_text = await _repair_broken_json_with_llm(text)
parsed = _parse_json_response(repaired_text)
return parsed, text
2026-03-11 16:28:04 +08:00
except Exception as e:
logger.warning("extract_transaction_fields API failed: %s", e)
2026-03-12 12:32:29 +08:00
return {}, text if "text" in locals() else ""
2026-03-11 16:28:04 +08:00
# ── mock fallback ────────────────────────────────────────────────────────
def _classify_mock(image_path: str) -> tuple[SourceApp, PageType]:
lower = image_path.lower()
if "wechat" in lower or "wx" in lower:
return SourceApp.wechat, PageType.bill_detail
if "alipay" in lower or "ali" in lower:
return SourceApp.alipay, PageType.bill_list
if "bank" in lower:
return SourceApp.bank, PageType.bill_detail
return SourceApp.other, PageType.unknown
def _extract_mock(image_path: str, source_app: SourceApp, page_type: PageType) -> dict:
return {
"trade_time": "2026-03-08 10:00:00",
"amount": 1000.00,
"direction": "out",
"counterparty_name": "模拟对手方",
"counterparty_account": "",
"self_account_tail_no": "",
"order_no": f"MOCK-{hash(image_path) % 100000:05d}",
"remark": "模拟交易",
"confidence": 0.80,
}