From fd2b574d5af4b0f9f9b334f7cfc1a3dbde06bc43 Mon Sep 17 00:00:00 2001 From: ntnt Date: Tue, 10 Mar 2026 14:25:21 +0800 Subject: [PATCH] update: upload fix --- .env.example | 11 +- backend/.env.example | 17 ++ backend/app/api/screenshots.py | 92 +++++++- backend/app/api/settings.py | 5 +- backend/app/config.py | 74 +++++-- backend/app/models/screenshot.py | 11 +- backend/app/schemas/screenshot.py | 7 + backend/app/services/extractor.py | 17 +- backend/app/services/llm/base.py | 15 +- backend/app/services/llm/claude_vision.py | 26 ++- .../app/services/llm/custom_openai_vision.py | 28 ++- backend/app/services/llm/deepseek_vision.py | 29 ++- backend/app/services/llm/openai_vision.py | 22 +- backend/app/services/llm/router.py | 36 ++-- frontend/src/App.tsx | 12 +- .../src/components/ScreenshotUploader.tsx | 199 ++++++++++++++---- frontend/src/index.css | 2 + frontend/src/pages/Settings.tsx | 105 +++++---- frontend/src/services/api.ts | 23 +- 19 files changed, 575 insertions(+), 156 deletions(-) create mode 100644 backend/.env.example diff --git a/.env.example b/.env.example index 33027f6..269d0f4 100644 --- a/.env.example +++ b/.env.example @@ -1,7 +1,14 @@ DATABASE_URL=sqlite+aiosqlite:///./fund_tracer.db -LLM_PROVIDER=openai -# Optional: choose model names +# --- OCR model (screenshot -> transactions) --- +OCR_PROVIDER=openai +OCR_MODEL= + +# --- Inference model (report generation, reasoning) --- +INFERENCE_PROVIDER=openai +INFERENCE_MODEL= + +# Provider default model names OPENAI_MODEL=gpt-4o ANTHROPIC_MODEL=claude-3-5-sonnet-20241022 DEEPSEEK_MODEL=deepseek-chat diff --git a/backend/.env.example b/backend/.env.example new file mode 100644 index 0000000..33027f6 --- /dev/null +++ b/backend/.env.example @@ -0,0 +1,17 @@ +DATABASE_URL=sqlite+aiosqlite:///./fund_tracer.db +LLM_PROVIDER=openai + +# Optional: choose model names +OPENAI_MODEL=gpt-4o +ANTHROPIC_MODEL=claude-3-5-sonnet-20241022 +DEEPSEEK_MODEL=deepseek-chat +CUSTOM_OPENAI_MODEL=gpt-4o-mini + +# Custom OpenAI-compatible provider +CUSTOM_OPENAI_BASE_URL= +CUSTOM_OPENAI_API_KEY= + +# API keys +OPENAI_API_KEY= +ANTHROPIC_API_KEY= +DEEPSEEK_API_KEY= diff --git a/backend/app/api/screenshots.py b/backend/app/api/screenshots.py index b90f95a..cccdda3 100644 --- a/backend/app/api/screenshots.py +++ b/backend/app/api/screenshots.py @@ -1,6 +1,7 @@ """Screenshot upload and extraction API.""" import uuid +from datetime import datetime from pathlib import Path from fastapi import APIRouter, Depends, HTTPException, UploadFile, File @@ -84,18 +85,105 @@ async def extract_transactions( if not full_path.exists(): raise HTTPException(status_code=404, detail="File not found on disk") image_bytes = full_path.read_bytes() + started_at = datetime.utcnow() + # 每次开始新一轮识别都重置计时,确保耗时是“本次分析”而不是历史累计 + screenshot.started_at = started_at + screenshot.finished_at = None + screenshot.duration_ms = None + screenshot.error_message = None + screenshot.progress_step = "starting" + screenshot.progress_percent = 0 + screenshot.progress_detail = "准备开始识别" + await db.commit() + + async def update_progress(step: str, percent: int, detail: str): + screenshot.status = "processing" + screenshot.progress_step = step + screenshot.progress_percent = percent + screenshot.progress_detail = detail + await db.commit() + try: - transactions = await extract_and_save(case_id, screenshot_id, image_bytes) + await update_progress("file_loaded", 10, "截图读取完成") + transactions = await extract_and_save( + case_id, + screenshot_id, + image_bytes, + progress_hook=update_progress, + ) except Exception as e: + error_detail = _classify_error(e) r = await db.execute(select(Screenshot).where(Screenshot.id == screenshot_id)) sc = r.scalar_one_or_none() if sc: sc.status = "failed" + sc.progress_step = "failed" + sc.progress_percent = 100 + sc.progress_detail = "识别失败" + sc.finished_at = datetime.utcnow() + if sc.started_at: + sc.duration_ms = int((sc.finished_at - sc.started_at).total_seconds() * 1000) + sc.error_message = error_detail await db.commit() - raise HTTPException(status_code=502, detail=f"Extraction failed: {e!s}") + raise HTTPException(status_code=502, detail=error_detail) r = await db.execute(select(Screenshot).where(Screenshot.id == screenshot_id)) sc = r.scalar_one_or_none() if sc: sc.status = "extracted" + sc.progress_step = "completed" + sc.progress_percent = 100 + sc.progress_detail = "识别完成" + sc.finished_at = datetime.utcnow() + if sc.started_at: + sc.duration_ms = int((sc.finished_at - sc.started_at).total_seconds() * 1000) + sc.error_message = None await db.commit() return TransactionListResponse(items=transactions) + + +def _classify_error(e: Exception) -> str: + """Produce a human-readable, categorized error message.""" + name = type(e).__name__ + msg = str(e) + + if isinstance(e, ValueError): + return f"配置错误: {msg}" + + # OpenAI SDK errors + try: + from openai import AuthenticationError, RateLimitError, APIConnectionError, BadRequestError, APIStatusError, APITimeoutError + if isinstance(e, AuthenticationError): + return f"API Key 无效或已过期 ({name}): {msg}" + if isinstance(e, RateLimitError): + return f"API 调用频率超限,请稍后重试 ({name}): {msg}" + if isinstance(e, APITimeoutError): + return f"模型服务响应超时,请检查 BaseURL/模型可用性或稍后重试 ({name}): {msg}" + if isinstance(e, APIConnectionError): + return f"无法连接到模型服务,请检查网络或 BaseURL ({name}): {msg}" + if isinstance(e, BadRequestError): + return f"请求被模型服务拒绝(可能模型名错误或不支持图片) ({name}): {msg}" + if isinstance(e, APIStatusError): + return f"模型服务返回错误 (HTTP {e.status_code}): {msg}" + except ImportError: + pass + + # Anthropic SDK errors + try: + from anthropic import AuthenticationError as AnthAuthError, RateLimitError as AnthRateError + from anthropic import APIConnectionError as AnthConnError, BadRequestError as AnthBadReq + if isinstance(e, AnthAuthError): + return f"Anthropic API Key 无效或已过期: {msg}" + if isinstance(e, AnthRateError): + return f"Anthropic API 调用频率超限: {msg}" + if isinstance(e, AnthConnError): + return f"无法连接到 Anthropic 服务: {msg}" + if isinstance(e, AnthBadReq): + return f"Anthropic 请求被拒绝: {msg}" + except ImportError: + pass + + # Connection / network + if "connect" in msg.lower() or "timeout" in msg.lower(): + return f"网络连接失败或超时: {msg}" + + return f"识别失败 ({name}): {msg}" diff --git a/backend/app/api/settings.py b/backend/app/api/settings.py index 69b0fd3..f69c0e4 100644 --- a/backend/app/api/settings.py +++ b/backend/app/api/settings.py @@ -9,7 +9,10 @@ router = APIRouter() class SettingsUpdate(BaseModel): - llm_provider: str | None = None + ocr_provider: str | None = None + ocr_model: str | None = None + inference_provider: str | None = None + inference_model: str | None = None openai_api_key: str | None = None anthropic_api_key: str | None = None deepseek_api_key: str | None = None diff --git a/backend/app/config.py b/backend/app/config.py index aa915aa..f917d6d 100644 --- a/backend/app/config.py +++ b/backend/app/config.py @@ -20,18 +20,30 @@ class Settings(BaseSettings): max_upload_size_mb: int = 20 allowed_extensions: set[str] = {"png", "jpg", "jpeg", "webp"} - # LLM - llm_provider: str = "openai" # openai | anthropic | deepseek | custom_openai + # --- OCR (vision) model --- + ocr_provider: str = "openai" # openai | anthropic | deepseek | custom_openai + ocr_model: str | None = None # if None, falls back to provider default + + # --- Inference (reasoning) model --- + inference_provider: str = "openai" + inference_model: str | None = None + + # --- Provider credentials (shared between OCR and inference) --- openai_api_key: str | None = None anthropic_api_key: str | None = None deepseek_api_key: str | None = None custom_openai_api_key: str | None = None custom_openai_base_url: str | None = None + + # Provider default model names (used when ocr_model / inference_model is None) openai_model: str = "gpt-4o" anthropic_model: str = "claude-3-5-sonnet-20241022" deepseek_model: str = "deepseek-chat" custom_openai_model: str = "gpt-4o-mini" + # Legacy compat: llm_provider maps to ocr_provider on load + llm_provider: str | None = None + class Config: env_file = ".env" env_file_encoding = "utf-8" @@ -40,8 +52,24 @@ class Settings(BaseSettings): _runtime_overrides: dict[str, str | None] = {} +_ALLOWED_RUNTIME_KEYS = { + "ocr_provider", + "ocr_model", + "inference_provider", + "inference_model", + "openai_api_key", + "anthropic_api_key", + "deepseek_api_key", + "custom_openai_api_key", + "custom_openai_base_url", + "custom_openai_model", +} + def _apply_overrides(settings: Settings) -> Settings: + # Legacy: if llm_provider is set but ocr_provider is default, use it + if settings.llm_provider and settings.ocr_provider == "openai": + settings.ocr_provider = settings.llm_provider for key, value in _runtime_overrides.items(): if hasattr(settings, key): setattr(settings, key, value) @@ -53,19 +81,32 @@ def get_settings() -> Settings: return _apply_overrides(Settings()) -def update_runtime_settings(payload: dict[str, str | None]) -> Settings: - """Update runtime settings and refresh cached Settings object.""" - allowed = { - "llm_provider", - "openai_api_key", - "anthropic_api_key", - "deepseek_api_key", - "custom_openai_api_key", - "custom_openai_base_url", - "custom_openai_model", +def _resolve_model(provider: str, explicit_model: str | None, settings: Settings) -> str: + """Return the model name to use for a given provider.""" + if explicit_model: + return explicit_model + defaults = { + "openai": settings.openai_model, + "anthropic": settings.anthropic_model, + "deepseek": settings.deepseek_model, + "custom_openai": settings.custom_openai_model, } + return defaults.get(provider, settings.openai_model) + + +def get_ocr_model() -> str: + s = get_settings() + return _resolve_model(s.ocr_provider, s.ocr_model, s) + + +def get_inference_model() -> str: + s = get_settings() + return _resolve_model(s.inference_provider, s.inference_model, s) + + +def update_runtime_settings(payload: dict[str, str | None]) -> Settings: for key, value in payload.items(): - if key in allowed: + if key in _ALLOWED_RUNTIME_KEYS: _runtime_overrides[key] = value get_settings.cache_clear() return get_settings() @@ -74,9 +115,12 @@ def update_runtime_settings(payload: dict[str, str | None]) -> Settings: def public_settings() -> dict: s = get_settings() return { - "llm_provider": s.llm_provider, + "ocr_provider": s.ocr_provider, + "ocr_model": get_ocr_model(), + "inference_provider": s.inference_provider, + "inference_model": get_inference_model(), "providers": ["openai", "anthropic", "deepseek", "custom_openai"], - "models": { + "provider_defaults": { "openai": s.openai_model, "anthropic": s.anthropic_model, "deepseek": s.deepseek_model, diff --git a/backend/app/models/screenshot.py b/backend/app/models/screenshot.py index 2229742..f12720a 100644 --- a/backend/app/models/screenshot.py +++ b/backend/app/models/screenshot.py @@ -2,7 +2,7 @@ from __future__ import annotations from datetime import datetime -from sqlalchemy import String, DateTime, ForeignKey +from sqlalchemy import String, Text, DateTime, ForeignKey from sqlalchemy.orm import Mapped, mapped_column, relationship from app.models.database import Base @@ -15,7 +15,14 @@ class Screenshot(Base): case_id: Mapped[int] = mapped_column(ForeignKey("cases.id", ondelete="CASCADE"), index=True) filename: Mapped[str] = mapped_column(String(255)) file_path: Mapped[str] = mapped_column(String(512)) - status: Mapped[str] = mapped_column(String(32), default="pending") # pending | extracted | failed + status: Mapped[str] = mapped_column(String(32), default="pending") # pending | processing | extracted | failed + progress_step: Mapped[str | None] = mapped_column(String(64), nullable=True, default=None) + progress_percent: Mapped[int] = mapped_column(default=0) + progress_detail: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None) + started_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True, default=None) + finished_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True, default=None) + duration_ms: Mapped[int | None] = mapped_column(nullable=True, default=None) + error_message: Mapped[str | None] = mapped_column(Text, nullable=True, default=None) created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow) case: Mapped["Case"] = relationship("Case", back_populates="screenshots") diff --git a/backend/app/schemas/screenshot.py b/backend/app/schemas/screenshot.py index a10ab43..6697cef 100644 --- a/backend/app/schemas/screenshot.py +++ b/backend/app/schemas/screenshot.py @@ -11,6 +11,13 @@ class ScreenshotResponse(BaseModel): filename: str file_path: str status: str + progress_step: str | None = None + progress_percent: int = 0 + progress_detail: str | None = None + started_at: datetime | None = None + finished_at: datetime | None = None + duration_ms: int | None = None + error_message: str | None = None created_at: datetime diff --git a/backend/app/services/extractor.py b/backend/app/services/extractor.py index 809fdde..428c483 100644 --- a/backend/app/services/extractor.py +++ b/backend/app/services/extractor.py @@ -1,7 +1,7 @@ """Transaction data extraction: LLM Vision + persistence.""" from app.models import Transaction -from app.models.database import async_session_maker +import app.models.database as db_module from app.schemas.transaction import TransactionExtractItem, TransactionResponse from app.services.llm import get_llm_provider @@ -10,15 +10,26 @@ async def extract_and_save( case_id: int, screenshot_id: int, image_bytes: bytes, + progress_hook=None, ) -> list[TransactionResponse]: """ Run vision extraction on image and persist transactions to DB. Returns list of created transactions; low-confidence items are still saved but flagged. """ + if progress_hook: + await progress_hook("init", 5, "初始化识别上下文") provider = get_llm_provider() + if progress_hook: + await progress_hook("provider_ready", 15, f"已加载模型提供商: {type(provider).__name__}") + if progress_hook: + await progress_hook("calling_model", 35, "调用视觉模型识别截图中交易") items: list[TransactionExtractItem] = await provider.extract_from_image(image_bytes) + if progress_hook: + await progress_hook("model_returned", 70, f"模型返回 {len(items)} 条交易") results: list[TransactionResponse] = [] - async with async_session_maker() as session: + async with db_module.async_session_maker() as session: + if progress_hook: + await progress_hook("db_writing", 85, "写入交易记录到数据库") for it in items: t = Transaction( case_id=case_id, @@ -39,4 +50,6 @@ async def extract_and_save( await session.flush() results.append(TransactionResponse.model_validate(t)) await session.commit() + if progress_hook: + await progress_hook("completed", 100, "识别完成") return results diff --git a/backend/app/services/llm/base.py b/backend/app/services/llm/base.py index e148b4a..857fe52 100644 --- a/backend/app/services/llm/base.py +++ b/backend/app/services/llm/base.py @@ -6,13 +6,16 @@ from app.schemas.transaction import TransactionExtractItem class BaseLLMProvider(ABC): - """Abstract base for LLM vision providers. Each provider implements extract_from_image.""" + """Abstract base for LLM providers. Supports optional model override.""" + + def __init__(self, model_override: str | None = None): + self._model_override = model_override @abstractmethod async def extract_from_image(self, image_bytes: bytes) -> list[TransactionExtractItem]: - """ - Analyze a billing screenshot and return structured transaction list. - :param image_bytes: Raw image file content (PNG/JPEG) - :return: List of extracted transactions (may be empty or partial on failure) - """ + pass + + @abstractmethod + async def chat(self, system: str, user: str) -> str: + """Plain text chat (for inference/reasoning tasks like report generation).""" pass diff --git a/backend/app/services/llm/claude_vision.py b/backend/app/services/llm/claude_vision.py index 66e4487..fb278d7 100644 --- a/backend/app/services/llm/claude_vision.py +++ b/backend/app/services/llm/claude_vision.py @@ -1,8 +1,6 @@ """Anthropic Claude Vision provider.""" import base64 -import json -import re from anthropic import AsyncAnthropic from app.config import get_settings @@ -13,6 +11,9 @@ from app.services.llm.openai_vision import _parse_json_array class ClaudeVisionProvider(BaseLLMProvider): + def _get_model(self) -> str: + return self._model_override or get_settings().anthropic_model + async def extract_from_image(self, image_bytes: bytes) -> list[TransactionExtractItem]: settings = get_settings() if not settings.anthropic_api_key: @@ -20,14 +21,12 @@ class ClaudeVisionProvider(BaseLLMProvider): client = AsyncAnthropic(api_key=settings.anthropic_api_key) b64 = base64.standard_b64encode(image_bytes).decode("ascii") messages = get_extract_messages(b64) - # Claude API: user message with content block list user_content = messages[1]["content"] content_blocks = [] for block in user_content: if block["type"] == "text": content_blocks.append({"type": "text", "text": block["text"]}) elif block["type"] == "image_url": - # Claude expects base64 without data URL prefix content_blocks.append({ "type": "image", "source": { @@ -37,7 +36,7 @@ class ClaudeVisionProvider(BaseLLMProvider): }, }) response = await client.messages.create( - model=settings.anthropic_model, + model=self._get_model(), max_tokens=4096, system=messages[0]["content"], messages=[{"role": "user", "content": content_blocks}], @@ -47,3 +46,20 @@ class ClaudeVisionProvider(BaseLLMProvider): if hasattr(block, "text"): text += block.text return _parse_json_array(text or "[]") + + async def chat(self, system: str, user: str) -> str: + settings = get_settings() + if not settings.anthropic_api_key: + raise ValueError("ANTHROPIC_API_KEY is not set") + client = AsyncAnthropic(api_key=settings.anthropic_api_key) + response = await client.messages.create( + model=self._get_model(), + max_tokens=4096, + system=system, + messages=[{"role": "user", "content": user}], + ) + text = "" + for block in response.content: + if hasattr(block, "text"): + text += block.text + return text or "" diff --git a/backend/app/services/llm/custom_openai_vision.py b/backend/app/services/llm/custom_openai_vision.py index 96077f4..545b412 100644 --- a/backend/app/services/llm/custom_openai_vision.py +++ b/backend/app/services/llm/custom_openai_vision.py @@ -11,22 +11,44 @@ from app.services.llm.openai_vision import _parse_json_array class CustomOpenAICompatibleProvider(BaseLLMProvider): - async def extract_from_image(self, image_bytes: bytes) -> list[TransactionExtractItem]: + def _get_client(self) -> AsyncOpenAI: settings = get_settings() if not settings.custom_openai_api_key: raise ValueError("CUSTOM_OPENAI_API_KEY is not set") if not settings.custom_openai_base_url: raise ValueError("CUSTOM_OPENAI_BASE_URL is not set") - client = AsyncOpenAI( + return AsyncOpenAI( api_key=settings.custom_openai_api_key, base_url=settings.custom_openai_base_url, + timeout=45.0, + max_retries=1, ) + + def _get_model(self) -> str: + return self._model_override or get_settings().custom_openai_model + + async def extract_from_image(self, image_bytes: bytes) -> list[TransactionExtractItem]: + client = self._get_client() b64 = base64.standard_b64encode(image_bytes).decode("ascii") messages = get_extract_messages(b64) response = await client.chat.completions.create( - model=settings.custom_openai_model, + model=self._get_model(), messages=messages, max_tokens=4096, + timeout=45.0, ) text = response.choices[0].message.content or "[]" return _parse_json_array(text) + + async def chat(self, system: str, user: str) -> str: + client = self._get_client() + response = await client.chat.completions.create( + model=self._get_model(), + messages=[ + {"role": "system", "content": system}, + {"role": "user", "content": user}, + ], + max_tokens=4096, + timeout=45.0, + ) + return response.choices[0].message.content or "" diff --git a/backend/app/services/llm/deepseek_vision.py b/backend/app/services/llm/deepseek_vision.py index a567ecb..44c1859 100644 --- a/backend/app/services/llm/deepseek_vision.py +++ b/backend/app/services/llm/deepseek_vision.py @@ -9,26 +9,39 @@ from app.services.llm.base import BaseLLMProvider from app.prompts.extract_transaction import get_extract_messages from app.services.llm.openai_vision import _parse_json_array - -# DeepSeek vision endpoint (OpenAI-compatible) DEEPSEEK_BASE = "https://api.deepseek.com" class DeepSeekVisionProvider(BaseLLMProvider): - async def extract_from_image(self, image_bytes: bytes) -> list[TransactionExtractItem]: + def _get_client(self) -> AsyncOpenAI: settings = get_settings() if not settings.deepseek_api_key: raise ValueError("DEEPSEEK_API_KEY is not set") - client = AsyncOpenAI( - api_key=settings.deepseek_api_key, - base_url=DEEPSEEK_BASE, - ) + return AsyncOpenAI(api_key=settings.deepseek_api_key, base_url=DEEPSEEK_BASE) + + def _get_model(self) -> str: + return self._model_override or get_settings().deepseek_model + + async def extract_from_image(self, image_bytes: bytes) -> list[TransactionExtractItem]: + client = self._get_client() b64 = base64.standard_b64encode(image_bytes).decode("ascii") messages = get_extract_messages(b64) response = await client.chat.completions.create( - model=settings.deepseek_model, + model=self._get_model(), messages=messages, max_tokens=4096, ) text = response.choices[0].message.content or "[]" return _parse_json_array(text) + + async def chat(self, system: str, user: str) -> str: + client = self._get_client() + response = await client.chat.completions.create( + model=self._get_model(), + messages=[ + {"role": "system", "content": system}, + {"role": "user", "content": user}, + ], + max_tokens=4096, + ) + return response.choices[0].message.content or "" diff --git a/backend/app/services/llm/openai_vision.py b/backend/app/services/llm/openai_vision.py index 1009c44..95ea76c 100644 --- a/backend/app/services/llm/openai_vision.py +++ b/backend/app/services/llm/openai_vision.py @@ -12,6 +12,9 @@ from app.prompts.extract_transaction import get_extract_messages class OpenAIVisionProvider(BaseLLMProvider): + def _get_model(self) -> str: + return self._model_override or get_settings().openai_model + async def extract_from_image(self, image_bytes: bytes) -> list[TransactionExtractItem]: settings = get_settings() if not settings.openai_api_key: @@ -20,18 +23,32 @@ class OpenAIVisionProvider(BaseLLMProvider): b64 = base64.standard_b64encode(image_bytes).decode("ascii") messages = get_extract_messages(b64) response = await client.chat.completions.create( - model=settings.openai_model, + model=self._get_model(), messages=messages, max_tokens=4096, ) text = response.choices[0].message.content or "[]" return _parse_json_array(text) + async def chat(self, system: str, user: str) -> str: + settings = get_settings() + if not settings.openai_api_key: + raise ValueError("OPENAI_API_KEY is not set") + client = AsyncOpenAI(api_key=settings.openai_api_key) + response = await client.chat.completions.create( + model=self._get_model(), + messages=[ + {"role": "system", "content": system}, + {"role": "user", "content": user}, + ], + max_tokens=4096, + ) + return response.choices[0].message.content or "" + def _parse_json_array(text: str) -> list[TransactionExtractItem]: """Parse LLM response into list of TransactionExtractItem. Tolerates markdown and extra text.""" text = text.strip() - # Remove optional markdown code block if text.startswith("```"): text = re.sub(r"^```(?:json)?\s*", "", text) text = re.sub(r"\s*```\s*$", "", text) @@ -46,7 +63,6 @@ def _parse_json_array(text: str) -> list[TransactionExtractItem]: if not isinstance(item, dict): continue try: - # Normalize transaction_time: allow string -> datetime if isinstance(item.get("transaction_time"), str) and item["transaction_time"]: from dateutil import parser as date_parser item["transaction_time"] = date_parser.isoparse(item["transaction_time"]) diff --git a/backend/app/services/llm/router.py b/backend/app/services/llm/router.py index 4e64d70..4c78eb5 100644 --- a/backend/app/services/llm/router.py +++ b/backend/app/services/llm/router.py @@ -1,6 +1,6 @@ -"""LLM provider factory - returns provider by config.""" +"""LLM provider factory - returns provider by config, split by role (ocr / inference).""" -from app.config import get_settings +from app.config import get_settings, get_ocr_model, get_inference_model from app.services.llm.base import BaseLLMProvider from app.services.llm.openai_vision import OpenAIVisionProvider from app.services.llm.claude_vision import ClaudeVisionProvider @@ -8,15 +8,25 @@ from app.services.llm.deepseek_vision import DeepSeekVisionProvider from app.services.llm.custom_openai_vision import CustomOpenAICompatibleProvider -def get_llm_provider() -> BaseLLMProvider: +def _make_provider(provider_name: str, model_override: str | None = None) -> BaseLLMProvider: + name = (provider_name or "openai").lower() + if name == "openai": + return OpenAIVisionProvider(model_override=model_override) + if name == "anthropic": + return ClaudeVisionProvider(model_override=model_override) + if name == "deepseek": + return DeepSeekVisionProvider(model_override=model_override) + if name == "custom_openai": + return CustomOpenAICompatibleProvider(model_override=model_override) + return OpenAIVisionProvider(model_override=model_override) + + +def get_llm_provider(role: str = "ocr") -> BaseLLMProvider: + """ + role="ocr" -> uses ocr_provider + ocr_model + role="inference" -> uses inference_provider + inference_model + """ settings = get_settings() - provider = (settings.llm_provider or "openai").lower() - if provider == "openai": - return OpenAIVisionProvider() - if provider == "anthropic": - return ClaudeVisionProvider() - if provider == "deepseek": - return DeepSeekVisionProvider() - if provider == "custom_openai": - return CustomOpenAICompatibleProvider() - return OpenAIVisionProvider() + if role == "inference": + return _make_provider(settings.inference_provider, get_inference_model()) + return _make_provider(settings.ocr_provider, get_ocr_model()) diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx index 990ba6a..355a660 100644 --- a/frontend/src/App.tsx +++ b/frontend/src/App.tsx @@ -8,7 +8,17 @@ import Settings from "./pages/Settings"; function App() { return ( - + }> } /> diff --git a/frontend/src/components/ScreenshotUploader.tsx b/frontend/src/components/ScreenshotUploader.tsx index a5a178a..7db8b9e 100644 --- a/frontend/src/components/ScreenshotUploader.tsx +++ b/frontend/src/components/ScreenshotUploader.tsx @@ -1,6 +1,8 @@ -import { useState, useEffect } from "react"; -import { Upload, List, Button, Card, Tag, message } from "antd"; -import { InboxOutlined, ThunderboltOutlined } from "@ant-design/icons"; +import { useEffect, useMemo, useState } from "react"; +import type { Key } from "react"; +import { Upload, Table, Button, Tag, Alert, message, Space, Progress } from "antd"; +import type { ColumnsType } from "antd/es/table"; +import { InboxOutlined, ThunderboltOutlined, ReloadOutlined } from "@ant-design/icons"; import { api, type ScreenshotItem } from "../services/api"; const { Dragger } = Upload; @@ -10,10 +12,24 @@ interface Props { onExtracted?: () => void; } +function statusTag(item: ScreenshotItem) { + if (item.status === "extracted") return 已识别; + if (item.status === "processing") return 识别中; + if (item.status === "failed") return 失败; + return 待识别; +} + +function formatDuration(ms: number | null) { + if (ms == null) return "-"; + if (ms < 1000) return `${ms} ms`; + return `${(ms / 1000).toFixed(2)} s`; +} + export default function ScreenshotUploader({ caseId, onExtracted }: Props) { const [screenshots, setScreenshots] = useState([]); const [loading, setLoading] = useState(false); - const [extractingId, setExtractingId] = useState(null); + const [extractingIds, setExtractingIds] = useState([]); + const [selectedRowKeys, setSelectedRowKeys] = useState([]); const loadScreenshots = async () => { try { @@ -35,73 +51,168 @@ export default function ScreenshotUploader({ caseId, onExtracted }: Props) { } finally { setLoading(false); } - return false; // prevent default upload + return false; }; - const handleExtract = async (screenshotId: number) => { - setExtractingId(screenshotId); + const handleExtractSingle = async (screenshotId: number) => { + setExtractingIds((prev) => Array.from(new Set([...prev, screenshotId]))); try { await api.screenshots.extract(caseId, screenshotId); - message.success("识别完成"); await loadScreenshots(); onExtracted?.(); + message.success(`截图 ${screenshotId} 识别完成`); } catch (e: unknown) { - const msg = e && typeof e === "object" && "response" in e - ? (e as { response?: { data?: { detail?: string } } }).response?.data?.detail - : "识别失败"; - message.error(msg || "识别失败"); + const detail = + e && typeof e === "object" && "response" in e + ? (e as { response?: { data?: { detail?: string } } }).response?.data?.detail + : undefined; + message.error(detail || `截图 ${screenshotId} 识别失败`); + await loadScreenshots(); } finally { - setExtractingId(null); + setExtractingIds((prev) => prev.filter((id) => id !== screenshotId)); } }; + const handleBatchExtract = async () => { + const ids = selectedRowKeys.map((k) => Number(k)).filter((id) => Number.isFinite(id)); + if (!ids.length) { + message.warning("请先勾选要识别的截图"); + return; + } + setExtractingIds((prev) => Array.from(new Set([...prev, ...ids]))); + const started = Date.now(); + let ok = 0; + let fail = 0; + for (const id of ids) { + try { + await api.screenshots.extract(caseId, id); + ok += 1; + } catch { + fail += 1; + } + await loadScreenshots(); + } + setExtractingIds((prev) => prev.filter((id) => !ids.includes(id))); + onExtracted?.(); + const elapsed = Date.now() - started; + message.info(`批量识别结束:成功 ${ok},失败 ${fail},总耗时 ${(elapsed / 1000).toFixed(2)} s`); + }; + useEffect(() => { - if (caseId) loadScreenshots(); + if (!caseId) return; + loadScreenshots(); + const timer = window.setInterval(loadScreenshots, 1500); + return () => window.clearInterval(timer); }, [caseId]); + const rowSelection = { + selectedRowKeys, + onChange: (keys: Key[]) => setSelectedRowKeys(keys), + getCheckboxProps: (record: ScreenshotItem) => ({ + disabled: record.status === "processing", + }), + }; + + const columns: ColumnsType = useMemo(() => [ + { + title: "截图", + dataIndex: "filename", + key: "filename", + render: (v: string) => ( + + {v} + + ), + }, + { + title: "状态", + key: "status", + width: 110, + render: (_, r) => statusTag(r), + }, + { + title: "识别进度", + key: "progress", + width: 280, + render: (_, r) => ( +
+ +
{r.progress_detail || "-"}
+ {r.progress_step &&
步骤: {r.progress_step}
} +
+ ), + }, + { + title: "耗时", + key: "duration", + width: 120, + render: (_, r) => formatDuration(r.duration_ms), + }, + { + title: "错误信息", + key: "error_message", + render: (_, r) => + r.status === "failed" && r.error_message ? ( + + ) : ( + "-" + ), + }, + { + title: "操作", + key: "action", + width: 120, + render: (_, r) => ( + (r.status === "pending" || r.status === "failed") && ( + + ) + ), + }, + ], [extractingIds]); + return (
{ handleUpload(file as File); return false; }} + beforeUpload={(file) => { + handleUpload(file as File); + return false; + }} disabled={loading} >

