update: upload fix
This commit is contained in:
11
.env.example
11
.env.example
@@ -1,7 +1,14 @@
|
|||||||
DATABASE_URL=sqlite+aiosqlite:///./fund_tracer.db
|
DATABASE_URL=sqlite+aiosqlite:///./fund_tracer.db
|
||||||
LLM_PROVIDER=openai
|
|
||||||
|
|
||||||
# Optional: choose model names
|
# --- OCR model (screenshot -> transactions) ---
|
||||||
|
OCR_PROVIDER=openai
|
||||||
|
OCR_MODEL=
|
||||||
|
|
||||||
|
# --- Inference model (report generation, reasoning) ---
|
||||||
|
INFERENCE_PROVIDER=openai
|
||||||
|
INFERENCE_MODEL=
|
||||||
|
|
||||||
|
# Provider default model names
|
||||||
OPENAI_MODEL=gpt-4o
|
OPENAI_MODEL=gpt-4o
|
||||||
ANTHROPIC_MODEL=claude-3-5-sonnet-20241022
|
ANTHROPIC_MODEL=claude-3-5-sonnet-20241022
|
||||||
DEEPSEEK_MODEL=deepseek-chat
|
DEEPSEEK_MODEL=deepseek-chat
|
||||||
|
|||||||
17
backend/.env.example
Normal file
17
backend/.env.example
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
DATABASE_URL=sqlite+aiosqlite:///./fund_tracer.db
|
||||||
|
LLM_PROVIDER=openai
|
||||||
|
|
||||||
|
# Optional: choose model names
|
||||||
|
OPENAI_MODEL=gpt-4o
|
||||||
|
ANTHROPIC_MODEL=claude-3-5-sonnet-20241022
|
||||||
|
DEEPSEEK_MODEL=deepseek-chat
|
||||||
|
CUSTOM_OPENAI_MODEL=gpt-4o-mini
|
||||||
|
|
||||||
|
# Custom OpenAI-compatible provider
|
||||||
|
CUSTOM_OPENAI_BASE_URL=
|
||||||
|
CUSTOM_OPENAI_API_KEY=
|
||||||
|
|
||||||
|
# API keys
|
||||||
|
OPENAI_API_KEY=
|
||||||
|
ANTHROPIC_API_KEY=
|
||||||
|
DEEPSEEK_API_KEY=
|
||||||
@@ -1,6 +1,7 @@
|
|||||||
"""Screenshot upload and extraction API."""
|
"""Screenshot upload and extraction API."""
|
||||||
|
|
||||||
import uuid
|
import uuid
|
||||||
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File
|
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File
|
||||||
@@ -84,18 +85,105 @@ async def extract_transactions(
|
|||||||
if not full_path.exists():
|
if not full_path.exists():
|
||||||
raise HTTPException(status_code=404, detail="File not found on disk")
|
raise HTTPException(status_code=404, detail="File not found on disk")
|
||||||
image_bytes = full_path.read_bytes()
|
image_bytes = full_path.read_bytes()
|
||||||
|
started_at = datetime.utcnow()
|
||||||
|
# 每次开始新一轮识别都重置计时,确保耗时是“本次分析”而不是历史累计
|
||||||
|
screenshot.started_at = started_at
|
||||||
|
screenshot.finished_at = None
|
||||||
|
screenshot.duration_ms = None
|
||||||
|
screenshot.error_message = None
|
||||||
|
screenshot.progress_step = "starting"
|
||||||
|
screenshot.progress_percent = 0
|
||||||
|
screenshot.progress_detail = "准备开始识别"
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
async def update_progress(step: str, percent: int, detail: str):
|
||||||
|
screenshot.status = "processing"
|
||||||
|
screenshot.progress_step = step
|
||||||
|
screenshot.progress_percent = percent
|
||||||
|
screenshot.progress_detail = detail
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
transactions = await extract_and_save(case_id, screenshot_id, image_bytes)
|
await update_progress("file_loaded", 10, "截图读取完成")
|
||||||
|
transactions = await extract_and_save(
|
||||||
|
case_id,
|
||||||
|
screenshot_id,
|
||||||
|
image_bytes,
|
||||||
|
progress_hook=update_progress,
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
error_detail = _classify_error(e)
|
||||||
r = await db.execute(select(Screenshot).where(Screenshot.id == screenshot_id))
|
r = await db.execute(select(Screenshot).where(Screenshot.id == screenshot_id))
|
||||||
sc = r.scalar_one_or_none()
|
sc = r.scalar_one_or_none()
|
||||||
if sc:
|
if sc:
|
||||||
sc.status = "failed"
|
sc.status = "failed"
|
||||||
|
sc.progress_step = "failed"
|
||||||
|
sc.progress_percent = 100
|
||||||
|
sc.progress_detail = "识别失败"
|
||||||
|
sc.finished_at = datetime.utcnow()
|
||||||
|
if sc.started_at:
|
||||||
|
sc.duration_ms = int((sc.finished_at - sc.started_at).total_seconds() * 1000)
|
||||||
|
sc.error_message = error_detail
|
||||||
await db.commit()
|
await db.commit()
|
||||||
raise HTTPException(status_code=502, detail=f"Extraction failed: {e!s}")
|
raise HTTPException(status_code=502, detail=error_detail)
|
||||||
r = await db.execute(select(Screenshot).where(Screenshot.id == screenshot_id))
|
r = await db.execute(select(Screenshot).where(Screenshot.id == screenshot_id))
|
||||||
sc = r.scalar_one_or_none()
|
sc = r.scalar_one_or_none()
|
||||||
if sc:
|
if sc:
|
||||||
sc.status = "extracted"
|
sc.status = "extracted"
|
||||||
|
sc.progress_step = "completed"
|
||||||
|
sc.progress_percent = 100
|
||||||
|
sc.progress_detail = "识别完成"
|
||||||
|
sc.finished_at = datetime.utcnow()
|
||||||
|
if sc.started_at:
|
||||||
|
sc.duration_ms = int((sc.finished_at - sc.started_at).total_seconds() * 1000)
|
||||||
|
sc.error_message = None
|
||||||
await db.commit()
|
await db.commit()
|
||||||
return TransactionListResponse(items=transactions)
|
return TransactionListResponse(items=transactions)
|
||||||
|
|
||||||
|
|
||||||
|
def _classify_error(e: Exception) -> str:
|
||||||
|
"""Produce a human-readable, categorized error message."""
|
||||||
|
name = type(e).__name__
|
||||||
|
msg = str(e)
|
||||||
|
|
||||||
|
if isinstance(e, ValueError):
|
||||||
|
return f"配置错误: {msg}"
|
||||||
|
|
||||||
|
# OpenAI SDK errors
|
||||||
|
try:
|
||||||
|
from openai import AuthenticationError, RateLimitError, APIConnectionError, BadRequestError, APIStatusError, APITimeoutError
|
||||||
|
if isinstance(e, AuthenticationError):
|
||||||
|
return f"API Key 无效或已过期 ({name}): {msg}"
|
||||||
|
if isinstance(e, RateLimitError):
|
||||||
|
return f"API 调用频率超限,请稍后重试 ({name}): {msg}"
|
||||||
|
if isinstance(e, APITimeoutError):
|
||||||
|
return f"模型服务响应超时,请检查 BaseURL/模型可用性或稍后重试 ({name}): {msg}"
|
||||||
|
if isinstance(e, APIConnectionError):
|
||||||
|
return f"无法连接到模型服务,请检查网络或 BaseURL ({name}): {msg}"
|
||||||
|
if isinstance(e, BadRequestError):
|
||||||
|
return f"请求被模型服务拒绝(可能模型名错误或不支持图片) ({name}): {msg}"
|
||||||
|
if isinstance(e, APIStatusError):
|
||||||
|
return f"模型服务返回错误 (HTTP {e.status_code}): {msg}"
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Anthropic SDK errors
|
||||||
|
try:
|
||||||
|
from anthropic import AuthenticationError as AnthAuthError, RateLimitError as AnthRateError
|
||||||
|
from anthropic import APIConnectionError as AnthConnError, BadRequestError as AnthBadReq
|
||||||
|
if isinstance(e, AnthAuthError):
|
||||||
|
return f"Anthropic API Key 无效或已过期: {msg}"
|
||||||
|
if isinstance(e, AnthRateError):
|
||||||
|
return f"Anthropic API 调用频率超限: {msg}"
|
||||||
|
if isinstance(e, AnthConnError):
|
||||||
|
return f"无法连接到 Anthropic 服务: {msg}"
|
||||||
|
if isinstance(e, AnthBadReq):
|
||||||
|
return f"Anthropic 请求被拒绝: {msg}"
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Connection / network
|
||||||
|
if "connect" in msg.lower() or "timeout" in msg.lower():
|
||||||
|
return f"网络连接失败或超时: {msg}"
|
||||||
|
|
||||||
|
return f"识别失败 ({name}): {msg}"
|
||||||
|
|||||||
@@ -9,7 +9,10 @@ router = APIRouter()
|
|||||||
|
|
||||||
|
|
||||||
class SettingsUpdate(BaseModel):
|
class SettingsUpdate(BaseModel):
|
||||||
llm_provider: str | None = None
|
ocr_provider: str | None = None
|
||||||
|
ocr_model: str | None = None
|
||||||
|
inference_provider: str | None = None
|
||||||
|
inference_model: str | None = None
|
||||||
openai_api_key: str | None = None
|
openai_api_key: str | None = None
|
||||||
anthropic_api_key: str | None = None
|
anthropic_api_key: str | None = None
|
||||||
deepseek_api_key: str | None = None
|
deepseek_api_key: str | None = None
|
||||||
|
|||||||
@@ -20,18 +20,30 @@ class Settings(BaseSettings):
|
|||||||
max_upload_size_mb: int = 20
|
max_upload_size_mb: int = 20
|
||||||
allowed_extensions: set[str] = {"png", "jpg", "jpeg", "webp"}
|
allowed_extensions: set[str] = {"png", "jpg", "jpeg", "webp"}
|
||||||
|
|
||||||
# LLM
|
# --- OCR (vision) model ---
|
||||||
llm_provider: str = "openai" # openai | anthropic | deepseek | custom_openai
|
ocr_provider: str = "openai" # openai | anthropic | deepseek | custom_openai
|
||||||
|
ocr_model: str | None = None # if None, falls back to provider default
|
||||||
|
|
||||||
|
# --- Inference (reasoning) model ---
|
||||||
|
inference_provider: str = "openai"
|
||||||
|
inference_model: str | None = None
|
||||||
|
|
||||||
|
# --- Provider credentials (shared between OCR and inference) ---
|
||||||
openai_api_key: str | None = None
|
openai_api_key: str | None = None
|
||||||
anthropic_api_key: str | None = None
|
anthropic_api_key: str | None = None
|
||||||
deepseek_api_key: str | None = None
|
deepseek_api_key: str | None = None
|
||||||
custom_openai_api_key: str | None = None
|
custom_openai_api_key: str | None = None
|
||||||
custom_openai_base_url: str | None = None
|
custom_openai_base_url: str | None = None
|
||||||
|
|
||||||
|
# Provider default model names (used when ocr_model / inference_model is None)
|
||||||
openai_model: str = "gpt-4o"
|
openai_model: str = "gpt-4o"
|
||||||
anthropic_model: str = "claude-3-5-sonnet-20241022"
|
anthropic_model: str = "claude-3-5-sonnet-20241022"
|
||||||
deepseek_model: str = "deepseek-chat"
|
deepseek_model: str = "deepseek-chat"
|
||||||
custom_openai_model: str = "gpt-4o-mini"
|
custom_openai_model: str = "gpt-4o-mini"
|
||||||
|
|
||||||
|
# Legacy compat: llm_provider maps to ocr_provider on load
|
||||||
|
llm_provider: str | None = None
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
env_file = ".env"
|
env_file = ".env"
|
||||||
env_file_encoding = "utf-8"
|
env_file_encoding = "utf-8"
|
||||||
@@ -40,8 +52,24 @@ class Settings(BaseSettings):
|
|||||||
|
|
||||||
_runtime_overrides: dict[str, str | None] = {}
|
_runtime_overrides: dict[str, str | None] = {}
|
||||||
|
|
||||||
|
_ALLOWED_RUNTIME_KEYS = {
|
||||||
|
"ocr_provider",
|
||||||
|
"ocr_model",
|
||||||
|
"inference_provider",
|
||||||
|
"inference_model",
|
||||||
|
"openai_api_key",
|
||||||
|
"anthropic_api_key",
|
||||||
|
"deepseek_api_key",
|
||||||
|
"custom_openai_api_key",
|
||||||
|
"custom_openai_base_url",
|
||||||
|
"custom_openai_model",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def _apply_overrides(settings: Settings) -> Settings:
|
def _apply_overrides(settings: Settings) -> Settings:
|
||||||
|
# Legacy: if llm_provider is set but ocr_provider is default, use it
|
||||||
|
if settings.llm_provider and settings.ocr_provider == "openai":
|
||||||
|
settings.ocr_provider = settings.llm_provider
|
||||||
for key, value in _runtime_overrides.items():
|
for key, value in _runtime_overrides.items():
|
||||||
if hasattr(settings, key):
|
if hasattr(settings, key):
|
||||||
setattr(settings, key, value)
|
setattr(settings, key, value)
|
||||||
@@ -53,19 +81,32 @@ def get_settings() -> Settings:
|
|||||||
return _apply_overrides(Settings())
|
return _apply_overrides(Settings())
|
||||||
|
|
||||||
|
|
||||||
def update_runtime_settings(payload: dict[str, str | None]) -> Settings:
|
def _resolve_model(provider: str, explicit_model: str | None, settings: Settings) -> str:
|
||||||
"""Update runtime settings and refresh cached Settings object."""
|
"""Return the model name to use for a given provider."""
|
||||||
allowed = {
|
if explicit_model:
|
||||||
"llm_provider",
|
return explicit_model
|
||||||
"openai_api_key",
|
defaults = {
|
||||||
"anthropic_api_key",
|
"openai": settings.openai_model,
|
||||||
"deepseek_api_key",
|
"anthropic": settings.anthropic_model,
|
||||||
"custom_openai_api_key",
|
"deepseek": settings.deepseek_model,
|
||||||
"custom_openai_base_url",
|
"custom_openai": settings.custom_openai_model,
|
||||||
"custom_openai_model",
|
|
||||||
}
|
}
|
||||||
|
return defaults.get(provider, settings.openai_model)
|
||||||
|
|
||||||
|
|
||||||
|
def get_ocr_model() -> str:
|
||||||
|
s = get_settings()
|
||||||
|
return _resolve_model(s.ocr_provider, s.ocr_model, s)
|
||||||
|
|
||||||
|
|
||||||
|
def get_inference_model() -> str:
|
||||||
|
s = get_settings()
|
||||||
|
return _resolve_model(s.inference_provider, s.inference_model, s)
|
||||||
|
|
||||||
|
|
||||||
|
def update_runtime_settings(payload: dict[str, str | None]) -> Settings:
|
||||||
for key, value in payload.items():
|
for key, value in payload.items():
|
||||||
if key in allowed:
|
if key in _ALLOWED_RUNTIME_KEYS:
|
||||||
_runtime_overrides[key] = value
|
_runtime_overrides[key] = value
|
||||||
get_settings.cache_clear()
|
get_settings.cache_clear()
|
||||||
return get_settings()
|
return get_settings()
|
||||||
@@ -74,9 +115,12 @@ def update_runtime_settings(payload: dict[str, str | None]) -> Settings:
|
|||||||
def public_settings() -> dict:
|
def public_settings() -> dict:
|
||||||
s = get_settings()
|
s = get_settings()
|
||||||
return {
|
return {
|
||||||
"llm_provider": s.llm_provider,
|
"ocr_provider": s.ocr_provider,
|
||||||
|
"ocr_model": get_ocr_model(),
|
||||||
|
"inference_provider": s.inference_provider,
|
||||||
|
"inference_model": get_inference_model(),
|
||||||
"providers": ["openai", "anthropic", "deepseek", "custom_openai"],
|
"providers": ["openai", "anthropic", "deepseek", "custom_openai"],
|
||||||
"models": {
|
"provider_defaults": {
|
||||||
"openai": s.openai_model,
|
"openai": s.openai_model,
|
||||||
"anthropic": s.anthropic_model,
|
"anthropic": s.anthropic_model,
|
||||||
"deepseek": s.deepseek_model,
|
"deepseek": s.deepseek_model,
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from sqlalchemy import String, DateTime, ForeignKey
|
from sqlalchemy import String, Text, DateTime, ForeignKey
|
||||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||||
|
|
||||||
from app.models.database import Base
|
from app.models.database import Base
|
||||||
@@ -15,7 +15,14 @@ class Screenshot(Base):
|
|||||||
case_id: Mapped[int] = mapped_column(ForeignKey("cases.id", ondelete="CASCADE"), index=True)
|
case_id: Mapped[int] = mapped_column(ForeignKey("cases.id", ondelete="CASCADE"), index=True)
|
||||||
filename: Mapped[str] = mapped_column(String(255))
|
filename: Mapped[str] = mapped_column(String(255))
|
||||||
file_path: Mapped[str] = mapped_column(String(512))
|
file_path: Mapped[str] = mapped_column(String(512))
|
||||||
status: Mapped[str] = mapped_column(String(32), default="pending") # pending | extracted | failed
|
status: Mapped[str] = mapped_column(String(32), default="pending") # pending | processing | extracted | failed
|
||||||
|
progress_step: Mapped[str | None] = mapped_column(String(64), nullable=True, default=None)
|
||||||
|
progress_percent: Mapped[int] = mapped_column(default=0)
|
||||||
|
progress_detail: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None)
|
||||||
|
started_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True, default=None)
|
||||||
|
finished_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True, default=None)
|
||||||
|
duration_ms: Mapped[int | None] = mapped_column(nullable=True, default=None)
|
||||||
|
error_message: Mapped[str | None] = mapped_column(Text, nullable=True, default=None)
|
||||||
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
|
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
|
||||||
|
|
||||||
case: Mapped["Case"] = relationship("Case", back_populates="screenshots")
|
case: Mapped["Case"] = relationship("Case", back_populates="screenshots")
|
||||||
|
|||||||
@@ -11,6 +11,13 @@ class ScreenshotResponse(BaseModel):
|
|||||||
filename: str
|
filename: str
|
||||||
file_path: str
|
file_path: str
|
||||||
status: str
|
status: str
|
||||||
|
progress_step: str | None = None
|
||||||
|
progress_percent: int = 0
|
||||||
|
progress_detail: str | None = None
|
||||||
|
started_at: datetime | None = None
|
||||||
|
finished_at: datetime | None = None
|
||||||
|
duration_ms: int | None = None
|
||||||
|
error_message: str | None = None
|
||||||
created_at: datetime
|
created_at: datetime
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
"""Transaction data extraction: LLM Vision + persistence."""
|
"""Transaction data extraction: LLM Vision + persistence."""
|
||||||
|
|
||||||
from app.models import Transaction
|
from app.models import Transaction
|
||||||
from app.models.database import async_session_maker
|
import app.models.database as db_module
|
||||||
from app.schemas.transaction import TransactionExtractItem, TransactionResponse
|
from app.schemas.transaction import TransactionExtractItem, TransactionResponse
|
||||||
from app.services.llm import get_llm_provider
|
from app.services.llm import get_llm_provider
|
||||||
|
|
||||||
@@ -10,15 +10,26 @@ async def extract_and_save(
|
|||||||
case_id: int,
|
case_id: int,
|
||||||
screenshot_id: int,
|
screenshot_id: int,
|
||||||
image_bytes: bytes,
|
image_bytes: bytes,
|
||||||
|
progress_hook=None,
|
||||||
) -> list[TransactionResponse]:
|
) -> list[TransactionResponse]:
|
||||||
"""
|
"""
|
||||||
Run vision extraction on image and persist transactions to DB.
|
Run vision extraction on image and persist transactions to DB.
|
||||||
Returns list of created transactions; low-confidence items are still saved but flagged.
|
Returns list of created transactions; low-confidence items are still saved but flagged.
|
||||||
"""
|
"""
|
||||||
|
if progress_hook:
|
||||||
|
await progress_hook("init", 5, "初始化识别上下文")
|
||||||
provider = get_llm_provider()
|
provider = get_llm_provider()
|
||||||
|
if progress_hook:
|
||||||
|
await progress_hook("provider_ready", 15, f"已加载模型提供商: {type(provider).__name__}")
|
||||||
|
if progress_hook:
|
||||||
|
await progress_hook("calling_model", 35, "调用视觉模型识别截图中交易")
|
||||||
items: list[TransactionExtractItem] = await provider.extract_from_image(image_bytes)
|
items: list[TransactionExtractItem] = await provider.extract_from_image(image_bytes)
|
||||||
|
if progress_hook:
|
||||||
|
await progress_hook("model_returned", 70, f"模型返回 {len(items)} 条交易")
|
||||||
results: list[TransactionResponse] = []
|
results: list[TransactionResponse] = []
|
||||||
async with async_session_maker() as session:
|
async with db_module.async_session_maker() as session:
|
||||||
|
if progress_hook:
|
||||||
|
await progress_hook("db_writing", 85, "写入交易记录到数据库")
|
||||||
for it in items:
|
for it in items:
|
||||||
t = Transaction(
|
t = Transaction(
|
||||||
case_id=case_id,
|
case_id=case_id,
|
||||||
@@ -39,4 +50,6 @@ async def extract_and_save(
|
|||||||
await session.flush()
|
await session.flush()
|
||||||
results.append(TransactionResponse.model_validate(t))
|
results.append(TransactionResponse.model_validate(t))
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
if progress_hook:
|
||||||
|
await progress_hook("completed", 100, "识别完成")
|
||||||
return results
|
return results
|
||||||
|
|||||||
@@ -6,13 +6,16 @@ from app.schemas.transaction import TransactionExtractItem
|
|||||||
|
|
||||||
|
|
||||||
class BaseLLMProvider(ABC):
|
class BaseLLMProvider(ABC):
|
||||||
"""Abstract base for LLM vision providers. Each provider implements extract_from_image."""
|
"""Abstract base for LLM providers. Supports optional model override."""
|
||||||
|
|
||||||
|
def __init__(self, model_override: str | None = None):
|
||||||
|
self._model_override = model_override
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def extract_from_image(self, image_bytes: bytes) -> list[TransactionExtractItem]:
|
async def extract_from_image(self, image_bytes: bytes) -> list[TransactionExtractItem]:
|
||||||
"""
|
pass
|
||||||
Analyze a billing screenshot and return structured transaction list.
|
|
||||||
:param image_bytes: Raw image file content (PNG/JPEG)
|
@abstractmethod
|
||||||
:return: List of extracted transactions (may be empty or partial on failure)
|
async def chat(self, system: str, user: str) -> str:
|
||||||
"""
|
"""Plain text chat (for inference/reasoning tasks like report generation)."""
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -1,8 +1,6 @@
|
|||||||
"""Anthropic Claude Vision provider."""
|
"""Anthropic Claude Vision provider."""
|
||||||
|
|
||||||
import base64
|
import base64
|
||||||
import json
|
|
||||||
import re
|
|
||||||
from anthropic import AsyncAnthropic
|
from anthropic import AsyncAnthropic
|
||||||
|
|
||||||
from app.config import get_settings
|
from app.config import get_settings
|
||||||
@@ -13,6 +11,9 @@ from app.services.llm.openai_vision import _parse_json_array
|
|||||||
|
|
||||||
|
|
||||||
class ClaudeVisionProvider(BaseLLMProvider):
|
class ClaudeVisionProvider(BaseLLMProvider):
|
||||||
|
def _get_model(self) -> str:
|
||||||
|
return self._model_override or get_settings().anthropic_model
|
||||||
|
|
||||||
async def extract_from_image(self, image_bytes: bytes) -> list[TransactionExtractItem]:
|
async def extract_from_image(self, image_bytes: bytes) -> list[TransactionExtractItem]:
|
||||||
settings = get_settings()
|
settings = get_settings()
|
||||||
if not settings.anthropic_api_key:
|
if not settings.anthropic_api_key:
|
||||||
@@ -20,14 +21,12 @@ class ClaudeVisionProvider(BaseLLMProvider):
|
|||||||
client = AsyncAnthropic(api_key=settings.anthropic_api_key)
|
client = AsyncAnthropic(api_key=settings.anthropic_api_key)
|
||||||
b64 = base64.standard_b64encode(image_bytes).decode("ascii")
|
b64 = base64.standard_b64encode(image_bytes).decode("ascii")
|
||||||
messages = get_extract_messages(b64)
|
messages = get_extract_messages(b64)
|
||||||
# Claude API: user message with content block list
|
|
||||||
user_content = messages[1]["content"]
|
user_content = messages[1]["content"]
|
||||||
content_blocks = []
|
content_blocks = []
|
||||||
for block in user_content:
|
for block in user_content:
|
||||||
if block["type"] == "text":
|
if block["type"] == "text":
|
||||||
content_blocks.append({"type": "text", "text": block["text"]})
|
content_blocks.append({"type": "text", "text": block["text"]})
|
||||||
elif block["type"] == "image_url":
|
elif block["type"] == "image_url":
|
||||||
# Claude expects base64 without data URL prefix
|
|
||||||
content_blocks.append({
|
content_blocks.append({
|
||||||
"type": "image",
|
"type": "image",
|
||||||
"source": {
|
"source": {
|
||||||
@@ -37,7 +36,7 @@ class ClaudeVisionProvider(BaseLLMProvider):
|
|||||||
},
|
},
|
||||||
})
|
})
|
||||||
response = await client.messages.create(
|
response = await client.messages.create(
|
||||||
model=settings.anthropic_model,
|
model=self._get_model(),
|
||||||
max_tokens=4096,
|
max_tokens=4096,
|
||||||
system=messages[0]["content"],
|
system=messages[0]["content"],
|
||||||
messages=[{"role": "user", "content": content_blocks}],
|
messages=[{"role": "user", "content": content_blocks}],
|
||||||
@@ -47,3 +46,20 @@ class ClaudeVisionProvider(BaseLLMProvider):
|
|||||||
if hasattr(block, "text"):
|
if hasattr(block, "text"):
|
||||||
text += block.text
|
text += block.text
|
||||||
return _parse_json_array(text or "[]")
|
return _parse_json_array(text or "[]")
|
||||||
|
|
||||||
|
async def chat(self, system: str, user: str) -> str:
|
||||||
|
settings = get_settings()
|
||||||
|
if not settings.anthropic_api_key:
|
||||||
|
raise ValueError("ANTHROPIC_API_KEY is not set")
|
||||||
|
client = AsyncAnthropic(api_key=settings.anthropic_api_key)
|
||||||
|
response = await client.messages.create(
|
||||||
|
model=self._get_model(),
|
||||||
|
max_tokens=4096,
|
||||||
|
system=system,
|
||||||
|
messages=[{"role": "user", "content": user}],
|
||||||
|
)
|
||||||
|
text = ""
|
||||||
|
for block in response.content:
|
||||||
|
if hasattr(block, "text"):
|
||||||
|
text += block.text
|
||||||
|
return text or ""
|
||||||
|
|||||||
@@ -11,22 +11,44 @@ from app.services.llm.openai_vision import _parse_json_array
|
|||||||
|
|
||||||
|
|
||||||
class CustomOpenAICompatibleProvider(BaseLLMProvider):
|
class CustomOpenAICompatibleProvider(BaseLLMProvider):
|
||||||
async def extract_from_image(self, image_bytes: bytes) -> list[TransactionExtractItem]:
|
def _get_client(self) -> AsyncOpenAI:
|
||||||
settings = get_settings()
|
settings = get_settings()
|
||||||
if not settings.custom_openai_api_key:
|
if not settings.custom_openai_api_key:
|
||||||
raise ValueError("CUSTOM_OPENAI_API_KEY is not set")
|
raise ValueError("CUSTOM_OPENAI_API_KEY is not set")
|
||||||
if not settings.custom_openai_base_url:
|
if not settings.custom_openai_base_url:
|
||||||
raise ValueError("CUSTOM_OPENAI_BASE_URL is not set")
|
raise ValueError("CUSTOM_OPENAI_BASE_URL is not set")
|
||||||
client = AsyncOpenAI(
|
return AsyncOpenAI(
|
||||||
api_key=settings.custom_openai_api_key,
|
api_key=settings.custom_openai_api_key,
|
||||||
base_url=settings.custom_openai_base_url,
|
base_url=settings.custom_openai_base_url,
|
||||||
|
timeout=45.0,
|
||||||
|
max_retries=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _get_model(self) -> str:
|
||||||
|
return self._model_override or get_settings().custom_openai_model
|
||||||
|
|
||||||
|
async def extract_from_image(self, image_bytes: bytes) -> list[TransactionExtractItem]:
|
||||||
|
client = self._get_client()
|
||||||
b64 = base64.standard_b64encode(image_bytes).decode("ascii")
|
b64 = base64.standard_b64encode(image_bytes).decode("ascii")
|
||||||
messages = get_extract_messages(b64)
|
messages = get_extract_messages(b64)
|
||||||
response = await client.chat.completions.create(
|
response = await client.chat.completions.create(
|
||||||
model=settings.custom_openai_model,
|
model=self._get_model(),
|
||||||
messages=messages,
|
messages=messages,
|
||||||
max_tokens=4096,
|
max_tokens=4096,
|
||||||
|
timeout=45.0,
|
||||||
)
|
)
|
||||||
text = response.choices[0].message.content or "[]"
|
text = response.choices[0].message.content or "[]"
|
||||||
return _parse_json_array(text)
|
return _parse_json_array(text)
|
||||||
|
|
||||||
|
async def chat(self, system: str, user: str) -> str:
|
||||||
|
client = self._get_client()
|
||||||
|
response = await client.chat.completions.create(
|
||||||
|
model=self._get_model(),
|
||||||
|
messages=[
|
||||||
|
{"role": "system", "content": system},
|
||||||
|
{"role": "user", "content": user},
|
||||||
|
],
|
||||||
|
max_tokens=4096,
|
||||||
|
timeout=45.0,
|
||||||
|
)
|
||||||
|
return response.choices[0].message.content or ""
|
||||||
|
|||||||
@@ -9,26 +9,39 @@ from app.services.llm.base import BaseLLMProvider
|
|||||||
from app.prompts.extract_transaction import get_extract_messages
|
from app.prompts.extract_transaction import get_extract_messages
|
||||||
from app.services.llm.openai_vision import _parse_json_array
|
from app.services.llm.openai_vision import _parse_json_array
|
||||||
|
|
||||||
|
|
||||||
# DeepSeek vision endpoint (OpenAI-compatible)
|
|
||||||
DEEPSEEK_BASE = "https://api.deepseek.com"
|
DEEPSEEK_BASE = "https://api.deepseek.com"
|
||||||
|
|
||||||
|
|
||||||
class DeepSeekVisionProvider(BaseLLMProvider):
|
class DeepSeekVisionProvider(BaseLLMProvider):
|
||||||
async def extract_from_image(self, image_bytes: bytes) -> list[TransactionExtractItem]:
|
def _get_client(self) -> AsyncOpenAI:
|
||||||
settings = get_settings()
|
settings = get_settings()
|
||||||
if not settings.deepseek_api_key:
|
if not settings.deepseek_api_key:
|
||||||
raise ValueError("DEEPSEEK_API_KEY is not set")
|
raise ValueError("DEEPSEEK_API_KEY is not set")
|
||||||
client = AsyncOpenAI(
|
return AsyncOpenAI(api_key=settings.deepseek_api_key, base_url=DEEPSEEK_BASE)
|
||||||
api_key=settings.deepseek_api_key,
|
|
||||||
base_url=DEEPSEEK_BASE,
|
def _get_model(self) -> str:
|
||||||
)
|
return self._model_override or get_settings().deepseek_model
|
||||||
|
|
||||||
|
async def extract_from_image(self, image_bytes: bytes) -> list[TransactionExtractItem]:
|
||||||
|
client = self._get_client()
|
||||||
b64 = base64.standard_b64encode(image_bytes).decode("ascii")
|
b64 = base64.standard_b64encode(image_bytes).decode("ascii")
|
||||||
messages = get_extract_messages(b64)
|
messages = get_extract_messages(b64)
|
||||||
response = await client.chat.completions.create(
|
response = await client.chat.completions.create(
|
||||||
model=settings.deepseek_model,
|
model=self._get_model(),
|
||||||
messages=messages,
|
messages=messages,
|
||||||
max_tokens=4096,
|
max_tokens=4096,
|
||||||
)
|
)
|
||||||
text = response.choices[0].message.content or "[]"
|
text = response.choices[0].message.content or "[]"
|
||||||
return _parse_json_array(text)
|
return _parse_json_array(text)
|
||||||
|
|
||||||
|
async def chat(self, system: str, user: str) -> str:
|
||||||
|
client = self._get_client()
|
||||||
|
response = await client.chat.completions.create(
|
||||||
|
model=self._get_model(),
|
||||||
|
messages=[
|
||||||
|
{"role": "system", "content": system},
|
||||||
|
{"role": "user", "content": user},
|
||||||
|
],
|
||||||
|
max_tokens=4096,
|
||||||
|
)
|
||||||
|
return response.choices[0].message.content or ""
|
||||||
|
|||||||
@@ -12,6 +12,9 @@ from app.prompts.extract_transaction import get_extract_messages
|
|||||||
|
|
||||||
|
|
||||||
class OpenAIVisionProvider(BaseLLMProvider):
|
class OpenAIVisionProvider(BaseLLMProvider):
|
||||||
|
def _get_model(self) -> str:
|
||||||
|
return self._model_override or get_settings().openai_model
|
||||||
|
|
||||||
async def extract_from_image(self, image_bytes: bytes) -> list[TransactionExtractItem]:
|
async def extract_from_image(self, image_bytes: bytes) -> list[TransactionExtractItem]:
|
||||||
settings = get_settings()
|
settings = get_settings()
|
||||||
if not settings.openai_api_key:
|
if not settings.openai_api_key:
|
||||||
@@ -20,18 +23,32 @@ class OpenAIVisionProvider(BaseLLMProvider):
|
|||||||
b64 = base64.standard_b64encode(image_bytes).decode("ascii")
|
b64 = base64.standard_b64encode(image_bytes).decode("ascii")
|
||||||
messages = get_extract_messages(b64)
|
messages = get_extract_messages(b64)
|
||||||
response = await client.chat.completions.create(
|
response = await client.chat.completions.create(
|
||||||
model=settings.openai_model,
|
model=self._get_model(),
|
||||||
messages=messages,
|
messages=messages,
|
||||||
max_tokens=4096,
|
max_tokens=4096,
|
||||||
)
|
)
|
||||||
text = response.choices[0].message.content or "[]"
|
text = response.choices[0].message.content or "[]"
|
||||||
return _parse_json_array(text)
|
return _parse_json_array(text)
|
||||||
|
|
||||||
|
async def chat(self, system: str, user: str) -> str:
|
||||||
|
settings = get_settings()
|
||||||
|
if not settings.openai_api_key:
|
||||||
|
raise ValueError("OPENAI_API_KEY is not set")
|
||||||
|
client = AsyncOpenAI(api_key=settings.openai_api_key)
|
||||||
|
response = await client.chat.completions.create(
|
||||||
|
model=self._get_model(),
|
||||||
|
messages=[
|
||||||
|
{"role": "system", "content": system},
|
||||||
|
{"role": "user", "content": user},
|
||||||
|
],
|
||||||
|
max_tokens=4096,
|
||||||
|
)
|
||||||
|
return response.choices[0].message.content or ""
|
||||||
|
|
||||||
|
|
||||||
def _parse_json_array(text: str) -> list[TransactionExtractItem]:
|
def _parse_json_array(text: str) -> list[TransactionExtractItem]:
|
||||||
"""Parse LLM response into list of TransactionExtractItem. Tolerates markdown and extra text."""
|
"""Parse LLM response into list of TransactionExtractItem. Tolerates markdown and extra text."""
|
||||||
text = text.strip()
|
text = text.strip()
|
||||||
# Remove optional markdown code block
|
|
||||||
if text.startswith("```"):
|
if text.startswith("```"):
|
||||||
text = re.sub(r"^```(?:json)?\s*", "", text)
|
text = re.sub(r"^```(?:json)?\s*", "", text)
|
||||||
text = re.sub(r"\s*```\s*$", "", text)
|
text = re.sub(r"\s*```\s*$", "", text)
|
||||||
@@ -46,7 +63,6 @@ def _parse_json_array(text: str) -> list[TransactionExtractItem]:
|
|||||||
if not isinstance(item, dict):
|
if not isinstance(item, dict):
|
||||||
continue
|
continue
|
||||||
try:
|
try:
|
||||||
# Normalize transaction_time: allow string -> datetime
|
|
||||||
if isinstance(item.get("transaction_time"), str) and item["transaction_time"]:
|
if isinstance(item.get("transaction_time"), str) and item["transaction_time"]:
|
||||||
from dateutil import parser as date_parser
|
from dateutil import parser as date_parser
|
||||||
item["transaction_time"] = date_parser.isoparse(item["transaction_time"])
|
item["transaction_time"] = date_parser.isoparse(item["transaction_time"])
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
"""LLM provider factory - returns provider by config."""
|
"""LLM provider factory - returns provider by config, split by role (ocr / inference)."""
|
||||||
|
|
||||||
from app.config import get_settings
|
from app.config import get_settings, get_ocr_model, get_inference_model
|
||||||
from app.services.llm.base import BaseLLMProvider
|
from app.services.llm.base import BaseLLMProvider
|
||||||
from app.services.llm.openai_vision import OpenAIVisionProvider
|
from app.services.llm.openai_vision import OpenAIVisionProvider
|
||||||
from app.services.llm.claude_vision import ClaudeVisionProvider
|
from app.services.llm.claude_vision import ClaudeVisionProvider
|
||||||
@@ -8,15 +8,25 @@ from app.services.llm.deepseek_vision import DeepSeekVisionProvider
|
|||||||
from app.services.llm.custom_openai_vision import CustomOpenAICompatibleProvider
|
from app.services.llm.custom_openai_vision import CustomOpenAICompatibleProvider
|
||||||
|
|
||||||
|
|
||||||
def get_llm_provider() -> BaseLLMProvider:
|
def _make_provider(provider_name: str, model_override: str | None = None) -> BaseLLMProvider:
|
||||||
|
name = (provider_name or "openai").lower()
|
||||||
|
if name == "openai":
|
||||||
|
return OpenAIVisionProvider(model_override=model_override)
|
||||||
|
if name == "anthropic":
|
||||||
|
return ClaudeVisionProvider(model_override=model_override)
|
||||||
|
if name == "deepseek":
|
||||||
|
return DeepSeekVisionProvider(model_override=model_override)
|
||||||
|
if name == "custom_openai":
|
||||||
|
return CustomOpenAICompatibleProvider(model_override=model_override)
|
||||||
|
return OpenAIVisionProvider(model_override=model_override)
|
||||||
|
|
||||||
|
|
||||||
|
def get_llm_provider(role: str = "ocr") -> BaseLLMProvider:
|
||||||
|
"""
|
||||||
|
role="ocr" -> uses ocr_provider + ocr_model
|
||||||
|
role="inference" -> uses inference_provider + inference_model
|
||||||
|
"""
|
||||||
settings = get_settings()
|
settings = get_settings()
|
||||||
provider = (settings.llm_provider or "openai").lower()
|
if role == "inference":
|
||||||
if provider == "openai":
|
return _make_provider(settings.inference_provider, get_inference_model())
|
||||||
return OpenAIVisionProvider()
|
return _make_provider(settings.ocr_provider, get_ocr_model())
|
||||||
if provider == "anthropic":
|
|
||||||
return ClaudeVisionProvider()
|
|
||||||
if provider == "deepseek":
|
|
||||||
return DeepSeekVisionProvider()
|
|
||||||
if provider == "custom_openai":
|
|
||||||
return CustomOpenAICompatibleProvider()
|
|
||||||
return OpenAIVisionProvider()
|
|
||||||
|
|||||||
@@ -8,7 +8,17 @@ import Settings from "./pages/Settings";
|
|||||||
|
|
||||||
function App() {
|
function App() {
|
||||||
return (
|
return (
|
||||||
<ConfigProvider locale={zhCN}>
|
<ConfigProvider
|
||||||
|
locale={zhCN}
|
||||||
|
theme={{
|
||||||
|
token: {
|
||||||
|
fontSize: 16,
|
||||||
|
fontSizeSM: 14,
|
||||||
|
fontSizeLG: 18,
|
||||||
|
lineHeight: 1.6,
|
||||||
|
},
|
||||||
|
}}
|
||||||
|
>
|
||||||
<Routes>
|
<Routes>
|
||||||
<Route element={<AppLayout />}>
|
<Route element={<AppLayout />}>
|
||||||
<Route path="/" element={<CaseList />} />
|
<Route path="/" element={<CaseList />} />
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
import { useState, useEffect } from "react";
|
import { useEffect, useMemo, useState } from "react";
|
||||||
import { Upload, List, Button, Card, Tag, message } from "antd";
|
import type { Key } from "react";
|
||||||
import { InboxOutlined, ThunderboltOutlined } from "@ant-design/icons";
|
import { Upload, Table, Button, Tag, Alert, message, Space, Progress } from "antd";
|
||||||
|
import type { ColumnsType } from "antd/es/table";
|
||||||
|
import { InboxOutlined, ThunderboltOutlined, ReloadOutlined } from "@ant-design/icons";
|
||||||
import { api, type ScreenshotItem } from "../services/api";
|
import { api, type ScreenshotItem } from "../services/api";
|
||||||
|
|
||||||
const { Dragger } = Upload;
|
const { Dragger } = Upload;
|
||||||
@@ -10,10 +12,24 @@ interface Props {
|
|||||||
onExtracted?: () => void;
|
onExtracted?: () => void;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function statusTag(item: ScreenshotItem) {
|
||||||
|
if (item.status === "extracted") return <Tag color="green">已识别</Tag>;
|
||||||
|
if (item.status === "processing") return <Tag color="blue">识别中</Tag>;
|
||||||
|
if (item.status === "failed") return <Tag color="red">失败</Tag>;
|
||||||
|
return <Tag>待识别</Tag>;
|
||||||
|
}
|
||||||
|
|
||||||
|
function formatDuration(ms: number | null) {
|
||||||
|
if (ms == null) return "-";
|
||||||
|
if (ms < 1000) return `${ms} ms`;
|
||||||
|
return `${(ms / 1000).toFixed(2)} s`;
|
||||||
|
}
|
||||||
|
|
||||||
export default function ScreenshotUploader({ caseId, onExtracted }: Props) {
|
export default function ScreenshotUploader({ caseId, onExtracted }: Props) {
|
||||||
const [screenshots, setScreenshots] = useState<ScreenshotItem[]>([]);
|
const [screenshots, setScreenshots] = useState<ScreenshotItem[]>([]);
|
||||||
const [loading, setLoading] = useState(false);
|
const [loading, setLoading] = useState(false);
|
||||||
const [extractingId, setExtractingId] = useState<number | null>(null);
|
const [extractingIds, setExtractingIds] = useState<number[]>([]);
|
||||||
|
const [selectedRowKeys, setSelectedRowKeys] = useState<Key[]>([]);
|
||||||
|
|
||||||
const loadScreenshots = async () => {
|
const loadScreenshots = async () => {
|
||||||
try {
|
try {
|
||||||
@@ -35,73 +51,168 @@ export default function ScreenshotUploader({ caseId, onExtracted }: Props) {
|
|||||||
} finally {
|
} finally {
|
||||||
setLoading(false);
|
setLoading(false);
|
||||||
}
|
}
|
||||||
return false; // prevent default upload
|
return false;
|
||||||
};
|
};
|
||||||
|
|
||||||
const handleExtract = async (screenshotId: number) => {
|
const handleExtractSingle = async (screenshotId: number) => {
|
||||||
setExtractingId(screenshotId);
|
setExtractingIds((prev) => Array.from(new Set([...prev, screenshotId])));
|
||||||
try {
|
try {
|
||||||
await api.screenshots.extract(caseId, screenshotId);
|
await api.screenshots.extract(caseId, screenshotId);
|
||||||
message.success("识别完成");
|
|
||||||
await loadScreenshots();
|
await loadScreenshots();
|
||||||
onExtracted?.();
|
onExtracted?.();
|
||||||
|
message.success(`截图 ${screenshotId} 识别完成`);
|
||||||
} catch (e: unknown) {
|
} catch (e: unknown) {
|
||||||
const msg = e && typeof e === "object" && "response" in e
|
const detail =
|
||||||
? (e as { response?: { data?: { detail?: string } } }).response?.data?.detail
|
e && typeof e === "object" && "response" in e
|
||||||
: "识别失败";
|
? (e as { response?: { data?: { detail?: string } } }).response?.data?.detail
|
||||||
message.error(msg || "识别失败");
|
: undefined;
|
||||||
|
message.error(detail || `截图 ${screenshotId} 识别失败`);
|
||||||
|
await loadScreenshots();
|
||||||
} finally {
|
} finally {
|
||||||
setExtractingId(null);
|
setExtractingIds((prev) => prev.filter((id) => id !== screenshotId));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const handleBatchExtract = async () => {
|
||||||
|
const ids = selectedRowKeys.map((k) => Number(k)).filter((id) => Number.isFinite(id));
|
||||||
|
if (!ids.length) {
|
||||||
|
message.warning("请先勾选要识别的截图");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
setExtractingIds((prev) => Array.from(new Set([...prev, ...ids])));
|
||||||
|
const started = Date.now();
|
||||||
|
let ok = 0;
|
||||||
|
let fail = 0;
|
||||||
|
for (const id of ids) {
|
||||||
|
try {
|
||||||
|
await api.screenshots.extract(caseId, id);
|
||||||
|
ok += 1;
|
||||||
|
} catch {
|
||||||
|
fail += 1;
|
||||||
|
}
|
||||||
|
await loadScreenshots();
|
||||||
|
}
|
||||||
|
setExtractingIds((prev) => prev.filter((id) => !ids.includes(id)));
|
||||||
|
onExtracted?.();
|
||||||
|
const elapsed = Date.now() - started;
|
||||||
|
message.info(`批量识别结束:成功 ${ok},失败 ${fail},总耗时 ${(elapsed / 1000).toFixed(2)} s`);
|
||||||
|
};
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (caseId) loadScreenshots();
|
if (!caseId) return;
|
||||||
|
loadScreenshots();
|
||||||
|
const timer = window.setInterval(loadScreenshots, 1500);
|
||||||
|
return () => window.clearInterval(timer);
|
||||||
}, [caseId]);
|
}, [caseId]);
|
||||||
|
|
||||||
|
const rowSelection = {
|
||||||
|
selectedRowKeys,
|
||||||
|
onChange: (keys: Key[]) => setSelectedRowKeys(keys),
|
||||||
|
getCheckboxProps: (record: ScreenshotItem) => ({
|
||||||
|
disabled: record.status === "processing",
|
||||||
|
}),
|
||||||
|
};
|
||||||
|
|
||||||
|
const columns: ColumnsType<ScreenshotItem> = useMemo(() => [
|
||||||
|
{
|
||||||
|
title: "截图",
|
||||||
|
dataIndex: "filename",
|
||||||
|
key: "filename",
|
||||||
|
render: (v: string) => (
|
||||||
|
<span style={{ display: "block", maxWidth: 260, overflow: "hidden", textOverflow: "ellipsis", whiteSpace: "nowrap" }} title={v}>
|
||||||
|
{v}
|
||||||
|
</span>
|
||||||
|
),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
title: "状态",
|
||||||
|
key: "status",
|
||||||
|
width: 110,
|
||||||
|
render: (_, r) => statusTag(r),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
title: "识别进度",
|
||||||
|
key: "progress",
|
||||||
|
width: 280,
|
||||||
|
render: (_, r) => (
|
||||||
|
<div>
|
||||||
|
<Progress percent={r.progress_percent || 0} size="small" status={r.status === "failed" ? "exception" : r.status === "extracted" ? "success" : "active"} />
|
||||||
|
<div style={{ fontSize: 12, color: "#666" }}>{r.progress_detail || "-"}</div>
|
||||||
|
{r.progress_step && <div style={{ fontSize: 12, color: "#999" }}>步骤: {r.progress_step}</div>}
|
||||||
|
</div>
|
||||||
|
),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
title: "耗时",
|
||||||
|
key: "duration",
|
||||||
|
width: 120,
|
||||||
|
render: (_, r) => formatDuration(r.duration_ms),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
title: "错误信息",
|
||||||
|
key: "error_message",
|
||||||
|
render: (_, r) =>
|
||||||
|
r.status === "failed" && r.error_message ? (
|
||||||
|
<Alert type="error" showIcon message={r.error_message} />
|
||||||
|
) : (
|
||||||
|
"-"
|
||||||
|
),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
title: "操作",
|
||||||
|
key: "action",
|
||||||
|
width: 120,
|
||||||
|
render: (_, r) => (
|
||||||
|
(r.status === "pending" || r.status === "failed") && (
|
||||||
|
<Button
|
||||||
|
type="primary"
|
||||||
|
size="small"
|
||||||
|
icon={r.status === "failed" ? <ReloadOutlined /> : <ThunderboltOutlined />}
|
||||||
|
loading={extractingIds.includes(r.id)}
|
||||||
|
onClick={() => handleExtractSingle(r.id)}
|
||||||
|
>
|
||||||
|
{r.status === "failed" ? "重试" : "识别"}
|
||||||
|
</Button>
|
||||||
|
)
|
||||||
|
),
|
||||||
|
},
|
||||||
|
], [extractingIds]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div>
|
<div>
|
||||||
<Dragger
|
<Dragger
|
||||||
multiple
|
multiple
|
||||||
accept=".png,.jpg,.jpeg,.webp"
|
accept=".png,.jpg,.jpeg,.webp"
|
||||||
showUploadList={false}
|
showUploadList={false}
|
||||||
beforeUpload={(file) => { handleUpload(file as File); return false; }}
|
beforeUpload={(file) => {
|
||||||
|
handleUpload(file as File);
|
||||||
|
return false;
|
||||||
|
}}
|
||||||
disabled={loading}
|
disabled={loading}
|
||||||
>
|
>
|
||||||
<p className="ant-upload-drag-icon"><InboxOutlined /></p>
|
<p className="ant-upload-drag-icon"><InboxOutlined /></p>
|
||||||
<p className="ant-upload-text">点击或拖拽账单截图到此处上传</p>
|
<p className="ant-upload-text">点击或拖拽账单截图到此处上传</p>
|
||||||
<p className="ant-upload-hint">支持 png / jpg / webp,单次可多选</p>
|
<p className="ant-upload-hint">支持 png / jpg / webp,单次可多选</p>
|
||||||
</Dragger>
|
</Dragger>
|
||||||
<div style={{ marginTop: 16 }}>
|
|
||||||
<Button type="link" onClick={loadScreenshots} style={{ padding: 0 }}>刷新截图列表</Button>
|
<div style={{ marginTop: 12 }}>
|
||||||
<List
|
<Space>
|
||||||
style={{ marginTop: 8 }}
|
<Button type="link" onClick={loadScreenshots} style={{ padding: 0 }}>刷新截图列表</Button>
|
||||||
grid={{ gutter: 16, column: 4 }}
|
<Button type="primary" onClick={handleBatchExtract} disabled={!selectedRowKeys.length} loading={extractingIds.length > 0}>
|
||||||
dataSource={screenshots}
|
一键识别(已勾选)
|
||||||
renderItem={(item) => (
|
</Button>
|
||||||
<List.Item>
|
</Space>
|
||||||
<Card size="small" title={item.filename}>
|
|
||||||
<div style={{ marginBottom: 8 }}>
|
|
||||||
<Tag color={item.status === "extracted" ? "green" : item.status === "failed" ? "red" : "default"}>
|
|
||||||
{item.status === "extracted" ? "已识别" : item.status === "failed" ? "失败" : "待识别"}
|
|
||||||
</Tag>
|
|
||||||
</div>
|
|
||||||
{item.status === "pending" && (
|
|
||||||
<Button
|
|
||||||
type="primary"
|
|
||||||
size="small"
|
|
||||||
icon={<ThunderboltOutlined />}
|
|
||||||
loading={extractingId === item.id}
|
|
||||||
onClick={() => handleExtract(item.id)}
|
|
||||||
>
|
|
||||||
识别交易
|
|
||||||
</Button>
|
|
||||||
)}
|
|
||||||
</Card>
|
|
||||||
</List.Item>
|
|
||||||
)}
|
|
||||||
/>
|
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
<Table
|
||||||
|
style={{ marginTop: 8 }}
|
||||||
|
rowKey="id"
|
||||||
|
dataSource={screenshots}
|
||||||
|
columns={columns}
|
||||||
|
rowSelection={rowSelection}
|
||||||
|
pagination={false}
|
||||||
|
size="small"
|
||||||
|
/>
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,5 +3,7 @@
|
|||||||
}
|
}
|
||||||
body {
|
body {
|
||||||
margin: 0;
|
margin: 0;
|
||||||
|
font-size: 16px;
|
||||||
|
line-height: 1.6;
|
||||||
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, "Helvetica Neue", Arial, sans-serif;
|
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, "Helvetica Neue", Arial, sans-serif;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,12 +1,20 @@
|
|||||||
import { useEffect, useState } from "react";
|
import { useEffect, useState } from "react";
|
||||||
import { Card, Form, Input, Select, Button, Alert, Space, message } from "antd";
|
import { Card, Form, Input, Select, Button, Alert, Space, Divider, Descriptions, message } from "antd";
|
||||||
import {
|
import {
|
||||||
api,
|
api,
|
||||||
type RuntimeSettings,
|
type RuntimeSettings,
|
||||||
|
type ProviderKey,
|
||||||
getApiBaseUrl,
|
getApiBaseUrl,
|
||||||
setApiBaseUrl,
|
setApiBaseUrl,
|
||||||
} from "../services/api";
|
} from "../services/api";
|
||||||
|
|
||||||
|
const PROVIDER_OPTIONS = [
|
||||||
|
{ label: "OpenAI", value: "openai" as ProviderKey },
|
||||||
|
{ label: "Anthropic", value: "anthropic" as ProviderKey },
|
||||||
|
{ label: "DeepSeek", value: "deepseek" as ProviderKey },
|
||||||
|
{ label: "自定义(OpenAI兼容)", value: "custom_openai" as ProviderKey },
|
||||||
|
];
|
||||||
|
|
||||||
export default function Settings() {
|
export default function Settings() {
|
||||||
const [form] = Form.useForm();
|
const [form] = Form.useForm();
|
||||||
const [loading, setLoading] = useState(false);
|
const [loading, setLoading] = useState(false);
|
||||||
@@ -20,9 +28,11 @@ export default function Settings() {
|
|||||||
setRuntime(data);
|
setRuntime(data);
|
||||||
form.setFieldsValue({
|
form.setFieldsValue({
|
||||||
system_api_base_url: getApiBaseUrl(),
|
system_api_base_url: getApiBaseUrl(),
|
||||||
llm_provider: data.llm_provider,
|
ocr_provider: data.ocr_provider,
|
||||||
|
ocr_model: data.ocr_model,
|
||||||
|
inference_provider: data.inference_provider,
|
||||||
|
inference_model: data.inference_model,
|
||||||
custom_openai_base_url: data.base_urls?.custom_openai || "",
|
custom_openai_base_url: data.base_urls?.custom_openai || "",
|
||||||
custom_openai_model: data.models?.custom_openai || "gpt-4o-mini",
|
|
||||||
});
|
});
|
||||||
} catch {
|
} catch {
|
||||||
message.error("加载设置失败");
|
message.error("加载设置失败");
|
||||||
@@ -35,21 +45,15 @@ export default function Settings() {
|
|||||||
loadSettings();
|
loadSettings();
|
||||||
}, []);
|
}, []);
|
||||||
|
|
||||||
const onFinish = async (values: {
|
const onFinish = async (values: Record<string, string | undefined>) => {
|
||||||
system_api_base_url?: string;
|
|
||||||
llm_provider: "openai" | "anthropic" | "deepseek" | "custom_openai";
|
|
||||||
openai_api_key?: string;
|
|
||||||
anthropic_api_key?: string;
|
|
||||||
deepseek_api_key?: string;
|
|
||||||
custom_openai_api_key?: string;
|
|
||||||
custom_openai_base_url?: string;
|
|
||||||
custom_openai_model?: string;
|
|
||||||
}) => {
|
|
||||||
setSaving(true);
|
setSaving(true);
|
||||||
try {
|
try {
|
||||||
setApiBaseUrl(values.system_api_base_url || "");
|
setApiBaseUrl(values.system_api_base_url || "");
|
||||||
const payload = {
|
const payload: Record<string, string | undefined> = {
|
||||||
llm_provider: values.llm_provider,
|
ocr_provider: values.ocr_provider,
|
||||||
|
ocr_model: values.ocr_model?.trim() || undefined,
|
||||||
|
inference_provider: values.inference_provider,
|
||||||
|
inference_model: values.inference_model?.trim() || undefined,
|
||||||
openai_api_key: values.openai_api_key?.trim() || undefined,
|
openai_api_key: values.openai_api_key?.trim() || undefined,
|
||||||
anthropic_api_key: values.anthropic_api_key?.trim() || undefined,
|
anthropic_api_key: values.anthropic_api_key?.trim() || undefined,
|
||||||
deepseek_api_key: values.deepseek_api_key?.trim() || undefined,
|
deepseek_api_key: values.deepseek_api_key?.trim() || undefined,
|
||||||
@@ -59,7 +63,7 @@ export default function Settings() {
|
|||||||
};
|
};
|
||||||
const data = await api.settings.update(payload);
|
const data = await api.settings.update(payload);
|
||||||
setRuntime(data);
|
setRuntime(data);
|
||||||
message.success("设置已保存并生效(含系统 API BaseURL)");
|
message.success("设置已保存");
|
||||||
} catch {
|
} catch {
|
||||||
message.error("保存失败");
|
message.error("保存失败");
|
||||||
} finally {
|
} finally {
|
||||||
@@ -68,35 +72,42 @@ export default function Settings() {
|
|||||||
};
|
};
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Card title="LLM 设置" loading={loading}>
|
<Card title="模型与接口设置" loading={loading}>
|
||||||
<Alert
|
<Alert
|
||||||
type="info"
|
type="info"
|
||||||
showIcon
|
showIcon
|
||||||
style={{ marginBottom: 16 }}
|
style={{ marginBottom: 16 }}
|
||||||
message="LLM API Key 仅在当前服务进程运行期内生效,不会自动写入磁盘。"
|
message="API Key 仅在当前服务进程运行期内生效,不写入磁盘。OCR 模型用于从截图中提取交易,推理模型用于生成报告等文本推理任务。"
|
||||||
/>
|
/>
|
||||||
<Form form={form} layout="vertical" onFinish={onFinish}>
|
<Form form={form} layout="vertical" onFinish={onFinish}>
|
||||||
<Form.Item
|
<Form.Item
|
||||||
label="系统 API BaseURL(前端请求后端)"
|
label="系统 API BaseURL(前端请求后端)"
|
||||||
name="system_api_base_url"
|
name="system_api_base_url"
|
||||||
extra="默认 /api;若前后端分离部署,可填如 http://127.0.0.1:8000/api"
|
extra="默认 /api;前后端分离时填 http://127.0.0.1:8000/api"
|
||||||
>
|
>
|
||||||
<Input placeholder="/api 或 http://127.0.0.1:8000/api" />
|
<Input placeholder="/api" />
|
||||||
</Form.Item>
|
</Form.Item>
|
||||||
<Form.Item
|
|
||||||
label="默认模型提供商"
|
<Divider orientation="left">OCR 视觉模型(截图识别交易)</Divider>
|
||||||
name="llm_provider"
|
|
||||||
rules={[{ required: true, message: "请选择提供商" }]}
|
<Form.Item label="OCR 提供商" name="ocr_provider" rules={[{ required: true }]}>
|
||||||
>
|
<Select options={PROVIDER_OPTIONS} />
|
||||||
<Select
|
|
||||||
options={[
|
|
||||||
{ label: "OpenAI", value: "openai" },
|
|
||||||
{ label: "Anthropic", value: "anthropic" },
|
|
||||||
{ label: "DeepSeek", value: "deepseek" },
|
|
||||||
{ label: "自定义(OpenAI兼容)", value: "custom_openai" },
|
|
||||||
]}
|
|
||||||
/>
|
|
||||||
</Form.Item>
|
</Form.Item>
|
||||||
|
<Form.Item label="OCR 模型名" name="ocr_model" extra="留空则使用该提供商的默认模型">
|
||||||
|
<Input placeholder="如 gpt-4o / qwen-vl-max / ..." />
|
||||||
|
</Form.Item>
|
||||||
|
|
||||||
|
<Divider orientation="left">推理模型(报告生成等文本推理)</Divider>
|
||||||
|
|
||||||
|
<Form.Item label="推理提供商" name="inference_provider" rules={[{ required: true }]}>
|
||||||
|
<Select options={PROVIDER_OPTIONS} />
|
||||||
|
</Form.Item>
|
||||||
|
<Form.Item label="推理模型名" name="inference_model" extra="留空则使用该提供商的默认模型">
|
||||||
|
<Input placeholder="如 gpt-4o-mini / deepseek-chat / ..." />
|
||||||
|
</Form.Item>
|
||||||
|
|
||||||
|
<Divider orientation="left">API Key 与厂商配置</Divider>
|
||||||
|
|
||||||
<Form.Item label="OpenAI API Key" name="openai_api_key">
|
<Form.Item label="OpenAI API Key" name="openai_api_key">
|
||||||
<Input.Password placeholder="sk-..." />
|
<Input.Password placeholder="sk-..." />
|
||||||
</Form.Item>
|
</Form.Item>
|
||||||
@@ -113,29 +124,33 @@ export default function Settings() {
|
|||||||
>
|
>
|
||||||
<Input placeholder="https://api.xxx.com/v1" />
|
<Input placeholder="https://api.xxx.com/v1" />
|
||||||
</Form.Item>
|
</Form.Item>
|
||||||
<Form.Item label="自定义厂商 Model" name="custom_openai_model">
|
<Form.Item label="自定义厂商默认模型" name="custom_openai_model">
|
||||||
<Input placeholder="gpt-4o-mini / qwen-vl-plus / ..." />
|
<Input placeholder="gpt-4o-mini / qwen-vl-plus / ..." />
|
||||||
</Form.Item>
|
</Form.Item>
|
||||||
<Form.Item label="自定义厂商 API Key" name="custom_openai_api_key">
|
<Form.Item label="自定义厂商 API Key" name="custom_openai_api_key">
|
||||||
<Input.Password placeholder="sk-..." />
|
<Input.Password placeholder="sk-..." />
|
||||||
</Form.Item>
|
</Form.Item>
|
||||||
|
|
||||||
<Space>
|
<Space>
|
||||||
<Button type="primary" htmlType="submit" loading={saving}>
|
<Button type="primary" htmlType="submit" loading={saving}>保存设置</Button>
|
||||||
保存设置
|
|
||||||
</Button>
|
|
||||||
<Button onClick={loadSettings}>刷新</Button>
|
<Button onClick={loadSettings}>刷新</Button>
|
||||||
</Space>
|
</Space>
|
||||||
</Form>
|
</Form>
|
||||||
|
|
||||||
{runtime && (
|
{runtime && (
|
||||||
<Card title="当前状态" size="small" style={{ marginTop: 16 }}>
|
<Card title="当前生效配置" size="small" style={{ marginTop: 16 }}>
|
||||||
<div>系统 API BaseURL: {getApiBaseUrl()}</div>
|
<Descriptions column={2} size="small" bordered>
|
||||||
<div>当前提供商: {runtime.llm_provider}</div>
|
<Descriptions.Item label="系统 API BaseURL">{getApiBaseUrl()}</Descriptions.Item>
|
||||||
<div>OpenAI Key: {runtime.has_keys.openai ? "已配置" : "未配置"}</div>
|
<Descriptions.Item label="自定义厂商 BaseURL">{runtime.base_urls.custom_openai || "-"}</Descriptions.Item>
|
||||||
<div>Anthropic Key: {runtime.has_keys.anthropic ? "已配置" : "未配置"}</div>
|
<Descriptions.Item label="OCR 提供商">{runtime.ocr_provider}</Descriptions.Item>
|
||||||
<div>DeepSeek Key: {runtime.has_keys.deepseek ? "已配置" : "未配置"}</div>
|
<Descriptions.Item label="OCR 模型">{runtime.ocr_model}</Descriptions.Item>
|
||||||
<div>自定义厂商 Key: {runtime.has_keys.custom_openai ? "已配置" : "未配置"}</div>
|
<Descriptions.Item label="推理提供商">{runtime.inference_provider}</Descriptions.Item>
|
||||||
<div>自定义厂商 BaseURL: {runtime.base_urls.custom_openai || "-"}</div>
|
<Descriptions.Item label="推理模型">{runtime.inference_model}</Descriptions.Item>
|
||||||
|
<Descriptions.Item label="OpenAI Key">{runtime.has_keys.openai ? "已配置" : "未配置"}</Descriptions.Item>
|
||||||
|
<Descriptions.Item label="Anthropic Key">{runtime.has_keys.anthropic ? "已配置" : "未配置"}</Descriptions.Item>
|
||||||
|
<Descriptions.Item label="DeepSeek Key">{runtime.has_keys.deepseek ? "已配置" : "未配置"}</Descriptions.Item>
|
||||||
|
<Descriptions.Item label="自定义厂商 Key">{runtime.has_keys.custom_openai ? "已配置" : "未配置"}</Descriptions.Item>
|
||||||
|
</Descriptions>
|
||||||
</Card>
|
</Card>
|
||||||
)}
|
)}
|
||||||
</Card>
|
</Card>
|
||||||
|
|||||||
@@ -56,6 +56,13 @@ export interface ScreenshotItem {
|
|||||||
filename: string;
|
filename: string;
|
||||||
file_path: string;
|
file_path: string;
|
||||||
status: string;
|
status: string;
|
||||||
|
progress_step: string | null;
|
||||||
|
progress_percent: number;
|
||||||
|
progress_detail: string | null;
|
||||||
|
started_at: string | null;
|
||||||
|
finished_at: string | null;
|
||||||
|
duration_ms: number | null;
|
||||||
|
error_message: string | null;
|
||||||
created_at: string;
|
created_at: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -72,16 +79,24 @@ export interface FlowGraph {
|
|||||||
edges: Array<{ source: string; target: string; amount: number; count?: number }>;
|
edges: Array<{ source: string; target: string; amount: number; count?: number }>;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export type ProviderKey = "openai" | "anthropic" | "deepseek" | "custom_openai";
|
||||||
|
|
||||||
export interface RuntimeSettings {
|
export interface RuntimeSettings {
|
||||||
llm_provider: "openai" | "anthropic" | "deepseek" | "custom_openai";
|
ocr_provider: ProviderKey;
|
||||||
providers: Array<"openai" | "anthropic" | "deepseek" | "custom_openai">;
|
ocr_model: string;
|
||||||
models: Record<string, string>;
|
inference_provider: ProviderKey;
|
||||||
|
inference_model: string;
|
||||||
|
providers: ProviderKey[];
|
||||||
|
provider_defaults: Record<string, string>;
|
||||||
base_urls: Record<string, string>;
|
base_urls: Record<string, string>;
|
||||||
has_keys: Record<string, boolean>;
|
has_keys: Record<string, boolean>;
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface RuntimeSettingsUpdate {
|
export interface RuntimeSettingsUpdate {
|
||||||
llm_provider?: "openai" | "anthropic" | "deepseek" | "custom_openai";
|
ocr_provider?: ProviderKey;
|
||||||
|
ocr_model?: string;
|
||||||
|
inference_provider?: ProviderKey;
|
||||||
|
inference_model?: string;
|
||||||
openai_api_key?: string;
|
openai_api_key?: string;
|
||||||
anthropic_api_key?: string;
|
anthropic_api_key?: string;
|
||||||
deepseek_api_key?: string;
|
deepseek_api_key?: string;
|
||||||
|
|||||||
Reference in New Issue
Block a user