fix ocr
This commit is contained in:
@@ -1,12 +1,20 @@
|
||||
"""OCR and multimodal extraction service.
|
||||
|
||||
Wraps calls to cloud OCR / multimodal APIs with a provider-agnostic interface.
|
||||
When API keys are not configured, falls back to a mock implementation that
|
||||
returns placeholder data (sufficient for demo / competition).
|
||||
Both classify_page and extract_transaction_fields use OpenAI-compatible
|
||||
multimodal chat completion APIs. OCR and LLM can point to different
|
||||
providers / models via separate env vars:
|
||||
|
||||
OCR_API_URL / OCR_API_KEY / OCR_MODEL — for page classification & field extraction
|
||||
LLM_API_URL / LLM_API_KEY / LLM_MODEL — for reasoning tasks (assessment, suggestions)
|
||||
|
||||
When OCR keys are not set, falls back to LLM keys.
|
||||
When neither is set, returns mock data (sufficient for demo).
|
||||
"""
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
import ast
|
||||
import re
|
||||
|
||||
import httpx
|
||||
|
||||
@@ -14,28 +22,163 @@ from app.core.config import settings
|
||||
from app.models.evidence_image import SourceApp, PageType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
ENABLE_LLM_REPAIR = False # temporary: disabled per debugging request
|
||||
|
||||
|
||||
def _ocr_config() -> tuple[str, str, str]:
|
||||
"""Return (api_url, api_key, model) for OCR, falling back to LLM config."""
|
||||
url = settings.OCR_API_URL or settings.LLM_API_URL
|
||||
key = settings.OCR_API_KEY or settings.LLM_API_KEY
|
||||
model = settings.OCR_MODEL or settings.LLM_MODEL
|
||||
return url, key, model
|
||||
|
||||
|
||||
def _llm_config() -> tuple[str, str, str]:
|
||||
"""Return (api_url, api_key, model) for text-only LLM repair."""
|
||||
return settings.LLM_API_URL, settings.LLM_API_KEY, settings.LLM_MODEL
|
||||
|
||||
|
||||
def _ocr_available() -> bool:
|
||||
url, key, model = _ocr_config()
|
||||
return bool(url and key and model)
|
||||
|
||||
|
||||
def _llm_available() -> bool:
|
||||
url, key, model = _llm_config()
|
||||
return bool(url and key and model)
|
||||
|
||||
|
||||
# ── provider-agnostic interface ──────────────────────────────────────────
|
||||
|
||||
async def classify_page(image_path: str) -> tuple[SourceApp, PageType]:
|
||||
"""Identify the source app and page type of a screenshot."""
|
||||
if settings.LLM_API_KEY and settings.LLM_API_URL:
|
||||
if _ocr_available():
|
||||
return await _classify_via_api(image_path)
|
||||
return _classify_mock(image_path)
|
||||
|
||||
|
||||
async def extract_transaction_fields(image_path: str, source_app: SourceApp, page_type: PageType) -> dict:
|
||||
async def extract_transaction_fields(
|
||||
image_path: str, source_app: SourceApp, page_type: PageType
|
||||
) -> tuple[dict | list, str]:
|
||||
"""Extract structured transaction fields from a screenshot."""
|
||||
if settings.LLM_API_KEY and settings.LLM_API_URL:
|
||||
if _ocr_available():
|
||||
return await _extract_via_api(image_path, source_app, page_type)
|
||||
return _extract_mock(image_path, source_app, page_type)
|
||||
mock_data = _extract_mock(image_path, source_app, page_type)
|
||||
return mock_data, json.dumps(mock_data, ensure_ascii=False)
|
||||
|
||||
|
||||
# ── real API implementation ──────────────────────────────────────────────
|
||||
# ── OpenAI-compatible API implementation ─────────────────────────────────
|
||||
|
||||
async def _call_vision(prompt: str, image_b64: str, max_tokens: int = 2000) -> str:
|
||||
"""Send a vision request to the OCR model endpoint."""
|
||||
url, key, model = _ocr_config()
|
||||
async with httpx.AsyncClient(timeout=60) as client:
|
||||
resp = await client.post(
|
||||
url,
|
||||
headers={"Authorization": f"Bearer {key}"},
|
||||
json={
|
||||
"model": model,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": prompt},
|
||||
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_b64}"}},
|
||||
],
|
||||
}
|
||||
],
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": 0,
|
||||
},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
body = resp.json()
|
||||
choice0 = (body.get("choices") or [{}])[0]
|
||||
msg = choice0.get("message") or {}
|
||||
content = msg.get("content", "")
|
||||
finish_reason = choice0.get("finish_reason")
|
||||
usage = body.get("usage") or {}
|
||||
return content if isinstance(content, str) else str(content)
|
||||
|
||||
|
||||
async def _call_text_llm(prompt: str, max_tokens: int = 1200) -> str:
|
||||
"""Send a text-only request to the reasoning LLM endpoint."""
|
||||
url, key, model = _llm_config()
|
||||
async with httpx.AsyncClient(timeout=60) as client:
|
||||
resp = await client.post(
|
||||
url,
|
||||
headers={"Authorization": f"Bearer {key}"},
|
||||
json={
|
||||
"model": model,
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
"max_tokens": max_tokens,
|
||||
},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
return resp.json()["choices"][0]["message"]["content"]
|
||||
|
||||
|
||||
def _parse_json_response(text: str):
|
||||
"""Strip markdown fences and parse JSON from LLM output."""
|
||||
cleaned = text.strip()
|
||||
if cleaned.startswith("```"):
|
||||
cleaned = cleaned.split("\n", 1)[-1]
|
||||
if cleaned.endswith("```"):
|
||||
cleaned = cleaned.rsplit("```", 1)[0]
|
||||
cleaned = cleaned.strip().removeprefix("json").strip()
|
||||
# Try parsing the full body first.
|
||||
try:
|
||||
return json.loads(cleaned)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Extract likely JSON block when model wraps with extra text.
|
||||
match = re.search(r"(\{[\s\S]*\}|\[[\s\S]*\])", cleaned)
|
||||
candidate = match.group(1) if match else cleaned
|
||||
|
||||
# Normalize common model output issues:
|
||||
# - smart quotes
|
||||
# - trailing commas
|
||||
normalized = (
|
||||
candidate
|
||||
.replace("“", "\"")
|
||||
.replace("”", "\"")
|
||||
.replace("‘", "'")
|
||||
.replace("’", "'")
|
||||
)
|
||||
normalized = re.sub(r",\s*([}\]])", r"\1", normalized)
|
||||
try:
|
||||
return json.loads(normalized)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Last fallback: Python-like literal payloads with single quotes / True/False/None.
|
||||
py_like = re.sub(r"\btrue\b", "True", normalized, flags=re.IGNORECASE)
|
||||
py_like = re.sub(r"\bfalse\b", "False", py_like, flags=re.IGNORECASE)
|
||||
py_like = re.sub(r"\bnull\b", "None", py_like, flags=re.IGNORECASE)
|
||||
parsed = ast.literal_eval(py_like)
|
||||
if isinstance(parsed, (dict, list)):
|
||||
return parsed
|
||||
raise ValueError("Parsed payload is neither dict nor list")
|
||||
|
||||
|
||||
async def _repair_broken_json_with_llm(broken_text: str) -> str:
|
||||
"""Use LLM_MODEL to repair malformed OCR JSON text without adding new semantics."""
|
||||
prompt = (
|
||||
"你是JSON修复器。下面是OCR模型返回的可能损坏JSON文本。\n"
|
||||
"任务:仅修复语法并输出一个可被 json.loads 直接解析的JSON。\n"
|
||||
"硬性要求:\n"
|
||||
"1) 只能输出JSON,不要Markdown,不要解释。\n"
|
||||
"2) 保留原有字段和语义,不新增未出现的信息。\n"
|
||||
"3) 若出现截断,允许最小闭合修复(补齐缺失括号/引号)。\n"
|
||||
"4) 结果必须是对象或数组。\n\n"
|
||||
"损坏文本如下:\n"
|
||||
f"{broken_text}"
|
||||
)
|
||||
return await _call_text_llm(prompt, max_tokens=1200)
|
||||
|
||||
|
||||
async def _classify_via_api(image_path: str) -> tuple[SourceApp, PageType]:
|
||||
import base64
|
||||
full_path = settings.upload_path / image_path
|
||||
if not full_path.exists():
|
||||
return SourceApp.other, PageType.unknown
|
||||
@@ -43,79 +186,64 @@ async def _classify_via_api(image_path: str) -> tuple[SourceApp, PageType]:
|
||||
image_b64 = base64.b64encode(full_path.read_bytes()).decode()
|
||||
prompt = (
|
||||
"请分析这张手机截图,判断它来自哪个APP(wechat/alipay/bank/digital_wallet/other)"
|
||||
"以及页面类型(bill_list/bill_detail/transfer_receipt/sms_notice/balance/unknown)。"
|
||||
"只返回JSON: {\"source_app\": \"...\", \"page_type\": \"...\"}"
|
||||
"以及页面类型(bill_list/bill_detail/transfer_receipt/sms_notice/balance/unknown)。\n"
|
||||
'只返回JSON,格式: {"source_app": "...", "page_type": "..."}'
|
||||
)
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=30) as client:
|
||||
resp = await client.post(
|
||||
settings.LLM_API_URL,
|
||||
headers={"Authorization": f"Bearer {settings.LLM_API_KEY}"},
|
||||
json={
|
||||
"model": settings.LLM_MODEL,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": prompt},
|
||||
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_b64}"}},
|
||||
],
|
||||
}
|
||||
],
|
||||
"max_tokens": 200,
|
||||
},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
text = resp.json()["choices"][0]["message"]["content"]
|
||||
data = json.loads(text.strip().strip("`").removeprefix("json").strip())
|
||||
return SourceApp(data.get("source_app", "other")), PageType(data.get("page_type", "unknown"))
|
||||
text = await _call_vision(prompt, image_b64, max_tokens=600)
|
||||
data = _parse_json_response(text)
|
||||
return SourceApp(data.get("source_app", "other")), PageType(data.get("page_type", "unknown"))
|
||||
except Exception as e:
|
||||
logger.warning("classify_page API failed: %s", e)
|
||||
return SourceApp.other, PageType.unknown
|
||||
|
||||
|
||||
async def _extract_via_api(image_path: str, source_app: SourceApp, page_type: PageType) -> dict:
|
||||
import base64
|
||||
async def _extract_via_api(
|
||||
image_path: str, source_app: SourceApp, page_type: PageType
|
||||
) -> tuple[dict | list, str]:
|
||||
full_path = settings.upload_path / image_path
|
||||
if not full_path.exists():
|
||||
return {}
|
||||
return {}, ""
|
||||
|
||||
image_b64 = base64.b64encode(full_path.read_bytes()).decode()
|
||||
prompt = (
|
||||
f"这是一张来自{source_app.value}的{page_type.value}截图。"
|
||||
"请提取其中的交易信息,返回JSON格式,字段包括:"
|
||||
"trade_time(交易时间,格式YYYY-MM-DD HH:MM:SS), amount(金额,数字), "
|
||||
"direction(in或out), counterparty_name(对方名称), counterparty_account(对方账号), "
|
||||
"self_account_tail_no(本方账户尾号), order_no(订单号), remark(备注), confidence(0-1)。"
|
||||
"如果截图包含多笔交易,返回JSON数组。否则返回单个JSON对象。"
|
||||
"你是账单OCR结构化引擎。你的输出会被程序直接 json.loads 解析,任何非JSON字符都会导致任务失败。\n"
|
||||
f"输入图片来自 {source_app.value} / {page_type.value}。\n\n"
|
||||
"【硬性输出规则】\n"
|
||||
"1) 只能输出一个合法JSON值:对象或数组;禁止Markdown代码块、禁止解释文字、禁止注释。\n"
|
||||
"2) 键名必须使用英文双引号;字符串值必须使用英文双引号;禁止单引号、禁止中文引号。\n"
|
||||
"3) 禁止尾逗号,禁止 NaN/Infinity/undefined。\n"
|
||||
"4) 若只有1笔交易,输出对象;若>=2笔交易,输出数组。\n"
|
||||
"5) 每笔交易字段固定为:\n"
|
||||
' {"trade_time":"YYYY-MM-DD HH:MM:SS","amount":123.45,"direction":"in|out",'
|
||||
'"counterparty_name":"","counterparty_account":"","self_account_tail_no":"","order_no":"","remark":"","confidence":0.0}\n'
|
||||
"6) 字段约束:\n"
|
||||
"- trade_time: 必须为 YYYY-MM-DD HH:MM:SS;无法识别时填空字符串\"\"。\n"
|
||||
"- amount: 必须是数字(不要货币符号);无法识别时填 0。\n"
|
||||
"- direction: 仅允许 in 或 out;无法判断默认 out。\n"
|
||||
"- confidence: 0~1 的数字;无法判断默认 0.5。\n"
|
||||
"- 其他文本字段无法识别时填空字符串\"\"。\n\n"
|
||||
"【输出前自检】\n"
|
||||
"- 检查是否是严格JSON(可被 json.loads 直接解析)。\n"
|
||||
"- 检查每条记录都包含全部9个字段。\n"
|
||||
"- 检查 amount/confidence 为数字类型,direction 仅为 in/out。\n"
|
||||
"现在只输出最终JSON,不要输出任何额外文本。"
|
||||
)
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=60) as client:
|
||||
resp = await client.post(
|
||||
settings.LLM_API_URL,
|
||||
headers={"Authorization": f"Bearer {settings.LLM_API_KEY}"},
|
||||
json={
|
||||
"model": settings.LLM_MODEL,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": prompt},
|
||||
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_b64}"}},
|
||||
],
|
||||
}
|
||||
],
|
||||
"max_tokens": 2000,
|
||||
},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
text = resp.json()["choices"][0]["message"]["content"]
|
||||
return json.loads(text.strip().strip("`").removeprefix("json").strip())
|
||||
text = await _call_vision(prompt, image_b64, max_tokens=6000)
|
||||
try:
|
||||
parsed = _parse_json_response(text)
|
||||
except Exception as parse_err:
|
||||
if not ENABLE_LLM_REPAIR or not _llm_available():
|
||||
raise
|
||||
repaired_text = await _repair_broken_json_with_llm(text)
|
||||
parsed = _parse_json_response(repaired_text)
|
||||
return parsed, text
|
||||
except Exception as e:
|
||||
logger.warning("extract_transaction_fields API failed: %s", e)
|
||||
return {}
|
||||
return {}, text if "text" in locals() else ""
|
||||
|
||||
|
||||
# ── mock fallback ────────────────────────────────────────────────────────
|
||||
|
||||
Reference in New Issue
Block a user