点击或拖拽账单截图到此处上传

支持 png / jpg / webp,单次可多选

-
- - ( - - -
- - {item.status === "extracted" ? "已识别" : item.status === "failed" ? "失败" : "待识别"} - -
- {item.status === "pending" && ( - - )} -
-
- )} - /> + +
+ + + +
+ + ); } diff --git a/frontend/src/index.css b/frontend/src/index.css index 6724f23..6417080 100644 --- a/frontend/src/index.css +++ b/frontend/src/index.css @@ -3,5 +3,7 @@ } body { margin: 0; + font-size: 16px; + line-height: 1.6; font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, "Helvetica Neue", Arial, sans-serif; } diff --git a/frontend/src/pages/Settings.tsx b/frontend/src/pages/Settings.tsx index de70a70..091fdf7 100644 --- a/frontend/src/pages/Settings.tsx +++ b/frontend/src/pages/Settings.tsx @@ -1,12 +1,20 @@ import { useEffect, useState } from "react"; -import { Card, Form, Input, Select, Button, Alert, Space, message } from "antd"; +import { Card, Form, Input, Select, Button, Alert, Space, Divider, Descriptions, message } from "antd"; import { api, type RuntimeSettings, + type ProviderKey, getApiBaseUrl, setApiBaseUrl, } from "../services/api"; +const PROVIDER_OPTIONS = [ + { label: "OpenAI", value: "openai" as ProviderKey }, + { label: "Anthropic", value: "anthropic" as ProviderKey }, + { label: "DeepSeek", value: "deepseek" as ProviderKey }, + { label: "自定义(OpenAI兼容)", value: "custom_openai" as ProviderKey }, +]; + export default function Settings() { const [form] = Form.useForm(); const [loading, setLoading] = useState(false); @@ -20,9 +28,11 @@ export default function Settings() { setRuntime(data); form.setFieldsValue({ system_api_base_url: getApiBaseUrl(), - llm_provider: data.llm_provider, + ocr_provider: data.ocr_provider, + ocr_model: data.ocr_model, + inference_provider: data.inference_provider, + inference_model: data.inference_model, custom_openai_base_url: data.base_urls?.custom_openai || "", - custom_openai_model: data.models?.custom_openai || "gpt-4o-mini", }); } catch { message.error("加载设置失败"); @@ -35,21 +45,15 @@ export default function Settings() { loadSettings(); }, []); - const onFinish = async (values: { - system_api_base_url?: string; - llm_provider: "openai" | "anthropic" | "deepseek" | "custom_openai"; - openai_api_key?: string; - anthropic_api_key?: string; - deepseek_api_key?: string; - custom_openai_api_key?: string; - custom_openai_base_url?: string; - custom_openai_model?: string; - }) => { + const onFinish = async (values: Record) => { setSaving(true); try { setApiBaseUrl(values.system_api_base_url || ""); - const payload = { - llm_provider: values.llm_provider, + const payload: Record = { + ocr_provider: values.ocr_provider, + ocr_model: values.ocr_model?.trim() || undefined, + inference_provider: values.inference_provider, + inference_model: values.inference_model?.trim() || undefined, openai_api_key: values.openai_api_key?.trim() || undefined, anthropic_api_key: values.anthropic_api_key?.trim() || undefined, deepseek_api_key: values.deepseek_api_key?.trim() || undefined, @@ -59,7 +63,7 @@ export default function Settings() { }; const data = await api.settings.update(payload); setRuntime(data); - message.success("设置已保存并生效(含系统 API BaseURL)"); + message.success("设置已保存"); } catch { message.error("保存失败"); } finally { @@ -68,35 +72,42 @@ export default function Settings() { }; return ( - +
- + - - + + + + + 推理模型(报告生成等文本推理) + + + + + + API Key 与厂商配置 + @@ -113,29 +124,33 @@ export default function Settings() { > - + + - + {runtime && ( - -
系统 API BaseURL: {getApiBaseUrl()}
-
当前提供商: {runtime.llm_provider}
-
OpenAI Key: {runtime.has_keys.openai ? "已配置" : "未配置"}
-
Anthropic Key: {runtime.has_keys.anthropic ? "已配置" : "未配置"}
-
DeepSeek Key: {runtime.has_keys.deepseek ? "已配置" : "未配置"}
-
自定义厂商 Key: {runtime.has_keys.custom_openai ? "已配置" : "未配置"}
-
自定义厂商 BaseURL: {runtime.base_urls.custom_openai || "-"}
+ + + {getApiBaseUrl()} + {runtime.base_urls.custom_openai || "-"} + {runtime.ocr_provider} + {runtime.ocr_model} + {runtime.inference_provider} + {runtime.inference_model} + {runtime.has_keys.openai ? "已配置" : "未配置"} + {runtime.has_keys.anthropic ? "已配置" : "未配置"} + {runtime.has_keys.deepseek ? "已配置" : "未配置"} + {runtime.has_keys.custom_openai ? "已配置" : "未配置"} + )}
diff --git a/frontend/src/services/api.ts b/frontend/src/services/api.ts index 66e9771..448e1c5 100644 --- a/frontend/src/services/api.ts +++ b/frontend/src/services/api.ts @@ -56,6 +56,13 @@ export interface ScreenshotItem { filename: string; file_path: string; status: string; + progress_step: string | null; + progress_percent: number; + progress_detail: string | null; + started_at: string | null; + finished_at: string | null; + duration_ms: number | null; + error_message: string | null; created_at: string; } @@ -72,16 +79,24 @@ export interface FlowGraph { edges: Array<{ source: string; target: string; amount: number; count?: number }>; } +export type ProviderKey = "openai" | "anthropic" | "deepseek" | "custom_openai"; + export interface RuntimeSettings { - llm_provider: "openai" | "anthropic" | "deepseek" | "custom_openai"; - providers: Array<"openai" | "anthropic" | "deepseek" | "custom_openai">; - models: Record; + ocr_provider: ProviderKey; + ocr_model: string; + inference_provider: ProviderKey; + inference_model: string; + providers: ProviderKey[]; + provider_defaults: Record; base_urls: Record; has_keys: Record; } export interface RuntimeSettingsUpdate { - llm_provider?: "openai" | "anthropic" | "deepseek" | "custom_openai"; + ocr_provider?: ProviderKey; + ocr_model?: string; + inference_provider?: ProviderKey; + inference_model?: string; openai_api_key?: string; anthropic_api_key?: string; deepseek_api_key?: string;