This commit is contained in:
2026-03-12 12:32:29 +08:00
parent c0f9ddabbf
commit 470446fa6f
18 changed files with 591 additions and 142 deletions

View File

@@ -1,12 +1,20 @@
"""OCR and multimodal extraction service.
Wraps calls to cloud OCR / multimodal APIs with a provider-agnostic interface.
When API keys are not configured, falls back to a mock implementation that
returns placeholder data (sufficient for demo / competition).
Both classify_page and extract_transaction_fields use OpenAI-compatible
multimodal chat completion APIs. OCR and LLM can point to different
providers / models via separate env vars:
OCR_API_URL / OCR_API_KEY / OCR_MODEL — for page classification & field extraction
LLM_API_URL / LLM_API_KEY / LLM_MODEL — for reasoning tasks (assessment, suggestions)
When OCR keys are not set, falls back to LLM keys.
When neither is set, returns mock data (sufficient for demo).
"""
import base64
import json
import logging
from pathlib import Path
import ast
import re
import httpx
@@ -14,28 +22,163 @@ from app.core.config import settings
from app.models.evidence_image import SourceApp, PageType
logger = logging.getLogger(__name__)
ENABLE_LLM_REPAIR = False # temporary: disabled per debugging request
def _ocr_config() -> tuple[str, str, str]:
"""Return (api_url, api_key, model) for OCR, falling back to LLM config."""
url = settings.OCR_API_URL or settings.LLM_API_URL
key = settings.OCR_API_KEY or settings.LLM_API_KEY
model = settings.OCR_MODEL or settings.LLM_MODEL
return url, key, model
def _llm_config() -> tuple[str, str, str]:
"""Return (api_url, api_key, model) for text-only LLM repair."""
return settings.LLM_API_URL, settings.LLM_API_KEY, settings.LLM_MODEL
def _ocr_available() -> bool:
url, key, model = _ocr_config()
return bool(url and key and model)
def _llm_available() -> bool:
url, key, model = _llm_config()
return bool(url and key and model)
# ── provider-agnostic interface ──────────────────────────────────────────
async def classify_page(image_path: str) -> tuple[SourceApp, PageType]:
"""Identify the source app and page type of a screenshot."""
if settings.LLM_API_KEY and settings.LLM_API_URL:
if _ocr_available():
return await _classify_via_api(image_path)
return _classify_mock(image_path)
async def extract_transaction_fields(image_path: str, source_app: SourceApp, page_type: PageType) -> dict:
async def extract_transaction_fields(
image_path: str, source_app: SourceApp, page_type: PageType
) -> tuple[dict | list, str]:
"""Extract structured transaction fields from a screenshot."""
if settings.LLM_API_KEY and settings.LLM_API_URL:
if _ocr_available():
return await _extract_via_api(image_path, source_app, page_type)
return _extract_mock(image_path, source_app, page_type)
mock_data = _extract_mock(image_path, source_app, page_type)
return mock_data, json.dumps(mock_data, ensure_ascii=False)
# ── real API implementation ──────────────────────────────────────────────
# ── OpenAI-compatible API implementation ─────────────────────────────────
async def _call_vision(prompt: str, image_b64: str, max_tokens: int = 2000) -> str:
"""Send a vision request to the OCR model endpoint."""
url, key, model = _ocr_config()
async with httpx.AsyncClient(timeout=60) as client:
resp = await client.post(
url,
headers={"Authorization": f"Bearer {key}"},
json={
"model": model,
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": prompt},
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_b64}"}},
],
}
],
"max_tokens": max_tokens,
"temperature": 0,
},
)
resp.raise_for_status()
body = resp.json()
choice0 = (body.get("choices") or [{}])[0]
msg = choice0.get("message") or {}
content = msg.get("content", "")
finish_reason = choice0.get("finish_reason")
usage = body.get("usage") or {}
return content if isinstance(content, str) else str(content)
async def _call_text_llm(prompt: str, max_tokens: int = 1200) -> str:
"""Send a text-only request to the reasoning LLM endpoint."""
url, key, model = _llm_config()
async with httpx.AsyncClient(timeout=60) as client:
resp = await client.post(
url,
headers={"Authorization": f"Bearer {key}"},
json={
"model": model,
"messages": [{"role": "user", "content": prompt}],
"max_tokens": max_tokens,
},
)
resp.raise_for_status()
return resp.json()["choices"][0]["message"]["content"]
def _parse_json_response(text: str):
"""Strip markdown fences and parse JSON from LLM output."""
cleaned = text.strip()
if cleaned.startswith("```"):
cleaned = cleaned.split("\n", 1)[-1]
if cleaned.endswith("```"):
cleaned = cleaned.rsplit("```", 1)[0]
cleaned = cleaned.strip().removeprefix("json").strip()
# Try parsing the full body first.
try:
return json.loads(cleaned)
except Exception:
pass
# Extract likely JSON block when model wraps with extra text.
match = re.search(r"(\{[\s\S]*\}|\[[\s\S]*\])", cleaned)
candidate = match.group(1) if match else cleaned
# Normalize common model output issues:
# - smart quotes
# - trailing commas
normalized = (
candidate
.replace("", "\"")
.replace("", "\"")
.replace("", "'")
.replace("", "'")
)
normalized = re.sub(r",\s*([}\]])", r"\1", normalized)
try:
return json.loads(normalized)
except Exception:
pass
# Last fallback: Python-like literal payloads with single quotes / True/False/None.
py_like = re.sub(r"\btrue\b", "True", normalized, flags=re.IGNORECASE)
py_like = re.sub(r"\bfalse\b", "False", py_like, flags=re.IGNORECASE)
py_like = re.sub(r"\bnull\b", "None", py_like, flags=re.IGNORECASE)
parsed = ast.literal_eval(py_like)
if isinstance(parsed, (dict, list)):
return parsed
raise ValueError("Parsed payload is neither dict nor list")
async def _repair_broken_json_with_llm(broken_text: str) -> str:
"""Use LLM_MODEL to repair malformed OCR JSON text without adding new semantics."""
prompt = (
"你是JSON修复器。下面是OCR模型返回的可能损坏JSON文本。\n"
"任务:仅修复语法并输出一个可被 json.loads 直接解析的JSON。\n"
"硬性要求:\n"
"1) 只能输出JSON不要Markdown不要解释。\n"
"2) 保留原有字段和语义,不新增未出现的信息。\n"
"3) 若出现截断,允许最小闭合修复(补齐缺失括号/引号)。\n"
"4) 结果必须是对象或数组。\n\n"
"损坏文本如下:\n"
f"{broken_text}"
)
return await _call_text_llm(prompt, max_tokens=1200)
async def _classify_via_api(image_path: str) -> tuple[SourceApp, PageType]:
import base64
full_path = settings.upload_path / image_path
if not full_path.exists():
return SourceApp.other, PageType.unknown
@@ -43,79 +186,64 @@ async def _classify_via_api(image_path: str) -> tuple[SourceApp, PageType]:
image_b64 = base64.b64encode(full_path.read_bytes()).decode()
prompt = (
"请分析这张手机截图判断它来自哪个APPwechat/alipay/bank/digital_wallet/other"
"以及页面类型bill_list/bill_detail/transfer_receipt/sms_notice/balance/unknown"
"只返回JSON: {\"source_app\": \"...\", \"page_type\": \"...\"}"
"以及页面类型bill_list/bill_detail/transfer_receipt/sms_notice/balance/unknown\n"
'只返回JSON,格式: {"source_app": "...", "page_type": "..."}'
)
try:
async with httpx.AsyncClient(timeout=30) as client:
resp = await client.post(
settings.LLM_API_URL,
headers={"Authorization": f"Bearer {settings.LLM_API_KEY}"},
json={
"model": settings.LLM_MODEL,
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": prompt},
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_b64}"}},
],
}
],
"max_tokens": 200,
},
)
resp.raise_for_status()
text = resp.json()["choices"][0]["message"]["content"]
data = json.loads(text.strip().strip("`").removeprefix("json").strip())
return SourceApp(data.get("source_app", "other")), PageType(data.get("page_type", "unknown"))
text = await _call_vision(prompt, image_b64, max_tokens=600)
data = _parse_json_response(text)
return SourceApp(data.get("source_app", "other")), PageType(data.get("page_type", "unknown"))
except Exception as e:
logger.warning("classify_page API failed: %s", e)
return SourceApp.other, PageType.unknown
async def _extract_via_api(image_path: str, source_app: SourceApp, page_type: PageType) -> dict:
import base64
async def _extract_via_api(
image_path: str, source_app: SourceApp, page_type: PageType
) -> tuple[dict | list, str]:
full_path = settings.upload_path / image_path
if not full_path.exists():
return {}
return {}, ""
image_b64 = base64.b64encode(full_path.read_bytes()).decode()
prompt = (
f"这是一张来自{source_app.value}{page_type.value}截图。"
"请提取其中的交易信息返回JSON格式字段包括"
"trade_time(交易时间,格式YYYY-MM-DD HH:MM:SS), amount(金额,数字), "
"direction(in或out), counterparty_name(对方名称), counterparty_account(对方账号), "
"self_account_tail_no(本方账户尾号), order_no(订单号), remark(备注), confidence(0-1)。"
"如果截图包含多笔交易返回JSON数组。否则返回单个JSON对象。"
"你是账单OCR结构化引擎。你的输出会被程序直接 json.loads 解析任何非JSON字符都会导致任务失败。\n"
f"输入图片来自 {source_app.value} / {page_type.value}\n\n"
"【硬性输出规则】\n"
"1) 只能输出一个合法JSON值对象或数组禁止Markdown代码块、禁止解释文字、禁止注释。\n"
"2) 键名必须使用英文双引号;字符串值必须使用英文双引号;禁止单引号、禁止中文引号。\n"
"3) 禁止尾逗号,禁止 NaN/Infinity/undefined。\n"
"4) 若只有1笔交易输出对象若>=2笔交易输出数组。\n"
"5) 每笔交易字段固定为:\n"
' {"trade_time":"YYYY-MM-DD HH:MM:SS","amount":123.45,"direction":"in|out",'
'"counterparty_name":"","counterparty_account":"","self_account_tail_no":"","order_no":"","remark":"","confidence":0.0}\n'
"6) 字段约束:\n"
"- trade_time: 必须为 YYYY-MM-DD HH:MM:SS无法识别时填空字符串\"\"\n"
"- amount: 必须是数字(不要货币符号);无法识别时填 0。\n"
"- direction: 仅允许 in 或 out无法判断默认 out。\n"
"- confidence: 0~1 的数字;无法判断默认 0.5。\n"
"- 其他文本字段无法识别时填空字符串\"\"\n\n"
"【输出前自检】\n"
"- 检查是否是严格JSON可被 json.loads 直接解析)。\n"
"- 检查每条记录都包含全部9个字段。\n"
"- 检查 amount/confidence 为数字类型direction 仅为 in/out。\n"
"现在只输出最终JSON不要输出任何额外文本。"
)
try:
async with httpx.AsyncClient(timeout=60) as client:
resp = await client.post(
settings.LLM_API_URL,
headers={"Authorization": f"Bearer {settings.LLM_API_KEY}"},
json={
"model": settings.LLM_MODEL,
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": prompt},
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_b64}"}},
],
}
],
"max_tokens": 2000,
},
)
resp.raise_for_status()
text = resp.json()["choices"][0]["message"]["content"]
return json.loads(text.strip().strip("`").removeprefix("json").strip())
text = await _call_vision(prompt, image_b64, max_tokens=6000)
try:
parsed = _parse_json_response(text)
except Exception as parse_err:
if not ENABLE_LLM_REPAIR or not _llm_available():
raise
repaired_text = await _repair_broken_json_with_llm(text)
parsed = _parse_json_response(repaired_text)
return parsed, text
except Exception as e:
logger.warning("extract_transaction_fields API failed: %s", e)
return {}
return {}, text if "text" in locals() else ""
# ── mock fallback ────────────────────────────────────────────────────────