2026-03-11 16:28:04 +08:00
|
|
|
|
"""OCR and multimodal extraction service.
|
|
|
|
|
|
|
2026-03-12 12:32:29 +08:00
|
|
|
|
Both classify_page and extract_transaction_fields use OpenAI-compatible
|
|
|
|
|
|
multimodal chat completion APIs. OCR and LLM can point to different
|
|
|
|
|
|
providers / models via separate env vars:
|
|
|
|
|
|
|
|
|
|
|
|
OCR_API_URL / OCR_API_KEY / OCR_MODEL — for page classification & field extraction
|
|
|
|
|
|
LLM_API_URL / LLM_API_KEY / LLM_MODEL — for reasoning tasks (assessment, suggestions)
|
|
|
|
|
|
|
|
|
|
|
|
When OCR keys are not set, falls back to LLM keys.
|
|
|
|
|
|
When neither is set, returns mock data (sufficient for demo).
|
2026-03-11 16:28:04 +08:00
|
|
|
|
"""
|
2026-03-12 12:32:29 +08:00
|
|
|
|
import base64
|
2026-03-11 16:28:04 +08:00
|
|
|
|
import json
|
|
|
|
|
|
import logging
|
2026-03-12 12:32:29 +08:00
|
|
|
|
import ast
|
|
|
|
|
|
import re
|
2026-03-11 16:28:04 +08:00
|
|
|
|
|
|
|
|
|
|
import httpx
|
|
|
|
|
|
|
|
|
|
|
|
from app.core.config import settings
|
|
|
|
|
|
from app.models.evidence_image import SourceApp, PageType
|
|
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
2026-03-12 12:32:29 +08:00
|
|
|
|
ENABLE_LLM_REPAIR = False # temporary: disabled per debugging request
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _ocr_config() -> tuple[str, str, str]:
|
|
|
|
|
|
"""Return (api_url, api_key, model) for OCR, falling back to LLM config."""
|
|
|
|
|
|
url = settings.OCR_API_URL or settings.LLM_API_URL
|
|
|
|
|
|
key = settings.OCR_API_KEY or settings.LLM_API_KEY
|
|
|
|
|
|
model = settings.OCR_MODEL or settings.LLM_MODEL
|
|
|
|
|
|
return url, key, model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _llm_config() -> tuple[str, str, str]:
|
|
|
|
|
|
"""Return (api_url, api_key, model) for text-only LLM repair."""
|
|
|
|
|
|
return settings.LLM_API_URL, settings.LLM_API_KEY, settings.LLM_MODEL
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _ocr_available() -> bool:
|
|
|
|
|
|
url, key, model = _ocr_config()
|
|
|
|
|
|
return bool(url and key and model)
|
|
|
|
|
|
|
|
|
|
|
|
|
2026-03-13 23:29:55 +08:00
|
|
|
|
def _missing_ocr_fields() -> list[str]:
|
|
|
|
|
|
missing: list[str] = []
|
|
|
|
|
|
if not (settings.OCR_API_URL or settings.LLM_API_URL):
|
|
|
|
|
|
missing.append("OCR_API_URL(or LLM_API_URL)")
|
|
|
|
|
|
if not (settings.OCR_API_KEY or settings.LLM_API_KEY):
|
|
|
|
|
|
missing.append("OCR_API_KEY(or LLM_API_KEY)")
|
|
|
|
|
|
if not (settings.OCR_MODEL or settings.LLM_MODEL):
|
|
|
|
|
|
missing.append("OCR_MODEL(or LLM_MODEL)")
|
|
|
|
|
|
return missing
|
|
|
|
|
|
|
|
|
|
|
|
|
2026-03-12 12:32:29 +08:00
|
|
|
|
def _llm_available() -> bool:
|
|
|
|
|
|
url, key, model = _llm_config()
|
|
|
|
|
|
return bool(url and key and model)
|
2026-03-11 16:28:04 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ── provider-agnostic interface ──────────────────────────────────────────
|
|
|
|
|
|
|
|
|
|
|
|
async def classify_page(image_path: str) -> tuple[SourceApp, PageType]:
|
|
|
|
|
|
"""Identify the source app and page type of a screenshot."""
|
2026-03-12 12:32:29 +08:00
|
|
|
|
if _ocr_available():
|
2026-03-11 16:28:04 +08:00
|
|
|
|
return await _classify_via_api(image_path)
|
2026-03-13 23:29:55 +08:00
|
|
|
|
if not settings.OCR_ALLOW_MOCK_FALLBACK:
|
|
|
|
|
|
missing = ", ".join(_missing_ocr_fields()) or "unknown"
|
|
|
|
|
|
raise RuntimeError(f"OCR configuration missing: {missing}")
|
|
|
|
|
|
logger.warning("OCR unavailable, falling back to mock classification for image: %s", image_path)
|
2026-03-11 16:28:04 +08:00
|
|
|
|
return _classify_mock(image_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
2026-03-12 12:32:29 +08:00
|
|
|
|
async def extract_transaction_fields(
|
|
|
|
|
|
image_path: str, source_app: SourceApp, page_type: PageType
|
|
|
|
|
|
) -> tuple[dict | list, str]:
|
2026-03-11 16:28:04 +08:00
|
|
|
|
"""Extract structured transaction fields from a screenshot."""
|
2026-03-12 12:32:29 +08:00
|
|
|
|
if _ocr_available():
|
2026-03-11 16:28:04 +08:00
|
|
|
|
return await _extract_via_api(image_path, source_app, page_type)
|
2026-03-13 23:29:55 +08:00
|
|
|
|
if not settings.OCR_ALLOW_MOCK_FALLBACK:
|
|
|
|
|
|
missing = ", ".join(_missing_ocr_fields()) or "unknown"
|
|
|
|
|
|
raise RuntimeError(f"OCR configuration missing: {missing}")
|
|
|
|
|
|
logger.warning("OCR unavailable, falling back to mock extraction for image: %s", image_path)
|
2026-03-12 12:32:29 +08:00
|
|
|
|
mock_data = _extract_mock(image_path, source_app, page_type)
|
|
|
|
|
|
return mock_data, json.dumps(mock_data, ensure_ascii=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ── OpenAI-compatible API implementation ─────────────────────────────────
|
|
|
|
|
|
|
|
|
|
|
|
async def _call_vision(prompt: str, image_b64: str, max_tokens: int = 2000) -> str:
|
|
|
|
|
|
"""Send a vision request to the OCR model endpoint."""
|
|
|
|
|
|
url, key, model = _ocr_config()
|
|
|
|
|
|
async with httpx.AsyncClient(timeout=60) as client:
|
|
|
|
|
|
resp = await client.post(
|
|
|
|
|
|
url,
|
|
|
|
|
|
headers={"Authorization": f"Bearer {key}"},
|
|
|
|
|
|
json={
|
|
|
|
|
|
"model": model,
|
|
|
|
|
|
"messages": [
|
|
|
|
|
|
{
|
|
|
|
|
|
"role": "user",
|
|
|
|
|
|
"content": [
|
|
|
|
|
|
{"type": "text", "text": prompt},
|
|
|
|
|
|
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_b64}"}},
|
|
|
|
|
|
],
|
|
|
|
|
|
}
|
|
|
|
|
|
],
|
|
|
|
|
|
"max_tokens": max_tokens,
|
|
|
|
|
|
"temperature": 0,
|
|
|
|
|
|
},
|
|
|
|
|
|
)
|
|
|
|
|
|
resp.raise_for_status()
|
|
|
|
|
|
body = resp.json()
|
|
|
|
|
|
choice0 = (body.get("choices") or [{}])[0]
|
|
|
|
|
|
msg = choice0.get("message") or {}
|
|
|
|
|
|
content = msg.get("content", "")
|
|
|
|
|
|
finish_reason = choice0.get("finish_reason")
|
|
|
|
|
|
usage = body.get("usage") or {}
|
|
|
|
|
|
return content if isinstance(content, str) else str(content)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def _call_text_llm(prompt: str, max_tokens: int = 1200) -> str:
|
|
|
|
|
|
"""Send a text-only request to the reasoning LLM endpoint."""
|
|
|
|
|
|
url, key, model = _llm_config()
|
|
|
|
|
|
async with httpx.AsyncClient(timeout=60) as client:
|
|
|
|
|
|
resp = await client.post(
|
|
|
|
|
|
url,
|
|
|
|
|
|
headers={"Authorization": f"Bearer {key}"},
|
|
|
|
|
|
json={
|
|
|
|
|
|
"model": model,
|
|
|
|
|
|
"messages": [{"role": "user", "content": prompt}],
|
|
|
|
|
|
"max_tokens": max_tokens,
|
|
|
|
|
|
},
|
|
|
|
|
|
)
|
|
|
|
|
|
resp.raise_for_status()
|
|
|
|
|
|
return resp.json()["choices"][0]["message"]["content"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _parse_json_response(text: str):
|
|
|
|
|
|
"""Strip markdown fences and parse JSON from LLM output."""
|
|
|
|
|
|
cleaned = text.strip()
|
|
|
|
|
|
if cleaned.startswith("```"):
|
|
|
|
|
|
cleaned = cleaned.split("\n", 1)[-1]
|
|
|
|
|
|
if cleaned.endswith("```"):
|
|
|
|
|
|
cleaned = cleaned.rsplit("```", 1)[0]
|
|
|
|
|
|
cleaned = cleaned.strip().removeprefix("json").strip()
|
|
|
|
|
|
# Try parsing the full body first.
|
|
|
|
|
|
try:
|
|
|
|
|
|
return json.loads(cleaned)
|
|
|
|
|
|
except Exception:
|
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
# Extract likely JSON block when model wraps with extra text.
|
|
|
|
|
|
match = re.search(r"(\{[\s\S]*\}|\[[\s\S]*\])", cleaned)
|
|
|
|
|
|
candidate = match.group(1) if match else cleaned
|
|
|
|
|
|
|
|
|
|
|
|
# Normalize common model output issues:
|
|
|
|
|
|
# - smart quotes
|
|
|
|
|
|
# - trailing commas
|
|
|
|
|
|
normalized = (
|
|
|
|
|
|
candidate
|
|
|
|
|
|
.replace("“", "\"")
|
|
|
|
|
|
.replace("”", "\"")
|
|
|
|
|
|
.replace("‘", "'")
|
|
|
|
|
|
.replace("’", "'")
|
|
|
|
|
|
)
|
|
|
|
|
|
normalized = re.sub(r",\s*([}\]])", r"\1", normalized)
|
|
|
|
|
|
try:
|
|
|
|
|
|
return json.loads(normalized)
|
|
|
|
|
|
except Exception:
|
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
# Last fallback: Python-like literal payloads with single quotes / True/False/None.
|
|
|
|
|
|
py_like = re.sub(r"\btrue\b", "True", normalized, flags=re.IGNORECASE)
|
|
|
|
|
|
py_like = re.sub(r"\bfalse\b", "False", py_like, flags=re.IGNORECASE)
|
|
|
|
|
|
py_like = re.sub(r"\bnull\b", "None", py_like, flags=re.IGNORECASE)
|
|
|
|
|
|
parsed = ast.literal_eval(py_like)
|
|
|
|
|
|
if isinstance(parsed, (dict, list)):
|
|
|
|
|
|
return parsed
|
|
|
|
|
|
raise ValueError("Parsed payload is neither dict nor list")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def _repair_broken_json_with_llm(broken_text: str) -> str:
|
|
|
|
|
|
"""Use LLM_MODEL to repair malformed OCR JSON text without adding new semantics."""
|
|
|
|
|
|
prompt = (
|
|
|
|
|
|
"你是JSON修复器。下面是OCR模型返回的可能损坏JSON文本。\n"
|
|
|
|
|
|
"任务:仅修复语法并输出一个可被 json.loads 直接解析的JSON。\n"
|
|
|
|
|
|
"硬性要求:\n"
|
|
|
|
|
|
"1) 只能输出JSON,不要Markdown,不要解释。\n"
|
|
|
|
|
|
"2) 保留原有字段和语义,不新增未出现的信息。\n"
|
|
|
|
|
|
"3) 若出现截断,允许最小闭合修复(补齐缺失括号/引号)。\n"
|
|
|
|
|
|
"4) 结果必须是对象或数组。\n\n"
|
|
|
|
|
|
"损坏文本如下:\n"
|
|
|
|
|
|
f"{broken_text}"
|
|
|
|
|
|
)
|
|
|
|
|
|
return await _call_text_llm(prompt, max_tokens=1200)
|
2026-03-11 16:28:04 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def _classify_via_api(image_path: str) -> tuple[SourceApp, PageType]:
|
|
|
|
|
|
full_path = settings.upload_path / image_path
|
|
|
|
|
|
if not full_path.exists():
|
|
|
|
|
|
return SourceApp.other, PageType.unknown
|
|
|
|
|
|
|
|
|
|
|
|
image_b64 = base64.b64encode(full_path.read_bytes()).decode()
|
|
|
|
|
|
prompt = (
|
|
|
|
|
|
"请分析这张手机截图,判断它来自哪个APP(wechat/alipay/bank/digital_wallet/other)"
|
2026-03-12 12:32:29 +08:00
|
|
|
|
"以及页面类型(bill_list/bill_detail/transfer_receipt/sms_notice/balance/unknown)。\n"
|
|
|
|
|
|
'只返回JSON,格式: {"source_app": "...", "page_type": "..."}'
|
2026-03-11 16:28:04 +08:00
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
2026-03-12 12:32:29 +08:00
|
|
|
|
text = await _call_vision(prompt, image_b64, max_tokens=600)
|
|
|
|
|
|
data = _parse_json_response(text)
|
|
|
|
|
|
return SourceApp(data.get("source_app", "other")), PageType(data.get("page_type", "unknown"))
|
2026-03-11 16:28:04 +08:00
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logger.warning("classify_page API failed: %s", e)
|
|
|
|
|
|
return SourceApp.other, PageType.unknown
|
|
|
|
|
|
|
|
|
|
|
|
|
2026-03-12 12:32:29 +08:00
|
|
|
|
async def _extract_via_api(
|
|
|
|
|
|
image_path: str, source_app: SourceApp, page_type: PageType
|
|
|
|
|
|
) -> tuple[dict | list, str]:
|
2026-03-11 16:28:04 +08:00
|
|
|
|
full_path = settings.upload_path / image_path
|
|
|
|
|
|
if not full_path.exists():
|
2026-03-12 12:32:29 +08:00
|
|
|
|
return {}, ""
|
2026-03-11 16:28:04 +08:00
|
|
|
|
|
|
|
|
|
|
image_b64 = base64.b64encode(full_path.read_bytes()).decode()
|
|
|
|
|
|
prompt = (
|
2026-03-12 12:32:29 +08:00
|
|
|
|
"你是账单OCR结构化引擎。你的输出会被程序直接 json.loads 解析,任何非JSON字符都会导致任务失败。\n"
|
|
|
|
|
|
f"输入图片来自 {source_app.value} / {page_type.value}。\n\n"
|
|
|
|
|
|
"【硬性输出规则】\n"
|
|
|
|
|
|
"1) 只能输出一个合法JSON值:对象或数组;禁止Markdown代码块、禁止解释文字、禁止注释。\n"
|
|
|
|
|
|
"2) 键名必须使用英文双引号;字符串值必须使用英文双引号;禁止单引号、禁止中文引号。\n"
|
|
|
|
|
|
"3) 禁止尾逗号,禁止 NaN/Infinity/undefined。\n"
|
|
|
|
|
|
"4) 若只有1笔交易,输出对象;若>=2笔交易,输出数组。\n"
|
|
|
|
|
|
"5) 每笔交易字段固定为:\n"
|
|
|
|
|
|
' {"trade_time":"YYYY-MM-DD HH:MM:SS","amount":123.45,"direction":"in|out",'
|
|
|
|
|
|
'"counterparty_name":"","counterparty_account":"","self_account_tail_no":"","order_no":"","remark":"","confidence":0.0}\n'
|
|
|
|
|
|
"6) 字段约束:\n"
|
|
|
|
|
|
"- trade_time: 必须为 YYYY-MM-DD HH:MM:SS;无法识别时填空字符串\"\"。\n"
|
|
|
|
|
|
"- amount: 必须是数字(不要货币符号);无法识别时填 0。\n"
|
|
|
|
|
|
"- direction: 仅允许 in 或 out;无法判断默认 out。\n"
|
|
|
|
|
|
"- confidence: 0~1 的数字;无法判断默认 0.5。\n"
|
|
|
|
|
|
"- 其他文本字段无法识别时填空字符串\"\"。\n\n"
|
|
|
|
|
|
"【输出前自检】\n"
|
|
|
|
|
|
"- 检查是否是严格JSON(可被 json.loads 直接解析)。\n"
|
|
|
|
|
|
"- 检查每条记录都包含全部9个字段。\n"
|
|
|
|
|
|
"- 检查 amount/confidence 为数字类型,direction 仅为 in/out。\n"
|
|
|
|
|
|
"现在只输出最终JSON,不要输出任何额外文本。"
|
2026-03-11 16:28:04 +08:00
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
2026-03-12 12:32:29 +08:00
|
|
|
|
text = await _call_vision(prompt, image_b64, max_tokens=6000)
|
|
|
|
|
|
try:
|
|
|
|
|
|
parsed = _parse_json_response(text)
|
|
|
|
|
|
except Exception as parse_err:
|
|
|
|
|
|
if not ENABLE_LLM_REPAIR or not _llm_available():
|
|
|
|
|
|
raise
|
|
|
|
|
|
repaired_text = await _repair_broken_json_with_llm(text)
|
|
|
|
|
|
parsed = _parse_json_response(repaired_text)
|
|
|
|
|
|
return parsed, text
|
2026-03-11 16:28:04 +08:00
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logger.warning("extract_transaction_fields API failed: %s", e)
|
2026-03-12 12:32:29 +08:00
|
|
|
|
return {}, text if "text" in locals() else ""
|
2026-03-11 16:28:04 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ── mock fallback ────────────────────────────────────────────────────────
|
|
|
|
|
|
|
|
|
|
|
|
def _classify_mock(image_path: str) -> tuple[SourceApp, PageType]:
|
|
|
|
|
|
lower = image_path.lower()
|
|
|
|
|
|
if "wechat" in lower or "wx" in lower:
|
|
|
|
|
|
return SourceApp.wechat, PageType.bill_detail
|
|
|
|
|
|
if "alipay" in lower or "ali" in lower:
|
|
|
|
|
|
return SourceApp.alipay, PageType.bill_list
|
|
|
|
|
|
if "bank" in lower:
|
|
|
|
|
|
return SourceApp.bank, PageType.bill_detail
|
|
|
|
|
|
return SourceApp.other, PageType.unknown
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _extract_mock(image_path: str, source_app: SourceApp, page_type: PageType) -> dict:
|
|
|
|
|
|
return {
|
|
|
|
|
|
"trade_time": "2026-03-08 10:00:00",
|
|
|
|
|
|
"amount": 1000.00,
|
|
|
|
|
|
"direction": "out",
|
|
|
|
|
|
"counterparty_name": "模拟对手方",
|
|
|
|
|
|
"counterparty_account": "",
|
|
|
|
|
|
"self_account_tail_no": "",
|
|
|
|
|
|
"order_no": f"MOCK-{hash(image_path) % 100000:05d}",
|
|
|
|
|
|
"remark": "模拟交易",
|
|
|
|
|
|
"confidence": 0.80,
|
|
|
|
|
|
}
|