"""Screenshot upload and extraction API.""" import uuid from datetime import datetime from pathlib import Path from fastapi import APIRouter, Depends, HTTPException, UploadFile, File from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from app.config import get_settings from app.models.database import get_db from app.models import Case, Screenshot, Transaction from app.schemas import ScreenshotResponse, ScreenshotListResponse, TransactionListResponse from app.services.extractor import extract_and_save router = APIRouter() def _allowed(filename: str) -> bool: ext = (Path(filename).suffix or "").lstrip(".").lower() return ext in get_settings().allowed_extensions @router.get("/{case_id}/screenshots", response_model=ScreenshotListResponse) async def list_screenshots(case_id: int, db: AsyncSession = Depends(get_db)): r = await db.execute(select(Case).where(Case.id == case_id)) if not r.scalar_one_or_none(): raise HTTPException(status_code=404, detail="Case not found") r = await db.execute(select(Screenshot).where(Screenshot.case_id == case_id).order_by(Screenshot.created_at)) screenshots = r.scalars().all() return ScreenshotListResponse(items=[ScreenshotResponse.model_validate(s) for s in screenshots]) @router.post("/{case_id}/screenshots", response_model=ScreenshotListResponse) async def upload_screenshots( case_id: int, files: list[UploadFile] = File(...), db: AsyncSession = Depends(get_db), ): r = await db.execute(select(Case).where(Case.id == case_id)) case = r.scalar_one_or_none() if not case: raise HTTPException(status_code=404, detail="Case not found") settings = get_settings() upload_dir = settings.upload_dir.resolve() case_dir = upload_dir / str(case_id) case_dir.mkdir(parents=True, exist_ok=True) created: list[Screenshot] = [] for f in files: if not f.filename or not _allowed(f.filename): continue stem = uuid.uuid4().hex[:12] suffix = Path(f.filename).suffix path = case_dir / f"{stem}{suffix}" content = await f.read() path.write_bytes(content) rel_path = str(path.relative_to(upload_dir)) screenshot = Screenshot( case_id=case_id, filename=f.filename, file_path=rel_path, status="pending", ) db.add(screenshot) created.append(screenshot) await db.commit() for s in created: await db.refresh(s) return ScreenshotListResponse(items=[ScreenshotResponse.model_validate(s) for s in created]) @router.post("/{case_id}/screenshots/{screenshot_id}/extract", response_model=TransactionListResponse) async def extract_transactions( case_id: int, screenshot_id: int, db: AsyncSession = Depends(get_db), ): r = await db.execute(select(Screenshot).where(Screenshot.id == screenshot_id, Screenshot.case_id == case_id)) screenshot = r.scalar_one_or_none() if not screenshot: raise HTTPException(status_code=404, detail="Screenshot not found") settings = get_settings() full_path = settings.upload_dir.resolve() / screenshot.file_path if not full_path.exists(): raise HTTPException(status_code=404, detail="File not found on disk") image_bytes = full_path.read_bytes() started_at = datetime.utcnow() # 每次开始新一轮识别都重置计时,确保耗时是“本次分析”而不是历史累计 screenshot.started_at = started_at screenshot.finished_at = None screenshot.duration_ms = None screenshot.error_message = None screenshot.progress_step = "starting" screenshot.progress_percent = 0 screenshot.progress_detail = "准备开始识别" await db.commit() async def update_progress(step: str, percent: int, detail: str): screenshot.status = "processing" screenshot.progress_step = step screenshot.progress_percent = percent screenshot.progress_detail = detail await db.commit() try: await update_progress("file_loaded", 10, "截图读取完成") transactions = await extract_and_save( case_id, screenshot_id, image_bytes, progress_hook=update_progress, ) except Exception as e: error_detail = _classify_error(e) r = await db.execute(select(Screenshot).where(Screenshot.id == screenshot_id)) sc = r.scalar_one_or_none() if sc: sc.status = "failed" sc.progress_step = "failed" sc.progress_percent = 100 sc.progress_detail = "识别失败" sc.finished_at = datetime.utcnow() if sc.started_at: sc.duration_ms = int((sc.finished_at - sc.started_at).total_seconds() * 1000) sc.error_message = error_detail await db.commit() raise HTTPException(status_code=502, detail=error_detail) r = await db.execute(select(Screenshot).where(Screenshot.id == screenshot_id)) sc = r.scalar_one_or_none() if sc: sc.status = "extracted" sc.progress_step = "completed" sc.progress_percent = 100 sc.progress_detail = "识别完成" sc.finished_at = datetime.utcnow() if sc.started_at: sc.duration_ms = int((sc.finished_at - sc.started_at).total_seconds() * 1000) sc.error_message = None await db.commit() return TransactionListResponse(items=transactions) def _classify_error(e: Exception) -> str: """Produce a human-readable, categorized error message.""" name = type(e).__name__ msg = str(e) if isinstance(e, ValueError): return f"配置错误: {msg}" # OpenAI SDK errors try: from openai import AuthenticationError, RateLimitError, APIConnectionError, BadRequestError, APIStatusError, APITimeoutError if isinstance(e, AuthenticationError): return f"API Key 无效或已过期 ({name}): {msg}" if isinstance(e, RateLimitError): return f"API 调用频率超限,请稍后重试 ({name}): {msg}" if isinstance(e, APITimeoutError): return f"模型服务响应超时,请检查 BaseURL/模型可用性或稍后重试 ({name}): {msg}" if isinstance(e, APIConnectionError): return f"无法连接到模型服务,请检查网络或 BaseURL ({name}): {msg}" if isinstance(e, BadRequestError): return f"请求被模型服务拒绝(可能模型名错误或不支持图片) ({name}): {msg}" if isinstance(e, APIStatusError): return f"模型服务返回错误 (HTTP {e.status_code}): {msg}" except ImportError: pass # Anthropic SDK errors try: from anthropic import AuthenticationError as AnthAuthError, RateLimitError as AnthRateError from anthropic import APIConnectionError as AnthConnError, BadRequestError as AnthBadReq if isinstance(e, AnthAuthError): return f"Anthropic API Key 无效或已过期: {msg}" if isinstance(e, AnthRateError): return f"Anthropic API 调用频率超限: {msg}" if isinstance(e, AnthConnError): return f"无法连接到 Anthropic 服务: {msg}" if isinstance(e, AnthBadReq): return f"Anthropic 请求被拒绝: {msg}" except ImportError: pass # Connection / network if "connect" in msg.lower() or "timeout" in msg.lower(): return f"网络连接失败或超时: {msg}" return f"识别失败 ({name}): {msg}"