update: upload fix
This commit is contained in:
17
backend/.env.example
Normal file
17
backend/.env.example
Normal file
@@ -0,0 +1,17 @@
|
||||
DATABASE_URL=sqlite+aiosqlite:///./fund_tracer.db
|
||||
LLM_PROVIDER=openai
|
||||
|
||||
# Optional: choose model names
|
||||
OPENAI_MODEL=gpt-4o
|
||||
ANTHROPIC_MODEL=claude-3-5-sonnet-20241022
|
||||
DEEPSEEK_MODEL=deepseek-chat
|
||||
CUSTOM_OPENAI_MODEL=gpt-4o-mini
|
||||
|
||||
# Custom OpenAI-compatible provider
|
||||
CUSTOM_OPENAI_BASE_URL=
|
||||
CUSTOM_OPENAI_API_KEY=
|
||||
|
||||
# API keys
|
||||
OPENAI_API_KEY=
|
||||
ANTHROPIC_API_KEY=
|
||||
DEEPSEEK_API_KEY=
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Screenshot upload and extraction API."""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File
|
||||
@@ -84,18 +85,105 @@ async def extract_transactions(
|
||||
if not full_path.exists():
|
||||
raise HTTPException(status_code=404, detail="File not found on disk")
|
||||
image_bytes = full_path.read_bytes()
|
||||
started_at = datetime.utcnow()
|
||||
# 每次开始新一轮识别都重置计时,确保耗时是“本次分析”而不是历史累计
|
||||
screenshot.started_at = started_at
|
||||
screenshot.finished_at = None
|
||||
screenshot.duration_ms = None
|
||||
screenshot.error_message = None
|
||||
screenshot.progress_step = "starting"
|
||||
screenshot.progress_percent = 0
|
||||
screenshot.progress_detail = "准备开始识别"
|
||||
await db.commit()
|
||||
|
||||
async def update_progress(step: str, percent: int, detail: str):
|
||||
screenshot.status = "processing"
|
||||
screenshot.progress_step = step
|
||||
screenshot.progress_percent = percent
|
||||
screenshot.progress_detail = detail
|
||||
await db.commit()
|
||||
|
||||
try:
|
||||
transactions = await extract_and_save(case_id, screenshot_id, image_bytes)
|
||||
await update_progress("file_loaded", 10, "截图读取完成")
|
||||
transactions = await extract_and_save(
|
||||
case_id,
|
||||
screenshot_id,
|
||||
image_bytes,
|
||||
progress_hook=update_progress,
|
||||
)
|
||||
except Exception as e:
|
||||
error_detail = _classify_error(e)
|
||||
r = await db.execute(select(Screenshot).where(Screenshot.id == screenshot_id))
|
||||
sc = r.scalar_one_or_none()
|
||||
if sc:
|
||||
sc.status = "failed"
|
||||
sc.progress_step = "failed"
|
||||
sc.progress_percent = 100
|
||||
sc.progress_detail = "识别失败"
|
||||
sc.finished_at = datetime.utcnow()
|
||||
if sc.started_at:
|
||||
sc.duration_ms = int((sc.finished_at - sc.started_at).total_seconds() * 1000)
|
||||
sc.error_message = error_detail
|
||||
await db.commit()
|
||||
raise HTTPException(status_code=502, detail=f"Extraction failed: {e!s}")
|
||||
raise HTTPException(status_code=502, detail=error_detail)
|
||||
r = await db.execute(select(Screenshot).where(Screenshot.id == screenshot_id))
|
||||
sc = r.scalar_one_or_none()
|
||||
if sc:
|
||||
sc.status = "extracted"
|
||||
sc.progress_step = "completed"
|
||||
sc.progress_percent = 100
|
||||
sc.progress_detail = "识别完成"
|
||||
sc.finished_at = datetime.utcnow()
|
||||
if sc.started_at:
|
||||
sc.duration_ms = int((sc.finished_at - sc.started_at).total_seconds() * 1000)
|
||||
sc.error_message = None
|
||||
await db.commit()
|
||||
return TransactionListResponse(items=transactions)
|
||||
|
||||
|
||||
def _classify_error(e: Exception) -> str:
|
||||
"""Produce a human-readable, categorized error message."""
|
||||
name = type(e).__name__
|
||||
msg = str(e)
|
||||
|
||||
if isinstance(e, ValueError):
|
||||
return f"配置错误: {msg}"
|
||||
|
||||
# OpenAI SDK errors
|
||||
try:
|
||||
from openai import AuthenticationError, RateLimitError, APIConnectionError, BadRequestError, APIStatusError, APITimeoutError
|
||||
if isinstance(e, AuthenticationError):
|
||||
return f"API Key 无效或已过期 ({name}): {msg}"
|
||||
if isinstance(e, RateLimitError):
|
||||
return f"API 调用频率超限,请稍后重试 ({name}): {msg}"
|
||||
if isinstance(e, APITimeoutError):
|
||||
return f"模型服务响应超时,请检查 BaseURL/模型可用性或稍后重试 ({name}): {msg}"
|
||||
if isinstance(e, APIConnectionError):
|
||||
return f"无法连接到模型服务,请检查网络或 BaseURL ({name}): {msg}"
|
||||
if isinstance(e, BadRequestError):
|
||||
return f"请求被模型服务拒绝(可能模型名错误或不支持图片) ({name}): {msg}"
|
||||
if isinstance(e, APIStatusError):
|
||||
return f"模型服务返回错误 (HTTP {e.status_code}): {msg}"
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# Anthropic SDK errors
|
||||
try:
|
||||
from anthropic import AuthenticationError as AnthAuthError, RateLimitError as AnthRateError
|
||||
from anthropic import APIConnectionError as AnthConnError, BadRequestError as AnthBadReq
|
||||
if isinstance(e, AnthAuthError):
|
||||
return f"Anthropic API Key 无效或已过期: {msg}"
|
||||
if isinstance(e, AnthRateError):
|
||||
return f"Anthropic API 调用频率超限: {msg}"
|
||||
if isinstance(e, AnthConnError):
|
||||
return f"无法连接到 Anthropic 服务: {msg}"
|
||||
if isinstance(e, AnthBadReq):
|
||||
return f"Anthropic 请求被拒绝: {msg}"
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# Connection / network
|
||||
if "connect" in msg.lower() or "timeout" in msg.lower():
|
||||
return f"网络连接失败或超时: {msg}"
|
||||
|
||||
return f"识别失败 ({name}): {msg}"
|
||||
|
||||
@@ -9,7 +9,10 @@ router = APIRouter()
|
||||
|
||||
|
||||
class SettingsUpdate(BaseModel):
|
||||
llm_provider: str | None = None
|
||||
ocr_provider: str | None = None
|
||||
ocr_model: str | None = None
|
||||
inference_provider: str | None = None
|
||||
inference_model: str | None = None
|
||||
openai_api_key: str | None = None
|
||||
anthropic_api_key: str | None = None
|
||||
deepseek_api_key: str | None = None
|
||||
|
||||
@@ -20,18 +20,30 @@ class Settings(BaseSettings):
|
||||
max_upload_size_mb: int = 20
|
||||
allowed_extensions: set[str] = {"png", "jpg", "jpeg", "webp"}
|
||||
|
||||
# LLM
|
||||
llm_provider: str = "openai" # openai | anthropic | deepseek | custom_openai
|
||||
# --- OCR (vision) model ---
|
||||
ocr_provider: str = "openai" # openai | anthropic | deepseek | custom_openai
|
||||
ocr_model: str | None = None # if None, falls back to provider default
|
||||
|
||||
# --- Inference (reasoning) model ---
|
||||
inference_provider: str = "openai"
|
||||
inference_model: str | None = None
|
||||
|
||||
# --- Provider credentials (shared between OCR and inference) ---
|
||||
openai_api_key: str | None = None
|
||||
anthropic_api_key: str | None = None
|
||||
deepseek_api_key: str | None = None
|
||||
custom_openai_api_key: str | None = None
|
||||
custom_openai_base_url: str | None = None
|
||||
|
||||
# Provider default model names (used when ocr_model / inference_model is None)
|
||||
openai_model: str = "gpt-4o"
|
||||
anthropic_model: str = "claude-3-5-sonnet-20241022"
|
||||
deepseek_model: str = "deepseek-chat"
|
||||
custom_openai_model: str = "gpt-4o-mini"
|
||||
|
||||
# Legacy compat: llm_provider maps to ocr_provider on load
|
||||
llm_provider: str | None = None
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
env_file_encoding = "utf-8"
|
||||
@@ -40,8 +52,24 @@ class Settings(BaseSettings):
|
||||
|
||||
_runtime_overrides: dict[str, str | None] = {}
|
||||
|
||||
_ALLOWED_RUNTIME_KEYS = {
|
||||
"ocr_provider",
|
||||
"ocr_model",
|
||||
"inference_provider",
|
||||
"inference_model",
|
||||
"openai_api_key",
|
||||
"anthropic_api_key",
|
||||
"deepseek_api_key",
|
||||
"custom_openai_api_key",
|
||||
"custom_openai_base_url",
|
||||
"custom_openai_model",
|
||||
}
|
||||
|
||||
|
||||
def _apply_overrides(settings: Settings) -> Settings:
|
||||
# Legacy: if llm_provider is set but ocr_provider is default, use it
|
||||
if settings.llm_provider and settings.ocr_provider == "openai":
|
||||
settings.ocr_provider = settings.llm_provider
|
||||
for key, value in _runtime_overrides.items():
|
||||
if hasattr(settings, key):
|
||||
setattr(settings, key, value)
|
||||
@@ -53,19 +81,32 @@ def get_settings() -> Settings:
|
||||
return _apply_overrides(Settings())
|
||||
|
||||
|
||||
def update_runtime_settings(payload: dict[str, str | None]) -> Settings:
|
||||
"""Update runtime settings and refresh cached Settings object."""
|
||||
allowed = {
|
||||
"llm_provider",
|
||||
"openai_api_key",
|
||||
"anthropic_api_key",
|
||||
"deepseek_api_key",
|
||||
"custom_openai_api_key",
|
||||
"custom_openai_base_url",
|
||||
"custom_openai_model",
|
||||
def _resolve_model(provider: str, explicit_model: str | None, settings: Settings) -> str:
|
||||
"""Return the model name to use for a given provider."""
|
||||
if explicit_model:
|
||||
return explicit_model
|
||||
defaults = {
|
||||
"openai": settings.openai_model,
|
||||
"anthropic": settings.anthropic_model,
|
||||
"deepseek": settings.deepseek_model,
|
||||
"custom_openai": settings.custom_openai_model,
|
||||
}
|
||||
return defaults.get(provider, settings.openai_model)
|
||||
|
||||
|
||||
def get_ocr_model() -> str:
|
||||
s = get_settings()
|
||||
return _resolve_model(s.ocr_provider, s.ocr_model, s)
|
||||
|
||||
|
||||
def get_inference_model() -> str:
|
||||
s = get_settings()
|
||||
return _resolve_model(s.inference_provider, s.inference_model, s)
|
||||
|
||||
|
||||
def update_runtime_settings(payload: dict[str, str | None]) -> Settings:
|
||||
for key, value in payload.items():
|
||||
if key in allowed:
|
||||
if key in _ALLOWED_RUNTIME_KEYS:
|
||||
_runtime_overrides[key] = value
|
||||
get_settings.cache_clear()
|
||||
return get_settings()
|
||||
@@ -74,9 +115,12 @@ def update_runtime_settings(payload: dict[str, str | None]) -> Settings:
|
||||
def public_settings() -> dict:
|
||||
s = get_settings()
|
||||
return {
|
||||
"llm_provider": s.llm_provider,
|
||||
"ocr_provider": s.ocr_provider,
|
||||
"ocr_model": get_ocr_model(),
|
||||
"inference_provider": s.inference_provider,
|
||||
"inference_model": get_inference_model(),
|
||||
"providers": ["openai", "anthropic", "deepseek", "custom_openai"],
|
||||
"models": {
|
||||
"provider_defaults": {
|
||||
"openai": s.openai_model,
|
||||
"anthropic": s.anthropic_model,
|
||||
"deepseek": s.deepseek_model,
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
from datetime import datetime
|
||||
from sqlalchemy import String, DateTime, ForeignKey
|
||||
from sqlalchemy import String, Text, DateTime, ForeignKey
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.models.database import Base
|
||||
@@ -15,7 +15,14 @@ class Screenshot(Base):
|
||||
case_id: Mapped[int] = mapped_column(ForeignKey("cases.id", ondelete="CASCADE"), index=True)
|
||||
filename: Mapped[str] = mapped_column(String(255))
|
||||
file_path: Mapped[str] = mapped_column(String(512))
|
||||
status: Mapped[str] = mapped_column(String(32), default="pending") # pending | extracted | failed
|
||||
status: Mapped[str] = mapped_column(String(32), default="pending") # pending | processing | extracted | failed
|
||||
progress_step: Mapped[str | None] = mapped_column(String(64), nullable=True, default=None)
|
||||
progress_percent: Mapped[int] = mapped_column(default=0)
|
||||
progress_detail: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None)
|
||||
started_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True, default=None)
|
||||
finished_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True, default=None)
|
||||
duration_ms: Mapped[int | None] = mapped_column(nullable=True, default=None)
|
||||
error_message: Mapped[str | None] = mapped_column(Text, nullable=True, default=None)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
|
||||
|
||||
case: Mapped["Case"] = relationship("Case", back_populates="screenshots")
|
||||
|
||||
@@ -11,6 +11,13 @@ class ScreenshotResponse(BaseModel):
|
||||
filename: str
|
||||
file_path: str
|
||||
status: str
|
||||
progress_step: str | None = None
|
||||
progress_percent: int = 0
|
||||
progress_detail: str | None = None
|
||||
started_at: datetime | None = None
|
||||
finished_at: datetime | None = None
|
||||
duration_ms: int | None = None
|
||||
error_message: str | None = None
|
||||
created_at: datetime
|
||||
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Transaction data extraction: LLM Vision + persistence."""
|
||||
|
||||
from app.models import Transaction
|
||||
from app.models.database import async_session_maker
|
||||
import app.models.database as db_module
|
||||
from app.schemas.transaction import TransactionExtractItem, TransactionResponse
|
||||
from app.services.llm import get_llm_provider
|
||||
|
||||
@@ -10,15 +10,26 @@ async def extract_and_save(
|
||||
case_id: int,
|
||||
screenshot_id: int,
|
||||
image_bytes: bytes,
|
||||
progress_hook=None,
|
||||
) -> list[TransactionResponse]:
|
||||
"""
|
||||
Run vision extraction on image and persist transactions to DB.
|
||||
Returns list of created transactions; low-confidence items are still saved but flagged.
|
||||
"""
|
||||
if progress_hook:
|
||||
await progress_hook("init", 5, "初始化识别上下文")
|
||||
provider = get_llm_provider()
|
||||
if progress_hook:
|
||||
await progress_hook("provider_ready", 15, f"已加载模型提供商: {type(provider).__name__}")
|
||||
if progress_hook:
|
||||
await progress_hook("calling_model", 35, "调用视觉模型识别截图中交易")
|
||||
items: list[TransactionExtractItem] = await provider.extract_from_image(image_bytes)
|
||||
if progress_hook:
|
||||
await progress_hook("model_returned", 70, f"模型返回 {len(items)} 条交易")
|
||||
results: list[TransactionResponse] = []
|
||||
async with async_session_maker() as session:
|
||||
async with db_module.async_session_maker() as session:
|
||||
if progress_hook:
|
||||
await progress_hook("db_writing", 85, "写入交易记录到数据库")
|
||||
for it in items:
|
||||
t = Transaction(
|
||||
case_id=case_id,
|
||||
@@ -39,4 +50,6 @@ async def extract_and_save(
|
||||
await session.flush()
|
||||
results.append(TransactionResponse.model_validate(t))
|
||||
await session.commit()
|
||||
if progress_hook:
|
||||
await progress_hook("completed", 100, "识别完成")
|
||||
return results
|
||||
|
||||
@@ -6,13 +6,16 @@ from app.schemas.transaction import TransactionExtractItem
|
||||
|
||||
|
||||
class BaseLLMProvider(ABC):
|
||||
"""Abstract base for LLM vision providers. Each provider implements extract_from_image."""
|
||||
"""Abstract base for LLM providers. Supports optional model override."""
|
||||
|
||||
def __init__(self, model_override: str | None = None):
|
||||
self._model_override = model_override
|
||||
|
||||
@abstractmethod
|
||||
async def extract_from_image(self, image_bytes: bytes) -> list[TransactionExtractItem]:
|
||||
"""
|
||||
Analyze a billing screenshot and return structured transaction list.
|
||||
:param image_bytes: Raw image file content (PNG/JPEG)
|
||||
:return: List of extracted transactions (may be empty or partial on failure)
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def chat(self, system: str, user: str) -> str:
|
||||
"""Plain text chat (for inference/reasoning tasks like report generation)."""
|
||||
pass
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
"""Anthropic Claude Vision provider."""
|
||||
|
||||
import base64
|
||||
import json
|
||||
import re
|
||||
from anthropic import AsyncAnthropic
|
||||
|
||||
from app.config import get_settings
|
||||
@@ -13,6 +11,9 @@ from app.services.llm.openai_vision import _parse_json_array
|
||||
|
||||
|
||||
class ClaudeVisionProvider(BaseLLMProvider):
|
||||
def _get_model(self) -> str:
|
||||
return self._model_override or get_settings().anthropic_model
|
||||
|
||||
async def extract_from_image(self, image_bytes: bytes) -> list[TransactionExtractItem]:
|
||||
settings = get_settings()
|
||||
if not settings.anthropic_api_key:
|
||||
@@ -20,14 +21,12 @@ class ClaudeVisionProvider(BaseLLMProvider):
|
||||
client = AsyncAnthropic(api_key=settings.anthropic_api_key)
|
||||
b64 = base64.standard_b64encode(image_bytes).decode("ascii")
|
||||
messages = get_extract_messages(b64)
|
||||
# Claude API: user message with content block list
|
||||
user_content = messages[1]["content"]
|
||||
content_blocks = []
|
||||
for block in user_content:
|
||||
if block["type"] == "text":
|
||||
content_blocks.append({"type": "text", "text": block["text"]})
|
||||
elif block["type"] == "image_url":
|
||||
# Claude expects base64 without data URL prefix
|
||||
content_blocks.append({
|
||||
"type": "image",
|
||||
"source": {
|
||||
@@ -37,7 +36,7 @@ class ClaudeVisionProvider(BaseLLMProvider):
|
||||
},
|
||||
})
|
||||
response = await client.messages.create(
|
||||
model=settings.anthropic_model,
|
||||
model=self._get_model(),
|
||||
max_tokens=4096,
|
||||
system=messages[0]["content"],
|
||||
messages=[{"role": "user", "content": content_blocks}],
|
||||
@@ -47,3 +46,20 @@ class ClaudeVisionProvider(BaseLLMProvider):
|
||||
if hasattr(block, "text"):
|
||||
text += block.text
|
||||
return _parse_json_array(text or "[]")
|
||||
|
||||
async def chat(self, system: str, user: str) -> str:
|
||||
settings = get_settings()
|
||||
if not settings.anthropic_api_key:
|
||||
raise ValueError("ANTHROPIC_API_KEY is not set")
|
||||
client = AsyncAnthropic(api_key=settings.anthropic_api_key)
|
||||
response = await client.messages.create(
|
||||
model=self._get_model(),
|
||||
max_tokens=4096,
|
||||
system=system,
|
||||
messages=[{"role": "user", "content": user}],
|
||||
)
|
||||
text = ""
|
||||
for block in response.content:
|
||||
if hasattr(block, "text"):
|
||||
text += block.text
|
||||
return text or ""
|
||||
|
||||
@@ -11,22 +11,44 @@ from app.services.llm.openai_vision import _parse_json_array
|
||||
|
||||
|
||||
class CustomOpenAICompatibleProvider(BaseLLMProvider):
|
||||
async def extract_from_image(self, image_bytes: bytes) -> list[TransactionExtractItem]:
|
||||
def _get_client(self) -> AsyncOpenAI:
|
||||
settings = get_settings()
|
||||
if not settings.custom_openai_api_key:
|
||||
raise ValueError("CUSTOM_OPENAI_API_KEY is not set")
|
||||
if not settings.custom_openai_base_url:
|
||||
raise ValueError("CUSTOM_OPENAI_BASE_URL is not set")
|
||||
client = AsyncOpenAI(
|
||||
return AsyncOpenAI(
|
||||
api_key=settings.custom_openai_api_key,
|
||||
base_url=settings.custom_openai_base_url,
|
||||
timeout=45.0,
|
||||
max_retries=1,
|
||||
)
|
||||
|
||||
def _get_model(self) -> str:
|
||||
return self._model_override or get_settings().custom_openai_model
|
||||
|
||||
async def extract_from_image(self, image_bytes: bytes) -> list[TransactionExtractItem]:
|
||||
client = self._get_client()
|
||||
b64 = base64.standard_b64encode(image_bytes).decode("ascii")
|
||||
messages = get_extract_messages(b64)
|
||||
response = await client.chat.completions.create(
|
||||
model=settings.custom_openai_model,
|
||||
model=self._get_model(),
|
||||
messages=messages,
|
||||
max_tokens=4096,
|
||||
timeout=45.0,
|
||||
)
|
||||
text = response.choices[0].message.content or "[]"
|
||||
return _parse_json_array(text)
|
||||
|
||||
async def chat(self, system: str, user: str) -> str:
|
||||
client = self._get_client()
|
||||
response = await client.chat.completions.create(
|
||||
model=self._get_model(),
|
||||
messages=[
|
||||
{"role": "system", "content": system},
|
||||
{"role": "user", "content": user},
|
||||
],
|
||||
max_tokens=4096,
|
||||
timeout=45.0,
|
||||
)
|
||||
return response.choices[0].message.content or ""
|
||||
|
||||
@@ -9,26 +9,39 @@ from app.services.llm.base import BaseLLMProvider
|
||||
from app.prompts.extract_transaction import get_extract_messages
|
||||
from app.services.llm.openai_vision import _parse_json_array
|
||||
|
||||
|
||||
# DeepSeek vision endpoint (OpenAI-compatible)
|
||||
DEEPSEEK_BASE = "https://api.deepseek.com"
|
||||
|
||||
|
||||
class DeepSeekVisionProvider(BaseLLMProvider):
|
||||
async def extract_from_image(self, image_bytes: bytes) -> list[TransactionExtractItem]:
|
||||
def _get_client(self) -> AsyncOpenAI:
|
||||
settings = get_settings()
|
||||
if not settings.deepseek_api_key:
|
||||
raise ValueError("DEEPSEEK_API_KEY is not set")
|
||||
client = AsyncOpenAI(
|
||||
api_key=settings.deepseek_api_key,
|
||||
base_url=DEEPSEEK_BASE,
|
||||
)
|
||||
return AsyncOpenAI(api_key=settings.deepseek_api_key, base_url=DEEPSEEK_BASE)
|
||||
|
||||
def _get_model(self) -> str:
|
||||
return self._model_override or get_settings().deepseek_model
|
||||
|
||||
async def extract_from_image(self, image_bytes: bytes) -> list[TransactionExtractItem]:
|
||||
client = self._get_client()
|
||||
b64 = base64.standard_b64encode(image_bytes).decode("ascii")
|
||||
messages = get_extract_messages(b64)
|
||||
response = await client.chat.completions.create(
|
||||
model=settings.deepseek_model,
|
||||
model=self._get_model(),
|
||||
messages=messages,
|
||||
max_tokens=4096,
|
||||
)
|
||||
text = response.choices[0].message.content or "[]"
|
||||
return _parse_json_array(text)
|
||||
|
||||
async def chat(self, system: str, user: str) -> str:
|
||||
client = self._get_client()
|
||||
response = await client.chat.completions.create(
|
||||
model=self._get_model(),
|
||||
messages=[
|
||||
{"role": "system", "content": system},
|
||||
{"role": "user", "content": user},
|
||||
],
|
||||
max_tokens=4096,
|
||||
)
|
||||
return response.choices[0].message.content or ""
|
||||
|
||||
@@ -12,6 +12,9 @@ from app.prompts.extract_transaction import get_extract_messages
|
||||
|
||||
|
||||
class OpenAIVisionProvider(BaseLLMProvider):
|
||||
def _get_model(self) -> str:
|
||||
return self._model_override or get_settings().openai_model
|
||||
|
||||
async def extract_from_image(self, image_bytes: bytes) -> list[TransactionExtractItem]:
|
||||
settings = get_settings()
|
||||
if not settings.openai_api_key:
|
||||
@@ -20,18 +23,32 @@ class OpenAIVisionProvider(BaseLLMProvider):
|
||||
b64 = base64.standard_b64encode(image_bytes).decode("ascii")
|
||||
messages = get_extract_messages(b64)
|
||||
response = await client.chat.completions.create(
|
||||
model=settings.openai_model,
|
||||
model=self._get_model(),
|
||||
messages=messages,
|
||||
max_tokens=4096,
|
||||
)
|
||||
text = response.choices[0].message.content or "[]"
|
||||
return _parse_json_array(text)
|
||||
|
||||
async def chat(self, system: str, user: str) -> str:
|
||||
settings = get_settings()
|
||||
if not settings.openai_api_key:
|
||||
raise ValueError("OPENAI_API_KEY is not set")
|
||||
client = AsyncOpenAI(api_key=settings.openai_api_key)
|
||||
response = await client.chat.completions.create(
|
||||
model=self._get_model(),
|
||||
messages=[
|
||||
{"role": "system", "content": system},
|
||||
{"role": "user", "content": user},
|
||||
],
|
||||
max_tokens=4096,
|
||||
)
|
||||
return response.choices[0].message.content or ""
|
||||
|
||||
|
||||
def _parse_json_array(text: str) -> list[TransactionExtractItem]:
|
||||
"""Parse LLM response into list of TransactionExtractItem. Tolerates markdown and extra text."""
|
||||
text = text.strip()
|
||||
# Remove optional markdown code block
|
||||
if text.startswith("```"):
|
||||
text = re.sub(r"^```(?:json)?\s*", "", text)
|
||||
text = re.sub(r"\s*```\s*$", "", text)
|
||||
@@ -46,7 +63,6 @@ def _parse_json_array(text: str) -> list[TransactionExtractItem]:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
try:
|
||||
# Normalize transaction_time: allow string -> datetime
|
||||
if isinstance(item.get("transaction_time"), str) and item["transaction_time"]:
|
||||
from dateutil import parser as date_parser
|
||||
item["transaction_time"] = date_parser.isoparse(item["transaction_time"])
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""LLM provider factory - returns provider by config."""
|
||||
"""LLM provider factory - returns provider by config, split by role (ocr / inference)."""
|
||||
|
||||
from app.config import get_settings
|
||||
from app.config import get_settings, get_ocr_model, get_inference_model
|
||||
from app.services.llm.base import BaseLLMProvider
|
||||
from app.services.llm.openai_vision import OpenAIVisionProvider
|
||||
from app.services.llm.claude_vision import ClaudeVisionProvider
|
||||
@@ -8,15 +8,25 @@ from app.services.llm.deepseek_vision import DeepSeekVisionProvider
|
||||
from app.services.llm.custom_openai_vision import CustomOpenAICompatibleProvider
|
||||
|
||||
|
||||
def get_llm_provider() -> BaseLLMProvider:
|
||||
def _make_provider(provider_name: str, model_override: str | None = None) -> BaseLLMProvider:
|
||||
name = (provider_name or "openai").lower()
|
||||
if name == "openai":
|
||||
return OpenAIVisionProvider(model_override=model_override)
|
||||
if name == "anthropic":
|
||||
return ClaudeVisionProvider(model_override=model_override)
|
||||
if name == "deepseek":
|
||||
return DeepSeekVisionProvider(model_override=model_override)
|
||||
if name == "custom_openai":
|
||||
return CustomOpenAICompatibleProvider(model_override=model_override)
|
||||
return OpenAIVisionProvider(model_override=model_override)
|
||||
|
||||
|
||||
def get_llm_provider(role: str = "ocr") -> BaseLLMProvider:
|
||||
"""
|
||||
role="ocr" -> uses ocr_provider + ocr_model
|
||||
role="inference" -> uses inference_provider + inference_model
|
||||
"""
|
||||
settings = get_settings()
|
||||
provider = (settings.llm_provider or "openai").lower()
|
||||
if provider == "openai":
|
||||
return OpenAIVisionProvider()
|
||||
if provider == "anthropic":
|
||||
return ClaudeVisionProvider()
|
||||
if provider == "deepseek":
|
||||
return DeepSeekVisionProvider()
|
||||
if provider == "custom_openai":
|
||||
return CustomOpenAICompatibleProvider()
|
||||
return OpenAIVisionProvider()
|
||||
if role == "inference":
|
||||
return _make_provider(settings.inference_provider, get_inference_model())
|
||||
return _make_provider(settings.ocr_provider, get_ocr_model())
|
||||
|
||||
Reference in New Issue
Block a user