Files
fund-tracer/backend/app/api/screenshots.py
2026-03-10 14:25:21 +08:00

190 lines
7.4 KiB
Python

"""Screenshot upload and extraction API."""
import uuid
from datetime import datetime
from pathlib import Path
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.config import get_settings
from app.models.database import get_db
from app.models import Case, Screenshot, Transaction
from app.schemas import ScreenshotResponse, ScreenshotListResponse, TransactionListResponse
from app.services.extractor import extract_and_save
router = APIRouter()
def _allowed(filename: str) -> bool:
ext = (Path(filename).suffix or "").lstrip(".").lower()
return ext in get_settings().allowed_extensions
@router.get("/{case_id}/screenshots", response_model=ScreenshotListResponse)
async def list_screenshots(case_id: int, db: AsyncSession = Depends(get_db)):
r = await db.execute(select(Case).where(Case.id == case_id))
if not r.scalar_one_or_none():
raise HTTPException(status_code=404, detail="Case not found")
r = await db.execute(select(Screenshot).where(Screenshot.case_id == case_id).order_by(Screenshot.created_at))
screenshots = r.scalars().all()
return ScreenshotListResponse(items=[ScreenshotResponse.model_validate(s) for s in screenshots])
@router.post("/{case_id}/screenshots", response_model=ScreenshotListResponse)
async def upload_screenshots(
case_id: int,
files: list[UploadFile] = File(...),
db: AsyncSession = Depends(get_db),
):
r = await db.execute(select(Case).where(Case.id == case_id))
case = r.scalar_one_or_none()
if not case:
raise HTTPException(status_code=404, detail="Case not found")
settings = get_settings()
upload_dir = settings.upload_dir.resolve()
case_dir = upload_dir / str(case_id)
case_dir.mkdir(parents=True, exist_ok=True)
created: list[Screenshot] = []
for f in files:
if not f.filename or not _allowed(f.filename):
continue
stem = uuid.uuid4().hex[:12]
suffix = Path(f.filename).suffix
path = case_dir / f"{stem}{suffix}"
content = await f.read()
path.write_bytes(content)
rel_path = str(path.relative_to(upload_dir))
screenshot = Screenshot(
case_id=case_id,
filename=f.filename,
file_path=rel_path,
status="pending",
)
db.add(screenshot)
created.append(screenshot)
await db.commit()
for s in created:
await db.refresh(s)
return ScreenshotListResponse(items=[ScreenshotResponse.model_validate(s) for s in created])
@router.post("/{case_id}/screenshots/{screenshot_id}/extract", response_model=TransactionListResponse)
async def extract_transactions(
case_id: int,
screenshot_id: int,
db: AsyncSession = Depends(get_db),
):
r = await db.execute(select(Screenshot).where(Screenshot.id == screenshot_id, Screenshot.case_id == case_id))
screenshot = r.scalar_one_or_none()
if not screenshot:
raise HTTPException(status_code=404, detail="Screenshot not found")
settings = get_settings()
full_path = settings.upload_dir.resolve() / screenshot.file_path
if not full_path.exists():
raise HTTPException(status_code=404, detail="File not found on disk")
image_bytes = full_path.read_bytes()
started_at = datetime.utcnow()
# 每次开始新一轮识别都重置计时,确保耗时是“本次分析”而不是历史累计
screenshot.started_at = started_at
screenshot.finished_at = None
screenshot.duration_ms = None
screenshot.error_message = None
screenshot.progress_step = "starting"
screenshot.progress_percent = 0
screenshot.progress_detail = "准备开始识别"
await db.commit()
async def update_progress(step: str, percent: int, detail: str):
screenshot.status = "processing"
screenshot.progress_step = step
screenshot.progress_percent = percent
screenshot.progress_detail = detail
await db.commit()
try:
await update_progress("file_loaded", 10, "截图读取完成")
transactions = await extract_and_save(
case_id,
screenshot_id,
image_bytes,
progress_hook=update_progress,
)
except Exception as e:
error_detail = _classify_error(e)
r = await db.execute(select(Screenshot).where(Screenshot.id == screenshot_id))
sc = r.scalar_one_or_none()
if sc:
sc.status = "failed"
sc.progress_step = "failed"
sc.progress_percent = 100
sc.progress_detail = "识别失败"
sc.finished_at = datetime.utcnow()
if sc.started_at:
sc.duration_ms = int((sc.finished_at - sc.started_at).total_seconds() * 1000)
sc.error_message = error_detail
await db.commit()
raise HTTPException(status_code=502, detail=error_detail)
r = await db.execute(select(Screenshot).where(Screenshot.id == screenshot_id))
sc = r.scalar_one_or_none()
if sc:
sc.status = "extracted"
sc.progress_step = "completed"
sc.progress_percent = 100
sc.progress_detail = "识别完成"
sc.finished_at = datetime.utcnow()
if sc.started_at:
sc.duration_ms = int((sc.finished_at - sc.started_at).total_seconds() * 1000)
sc.error_message = None
await db.commit()
return TransactionListResponse(items=transactions)
def _classify_error(e: Exception) -> str:
"""Produce a human-readable, categorized error message."""
name = type(e).__name__
msg = str(e)
if isinstance(e, ValueError):
return f"配置错误: {msg}"
# OpenAI SDK errors
try:
from openai import AuthenticationError, RateLimitError, APIConnectionError, BadRequestError, APIStatusError, APITimeoutError
if isinstance(e, AuthenticationError):
return f"API Key 无效或已过期 ({name}): {msg}"
if isinstance(e, RateLimitError):
return f"API 调用频率超限,请稍后重试 ({name}): {msg}"
if isinstance(e, APITimeoutError):
return f"模型服务响应超时,请检查 BaseURL/模型可用性或稍后重试 ({name}): {msg}"
if isinstance(e, APIConnectionError):
return f"无法连接到模型服务,请检查网络或 BaseURL ({name}): {msg}"
if isinstance(e, BadRequestError):
return f"请求被模型服务拒绝(可能模型名错误或不支持图片) ({name}): {msg}"
if isinstance(e, APIStatusError):
return f"模型服务返回错误 (HTTP {e.status_code}): {msg}"
except ImportError:
pass
# Anthropic SDK errors
try:
from anthropic import AuthenticationError as AnthAuthError, RateLimitError as AnthRateError
from anthropic import APIConnectionError as AnthConnError, BadRequestError as AnthBadReq
if isinstance(e, AnthAuthError):
return f"Anthropic API Key 无效或已过期: {msg}"
if isinstance(e, AnthRateError):
return f"Anthropic API 调用频率超限: {msg}"
if isinstance(e, AnthConnError):
return f"无法连接到 Anthropic 服务: {msg}"
if isinstance(e, AnthBadReq):
return f"Anthropic 请求被拒绝: {msg}"
except ImportError:
pass
# Connection / network
if "connect" in msg.lower() or "timeout" in msg.lower():
return f"网络连接失败或超时: {msg}"
return f"识别失败 ({name}): {msg}"