190 lines
7.4 KiB
Python
190 lines
7.4 KiB
Python
"""Screenshot upload and extraction API."""
|
|
|
|
import uuid
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
|
|
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File
|
|
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.config import get_settings
|
|
from app.models.database import get_db
|
|
from app.models import Case, Screenshot, Transaction
|
|
from app.schemas import ScreenshotResponse, ScreenshotListResponse, TransactionListResponse
|
|
from app.services.extractor import extract_and_save
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
def _allowed(filename: str) -> bool:
|
|
ext = (Path(filename).suffix or "").lstrip(".").lower()
|
|
return ext in get_settings().allowed_extensions
|
|
|
|
|
|
@router.get("/{case_id}/screenshots", response_model=ScreenshotListResponse)
|
|
async def list_screenshots(case_id: int, db: AsyncSession = Depends(get_db)):
|
|
r = await db.execute(select(Case).where(Case.id == case_id))
|
|
if not r.scalar_one_or_none():
|
|
raise HTTPException(status_code=404, detail="Case not found")
|
|
r = await db.execute(select(Screenshot).where(Screenshot.case_id == case_id).order_by(Screenshot.created_at))
|
|
screenshots = r.scalars().all()
|
|
return ScreenshotListResponse(items=[ScreenshotResponse.model_validate(s) for s in screenshots])
|
|
|
|
|
|
@router.post("/{case_id}/screenshots", response_model=ScreenshotListResponse)
|
|
async def upload_screenshots(
|
|
case_id: int,
|
|
files: list[UploadFile] = File(...),
|
|
db: AsyncSession = Depends(get_db),
|
|
):
|
|
r = await db.execute(select(Case).where(Case.id == case_id))
|
|
case = r.scalar_one_or_none()
|
|
if not case:
|
|
raise HTTPException(status_code=404, detail="Case not found")
|
|
settings = get_settings()
|
|
upload_dir = settings.upload_dir.resolve()
|
|
case_dir = upload_dir / str(case_id)
|
|
case_dir.mkdir(parents=True, exist_ok=True)
|
|
created: list[Screenshot] = []
|
|
for f in files:
|
|
if not f.filename or not _allowed(f.filename):
|
|
continue
|
|
stem = uuid.uuid4().hex[:12]
|
|
suffix = Path(f.filename).suffix
|
|
path = case_dir / f"{stem}{suffix}"
|
|
content = await f.read()
|
|
path.write_bytes(content)
|
|
rel_path = str(path.relative_to(upload_dir))
|
|
screenshot = Screenshot(
|
|
case_id=case_id,
|
|
filename=f.filename,
|
|
file_path=rel_path,
|
|
status="pending",
|
|
)
|
|
db.add(screenshot)
|
|
created.append(screenshot)
|
|
await db.commit()
|
|
for s in created:
|
|
await db.refresh(s)
|
|
return ScreenshotListResponse(items=[ScreenshotResponse.model_validate(s) for s in created])
|
|
|
|
|
|
@router.post("/{case_id}/screenshots/{screenshot_id}/extract", response_model=TransactionListResponse)
|
|
async def extract_transactions(
|
|
case_id: int,
|
|
screenshot_id: int,
|
|
db: AsyncSession = Depends(get_db),
|
|
):
|
|
r = await db.execute(select(Screenshot).where(Screenshot.id == screenshot_id, Screenshot.case_id == case_id))
|
|
screenshot = r.scalar_one_or_none()
|
|
if not screenshot:
|
|
raise HTTPException(status_code=404, detail="Screenshot not found")
|
|
settings = get_settings()
|
|
full_path = settings.upload_dir.resolve() / screenshot.file_path
|
|
if not full_path.exists():
|
|
raise HTTPException(status_code=404, detail="File not found on disk")
|
|
image_bytes = full_path.read_bytes()
|
|
started_at = datetime.utcnow()
|
|
# 每次开始新一轮识别都重置计时,确保耗时是“本次分析”而不是历史累计
|
|
screenshot.started_at = started_at
|
|
screenshot.finished_at = None
|
|
screenshot.duration_ms = None
|
|
screenshot.error_message = None
|
|
screenshot.progress_step = "starting"
|
|
screenshot.progress_percent = 0
|
|
screenshot.progress_detail = "准备开始识别"
|
|
await db.commit()
|
|
|
|
async def update_progress(step: str, percent: int, detail: str):
|
|
screenshot.status = "processing"
|
|
screenshot.progress_step = step
|
|
screenshot.progress_percent = percent
|
|
screenshot.progress_detail = detail
|
|
await db.commit()
|
|
|
|
try:
|
|
await update_progress("file_loaded", 10, "截图读取完成")
|
|
transactions = await extract_and_save(
|
|
case_id,
|
|
screenshot_id,
|
|
image_bytes,
|
|
progress_hook=update_progress,
|
|
)
|
|
except Exception as e:
|
|
error_detail = _classify_error(e)
|
|
r = await db.execute(select(Screenshot).where(Screenshot.id == screenshot_id))
|
|
sc = r.scalar_one_or_none()
|
|
if sc:
|
|
sc.status = "failed"
|
|
sc.progress_step = "failed"
|
|
sc.progress_percent = 100
|
|
sc.progress_detail = "识别失败"
|
|
sc.finished_at = datetime.utcnow()
|
|
if sc.started_at:
|
|
sc.duration_ms = int((sc.finished_at - sc.started_at).total_seconds() * 1000)
|
|
sc.error_message = error_detail
|
|
await db.commit()
|
|
raise HTTPException(status_code=502, detail=error_detail)
|
|
r = await db.execute(select(Screenshot).where(Screenshot.id == screenshot_id))
|
|
sc = r.scalar_one_or_none()
|
|
if sc:
|
|
sc.status = "extracted"
|
|
sc.progress_step = "completed"
|
|
sc.progress_percent = 100
|
|
sc.progress_detail = "识别完成"
|
|
sc.finished_at = datetime.utcnow()
|
|
if sc.started_at:
|
|
sc.duration_ms = int((sc.finished_at - sc.started_at).total_seconds() * 1000)
|
|
sc.error_message = None
|
|
await db.commit()
|
|
return TransactionListResponse(items=transactions)
|
|
|
|
|
|
def _classify_error(e: Exception) -> str:
|
|
"""Produce a human-readable, categorized error message."""
|
|
name = type(e).__name__
|
|
msg = str(e)
|
|
|
|
if isinstance(e, ValueError):
|
|
return f"配置错误: {msg}"
|
|
|
|
# OpenAI SDK errors
|
|
try:
|
|
from openai import AuthenticationError, RateLimitError, APIConnectionError, BadRequestError, APIStatusError, APITimeoutError
|
|
if isinstance(e, AuthenticationError):
|
|
return f"API Key 无效或已过期 ({name}): {msg}"
|
|
if isinstance(e, RateLimitError):
|
|
return f"API 调用频率超限,请稍后重试 ({name}): {msg}"
|
|
if isinstance(e, APITimeoutError):
|
|
return f"模型服务响应超时,请检查 BaseURL/模型可用性或稍后重试 ({name}): {msg}"
|
|
if isinstance(e, APIConnectionError):
|
|
return f"无法连接到模型服务,请检查网络或 BaseURL ({name}): {msg}"
|
|
if isinstance(e, BadRequestError):
|
|
return f"请求被模型服务拒绝(可能模型名错误或不支持图片) ({name}): {msg}"
|
|
if isinstance(e, APIStatusError):
|
|
return f"模型服务返回错误 (HTTP {e.status_code}): {msg}"
|
|
except ImportError:
|
|
pass
|
|
|
|
# Anthropic SDK errors
|
|
try:
|
|
from anthropic import AuthenticationError as AnthAuthError, RateLimitError as AnthRateError
|
|
from anthropic import APIConnectionError as AnthConnError, BadRequestError as AnthBadReq
|
|
if isinstance(e, AnthAuthError):
|
|
return f"Anthropic API Key 无效或已过期: {msg}"
|
|
if isinstance(e, AnthRateError):
|
|
return f"Anthropic API 调用频率超限: {msg}"
|
|
if isinstance(e, AnthConnError):
|
|
return f"无法连接到 Anthropic 服务: {msg}"
|
|
if isinstance(e, AnthBadReq):
|
|
return f"Anthropic 请求被拒绝: {msg}"
|
|
except ImportError:
|
|
pass
|
|
|
|
# Connection / network
|
|
if "connect" in msg.lower() or "timeout" in msg.lower():
|
|
return f"网络连接失败或超时: {msg}"
|
|
|
|
return f"识别失败 ({name}): {msg}"
|