Files
fund-tracer/backend/app/services/ocr_service.py
2026-03-13 23:29:55 +08:00

293 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""OCR and multimodal extraction service.
Both classify_page and extract_transaction_fields use OpenAI-compatible
multimodal chat completion APIs. OCR and LLM can point to different
providers / models via separate env vars:
OCR_API_URL / OCR_API_KEY / OCR_MODEL — for page classification & field extraction
LLM_API_URL / LLM_API_KEY / LLM_MODEL — for reasoning tasks (assessment, suggestions)
When OCR keys are not set, falls back to LLM keys.
When neither is set, returns mock data (sufficient for demo).
"""
import base64
import json
import logging
import ast
import re
import httpx
from app.core.config import settings
from app.models.evidence_image import SourceApp, PageType
logger = logging.getLogger(__name__)
ENABLE_LLM_REPAIR = False # temporary: disabled per debugging request
def _ocr_config() -> tuple[str, str, str]:
"""Return (api_url, api_key, model) for OCR, falling back to LLM config."""
url = settings.OCR_API_URL or settings.LLM_API_URL
key = settings.OCR_API_KEY or settings.LLM_API_KEY
model = settings.OCR_MODEL or settings.LLM_MODEL
return url, key, model
def _llm_config() -> tuple[str, str, str]:
"""Return (api_url, api_key, model) for text-only LLM repair."""
return settings.LLM_API_URL, settings.LLM_API_KEY, settings.LLM_MODEL
def _ocr_available() -> bool:
url, key, model = _ocr_config()
return bool(url and key and model)
def _missing_ocr_fields() -> list[str]:
missing: list[str] = []
if not (settings.OCR_API_URL or settings.LLM_API_URL):
missing.append("OCR_API_URL(or LLM_API_URL)")
if not (settings.OCR_API_KEY or settings.LLM_API_KEY):
missing.append("OCR_API_KEY(or LLM_API_KEY)")
if not (settings.OCR_MODEL or settings.LLM_MODEL):
missing.append("OCR_MODEL(or LLM_MODEL)")
return missing
def _llm_available() -> bool:
url, key, model = _llm_config()
return bool(url and key and model)
# ── provider-agnostic interface ──────────────────────────────────────────
async def classify_page(image_path: str) -> tuple[SourceApp, PageType]:
"""Identify the source app and page type of a screenshot."""
if _ocr_available():
return await _classify_via_api(image_path)
if not settings.OCR_ALLOW_MOCK_FALLBACK:
missing = ", ".join(_missing_ocr_fields()) or "unknown"
raise RuntimeError(f"OCR configuration missing: {missing}")
logger.warning("OCR unavailable, falling back to mock classification for image: %s", image_path)
return _classify_mock(image_path)
async def extract_transaction_fields(
image_path: str, source_app: SourceApp, page_type: PageType
) -> tuple[dict | list, str]:
"""Extract structured transaction fields from a screenshot."""
if _ocr_available():
return await _extract_via_api(image_path, source_app, page_type)
if not settings.OCR_ALLOW_MOCK_FALLBACK:
missing = ", ".join(_missing_ocr_fields()) or "unknown"
raise RuntimeError(f"OCR configuration missing: {missing}")
logger.warning("OCR unavailable, falling back to mock extraction for image: %s", image_path)
mock_data = _extract_mock(image_path, source_app, page_type)
return mock_data, json.dumps(mock_data, ensure_ascii=False)
# ── OpenAI-compatible API implementation ─────────────────────────────────
async def _call_vision(prompt: str, image_b64: str, max_tokens: int = 2000) -> str:
"""Send a vision request to the OCR model endpoint."""
url, key, model = _ocr_config()
async with httpx.AsyncClient(timeout=60) as client:
resp = await client.post(
url,
headers={"Authorization": f"Bearer {key}"},
json={
"model": model,
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": prompt},
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_b64}"}},
],
}
],
"max_tokens": max_tokens,
"temperature": 0,
},
)
resp.raise_for_status()
body = resp.json()
choice0 = (body.get("choices") or [{}])[0]
msg = choice0.get("message") or {}
content = msg.get("content", "")
finish_reason = choice0.get("finish_reason")
usage = body.get("usage") or {}
return content if isinstance(content, str) else str(content)
async def _call_text_llm(prompt: str, max_tokens: int = 1200) -> str:
"""Send a text-only request to the reasoning LLM endpoint."""
url, key, model = _llm_config()
async with httpx.AsyncClient(timeout=60) as client:
resp = await client.post(
url,
headers={"Authorization": f"Bearer {key}"},
json={
"model": model,
"messages": [{"role": "user", "content": prompt}],
"max_tokens": max_tokens,
},
)
resp.raise_for_status()
return resp.json()["choices"][0]["message"]["content"]
def _parse_json_response(text: str):
"""Strip markdown fences and parse JSON from LLM output."""
cleaned = text.strip()
if cleaned.startswith("```"):
cleaned = cleaned.split("\n", 1)[-1]
if cleaned.endswith("```"):
cleaned = cleaned.rsplit("```", 1)[0]
cleaned = cleaned.strip().removeprefix("json").strip()
# Try parsing the full body first.
try:
return json.loads(cleaned)
except Exception:
pass
# Extract likely JSON block when model wraps with extra text.
match = re.search(r"(\{[\s\S]*\}|\[[\s\S]*\])", cleaned)
candidate = match.group(1) if match else cleaned
# Normalize common model output issues:
# - smart quotes
# - trailing commas
normalized = (
candidate
.replace("", "\"")
.replace("", "\"")
.replace("", "'")
.replace("", "'")
)
normalized = re.sub(r",\s*([}\]])", r"\1", normalized)
try:
return json.loads(normalized)
except Exception:
pass
# Last fallback: Python-like literal payloads with single quotes / True/False/None.
py_like = re.sub(r"\btrue\b", "True", normalized, flags=re.IGNORECASE)
py_like = re.sub(r"\bfalse\b", "False", py_like, flags=re.IGNORECASE)
py_like = re.sub(r"\bnull\b", "None", py_like, flags=re.IGNORECASE)
parsed = ast.literal_eval(py_like)
if isinstance(parsed, (dict, list)):
return parsed
raise ValueError("Parsed payload is neither dict nor list")
async def _repair_broken_json_with_llm(broken_text: str) -> str:
"""Use LLM_MODEL to repair malformed OCR JSON text without adding new semantics."""
prompt = (
"你是JSON修复器。下面是OCR模型返回的可能损坏JSON文本。\n"
"任务:仅修复语法并输出一个可被 json.loads 直接解析的JSON。\n"
"硬性要求:\n"
"1) 只能输出JSON不要Markdown不要解释。\n"
"2) 保留原有字段和语义,不新增未出现的信息。\n"
"3) 若出现截断,允许最小闭合修复(补齐缺失括号/引号)。\n"
"4) 结果必须是对象或数组。\n\n"
"损坏文本如下:\n"
f"{broken_text}"
)
return await _call_text_llm(prompt, max_tokens=1200)
async def _classify_via_api(image_path: str) -> tuple[SourceApp, PageType]:
full_path = settings.upload_path / image_path
if not full_path.exists():
return SourceApp.other, PageType.unknown
image_b64 = base64.b64encode(full_path.read_bytes()).decode()
prompt = (
"请分析这张手机截图判断它来自哪个APPwechat/alipay/bank/digital_wallet/other"
"以及页面类型bill_list/bill_detail/transfer_receipt/sms_notice/balance/unknown\n"
'只返回JSON格式: {"source_app": "...", "page_type": "..."}'
)
try:
text = await _call_vision(prompt, image_b64, max_tokens=600)
data = _parse_json_response(text)
return SourceApp(data.get("source_app", "other")), PageType(data.get("page_type", "unknown"))
except Exception as e:
logger.warning("classify_page API failed: %s", e)
return SourceApp.other, PageType.unknown
async def _extract_via_api(
image_path: str, source_app: SourceApp, page_type: PageType
) -> tuple[dict | list, str]:
full_path = settings.upload_path / image_path
if not full_path.exists():
return {}, ""
image_b64 = base64.b64encode(full_path.read_bytes()).decode()
prompt = (
"你是账单OCR结构化引擎。你的输出会被程序直接 json.loads 解析任何非JSON字符都会导致任务失败。\n"
f"输入图片来自 {source_app.value} / {page_type.value}\n\n"
"【硬性输出规则】\n"
"1) 只能输出一个合法JSON值对象或数组禁止Markdown代码块、禁止解释文字、禁止注释。\n"
"2) 键名必须使用英文双引号;字符串值必须使用英文双引号;禁止单引号、禁止中文引号。\n"
"3) 禁止尾逗号,禁止 NaN/Infinity/undefined。\n"
"4) 若只有1笔交易输出对象若>=2笔交易输出数组。\n"
"5) 每笔交易字段固定为:\n"
' {"trade_time":"YYYY-MM-DD HH:MM:SS","amount":123.45,"direction":"in|out",'
'"counterparty_name":"","counterparty_account":"","self_account_tail_no":"","order_no":"","remark":"","confidence":0.0}\n'
"6) 字段约束:\n"
"- trade_time: 必须为 YYYY-MM-DD HH:MM:SS无法识别时填空字符串\"\"\n"
"- amount: 必须是数字(不要货币符号);无法识别时填 0。\n"
"- direction: 仅允许 in 或 out无法判断默认 out。\n"
"- confidence: 0~1 的数字;无法判断默认 0.5。\n"
"- 其他文本字段无法识别时填空字符串\"\"\n\n"
"【输出前自检】\n"
"- 检查是否是严格JSON可被 json.loads 直接解析)。\n"
"- 检查每条记录都包含全部9个字段。\n"
"- 检查 amount/confidence 为数字类型direction 仅为 in/out。\n"
"现在只输出最终JSON不要输出任何额外文本。"
)
try:
text = await _call_vision(prompt, image_b64, max_tokens=6000)
try:
parsed = _parse_json_response(text)
except Exception as parse_err:
if not ENABLE_LLM_REPAIR or not _llm_available():
raise
repaired_text = await _repair_broken_json_with_llm(text)
parsed = _parse_json_response(repaired_text)
return parsed, text
except Exception as e:
logger.warning("extract_transaction_fields API failed: %s", e)
return {}, text if "text" in locals() else ""
# ── mock fallback ────────────────────────────────────────────────────────
def _classify_mock(image_path: str) -> tuple[SourceApp, PageType]:
lower = image_path.lower()
if "wechat" in lower or "wx" in lower:
return SourceApp.wechat, PageType.bill_detail
if "alipay" in lower or "ali" in lower:
return SourceApp.alipay, PageType.bill_list
if "bank" in lower:
return SourceApp.bank, PageType.bill_detail
return SourceApp.other, PageType.unknown
def _extract_mock(image_path: str, source_app: SourceApp, page_type: PageType) -> dict:
return {
"trade_time": "2026-03-08 10:00:00",
"amount": 1000.00,
"direction": "out",
"counterparty_name": "模拟对手方",
"counterparty_account": "",
"self_account_tail_no": "",
"order_no": f"MOCK-{hash(image_path) % 100000:05d}",
"remark": "模拟交易",
"confidence": 0.80,
}