from datetime import datetime from pathlib import Path from threading import Lock from typing import Literal from uuid import uuid4 from fastapi import APIRouter, BackgroundTasks, Depends, File, Form, HTTPException, Query, UploadFile from sqlalchemy.orm import Session from backend.auth import get_current_user from backend.database import SessionLocal, get_db from backend.models import ImportHistory, Question from backend.repositories import import_job_repo from backend.schemas import ImportHistoryOut, ImportJobOut, QuestionOut from backend.services.excel_service import create_template_bytes, parse_excel_file from backend.services.file_utils import save_upload, save_upload_for_job from backend.services.parser import OpenAICompatibleParserService, extract_metadata router = APIRouter(dependencies=[Depends(get_current_user)]) BATCH_TASKS: dict[str, dict] = {} BATCH_TASKS_LOCK = Lock() BatchMethod = Literal["excel", "ai"] def _set_task(task_id: str, **fields) -> None: with BATCH_TASKS_LOCK: task = BATCH_TASKS.get(task_id) if not task: return task.update(fields) def _build_ai_rows(path: Path) -> list[dict]: parser = OpenAICompatibleParserService() metadata = extract_metadata(path.name) questions = parser.parse_file(str(path)) rows = [] for q in questions: rows.append( { "chapter": metadata["chapter"], "primary_knowledge": "", "secondary_knowledge": metadata["secondary_knowledge"], "question_type": metadata["question_type"], "difficulty": metadata["difficulty"], "stem": q.get("题干", ""), "option_a": q.get("选项A", ""), "option_b": q.get("选项B", ""), "option_c": q.get("选项C", ""), "option_d": q.get("选项D", ""), "answer": q.get("正确答案", ""), "explanation": q.get("解析", ""), "notes": q.get("备注", ""), "source_file": metadata["source_file"], } ) return rows def _run_batch_import(task_id: str, files: list[dict], method: BatchMethod) -> None: db = SessionLocal() try: for index, item in enumerate(files, start=1): path = Path(item["path"]) filename = item["filename"] _set_task(task_id, current_file=filename) try: if method == "excel": if path.suffix.lower() not in [".xlsx", ".xlsm", ".xltx", ".xltm"]: raise ValueError("仅支持 Excel 文件") rows = parse_excel_file(path) else: rows = _build_ai_rows(path) questions = [Question(**row) for row in rows] if questions: db.add_all(questions) db.add( ImportHistory( filename=filename, method=method, question_count=len(questions), status="success", ) ) db.commit() with BATCH_TASKS_LOCK: task = BATCH_TASKS[task_id] task["success_count"] += 1 task["total_questions"] += len(questions) task["results"].append( { "filename": filename, "status": "success", "question_count": len(questions), } ) except Exception as exc: db.rollback() db.add( ImportHistory( filename=filename, method=method, question_count=0, status="failed", ) ) db.commit() with BATCH_TASKS_LOCK: task = BATCH_TASKS[task_id] task["failed_count"] += 1 task["results"].append( { "filename": filename, "status": "failed", "error": str(exc), "question_count": 0, } ) finally: _set_task(task_id, processed=index) _set_task( task_id, status="completed", current_file="", ended_at=datetime.utcnow().isoformat(), ) except Exception as exc: _set_task( task_id, status="failed", error=str(exc), ended_at=datetime.utcnow().isoformat(), ) finally: db.close() @router.post("/ai/parse") def parse_by_ai(file: UploadFile = File(...)) -> dict: path = save_upload(file) parser = OpenAICompatibleParserService() metadata = extract_metadata(file.filename or path.name) try: questions = parser.parse_file(str(path)) except ValueError as exc: raise HTTPException(status_code=502, detail=str(exc)) from exc preview = [] for q in questions: preview.append( { "chapter": metadata["chapter"], "primary_knowledge": "", "secondary_knowledge": metadata["secondary_knowledge"], "question_type": metadata["question_type"], "difficulty": metadata["difficulty"], "stem": q.get("题干", ""), "option_a": q.get("选项A", ""), "option_b": q.get("选项B", ""), "option_c": q.get("选项C", ""), "option_d": q.get("选项D", ""), "answer": q.get("正确答案", ""), "explanation": q.get("解析", ""), "notes": q.get("备注", ""), "source_file": metadata["source_file"], } ) return {"filename": file.filename, "preview": preview} @router.post("/ai/confirm", response_model=list[QuestionOut]) def confirm_ai_import(payload: list[dict], db: Session = Depends(get_db)) -> list[QuestionOut]: if not payload: raise HTTPException(status_code=400, detail="没有可导入数据") items = [Question(**item) for item in payload] db.add_all(items) db.add( ImportHistory( filename=items[0].source_file if items else "", method="ai", question_count=len(items), status="success", ) ) db.commit() for item in items: db.refresh(item) return [QuestionOut.model_validate(item) for item in items] @router.post("/excel", response_model=list[QuestionOut]) def import_excel(file: UploadFile = File(...), db: Session = Depends(get_db)) -> list[QuestionOut]: path = save_upload(file) if Path(path).suffix.lower() not in [".xlsx", ".xlsm", ".xltx", ".xltm"]: raise HTTPException(status_code=400, detail="仅支持 Excel 文件") rows = parse_excel_file(Path(path)) items = [Question(**row) for row in rows] db.add_all(items) db.add( ImportHistory( filename=file.filename or "", method="excel", question_count=len(items), status="success", ) ) db.commit() for item in items: db.refresh(item) return [QuestionOut.model_validate(item) for item in items] @router.get("/template") def download_template() -> dict: content = create_template_bytes() return {"filename": "question_template.xlsx", "content_base64": content.hex()} @router.get("/history", response_model=list[ImportHistoryOut]) def import_history(db: Session = Depends(get_db)) -> list[ImportHistoryOut]: rows = db.query(ImportHistory).order_by(ImportHistory.created_at.desc()).limit(100).all() return [ImportHistoryOut.model_validate(r) for r in rows] @router.post("/batch/start") def start_batch_import( background_tasks: BackgroundTasks, files: list[UploadFile] = File(...), method: BatchMethod = Form("ai"), ) -> dict: if not files: raise HTTPException(status_code=400, detail="请至少上传一个文件") task_id = uuid4().hex saved_files: list[dict] = [] for f in files: saved_path = save_upload(f) saved_files.append({"path": str(saved_path), "filename": f.filename or Path(saved_path).name}) with BATCH_TASKS_LOCK: BATCH_TASKS[task_id] = { "task_id": task_id, "status": "running", "method": method, "total": len(saved_files), "processed": 0, "current_file": "", "success_count": 0, "failed_count": 0, "total_questions": 0, "results": [], "error": "", "started_at": datetime.utcnow().isoformat(), "ended_at": "", } background_tasks.add_task(_run_batch_import, task_id, saved_files, method) return {"task_id": task_id} @router.get("/batch/{task_id}") def get_batch_import_progress(task_id: str) -> dict: with BATCH_TASKS_LOCK: task = BATCH_TASKS.get(task_id) if not task: raise HTTPException(status_code=404, detail="任务不存在") total = task["total"] or 1 progress = round((task["processed"] / total) * 100, 2) return {**task, "progress": progress} # ----- 持久化队列 Jobs API ----- @router.post("/jobs", response_model=ImportJobOut) def create_import_job( files: list[UploadFile] = File(...), method: BatchMethod = Form("ai"), db: Session = Depends(get_db), ) -> ImportJobOut: if not files: raise HTTPException(status_code=400, detail="请至少上传一个文件") job = import_job_repo.create_job_empty(db, method) items: list[dict] = [] for seq, f in enumerate(files, start=1): path = save_upload_for_job(job.id, seq, f) items.append({"filename": f.filename or path.name, "stored_path": str(path)}) import_job_repo.add_job_items(db, job.id, items) job = import_job_repo.get_job(db, job.id) return ImportJobOut.model_validate(job) @router.get("/jobs/{job_id}", response_model=ImportJobOut) def get_import_job(job_id: int, db: Session = Depends(get_db)) -> ImportJobOut: job = import_job_repo.get_job(db, job_id) if not job: raise HTTPException(status_code=404, detail="任务不存在") return ImportJobOut.model_validate(job) @router.get("/jobs", response_model=list[ImportJobOut]) def list_import_jobs( status: str | None = Query(None, description="queued,running 等,逗号分隔"), limit: int = Query(50, le=100), db: Session = Depends(get_db), ) -> list[ImportJobOut]: statuses = [s.strip() for s in status.split(",") if s.strip()] if status else None jobs = import_job_repo.list_jobs(db, statuses=statuses, limit=limit) return [ImportJobOut.model_validate(j) for j in jobs] @router.post("/jobs/{job_id}/retry", response_model=ImportJobOut) def retry_import_job(job_id: int, db: Session = Depends(get_db)) -> ImportJobOut: job = import_job_repo.get_job(db, job_id) if not job: raise HTTPException(status_code=404, detail="任务不存在") failed = import_job_repo.get_failed_items(db, job_id) if not failed: raise HTTPException(status_code=400, detail="没有可重试的失败项") items = [{"filename": it.filename, "stored_path": it.stored_path} for it in failed] new_job = import_job_repo.create_job(db, job.method, items) new_job = import_job_repo.get_job(db, new_job.id) return ImportJobOut.model_validate(new_job) @router.post("/jobs/{job_id}/cancel", response_model=ImportJobOut) def cancel_import_job(job_id: int, db: Session = Depends(get_db)) -> ImportJobOut: job = import_job_repo.cancel_job(db, job_id) if not job: raise HTTPException(status_code=400, detail="任务不存在或无法取消") return ImportJobOut.model_validate(job)