Files
problem-bank/backend/routers/imports.py
2026-03-06 15:52:34 +08:00

333 lines
12 KiB
Python

from datetime import datetime
from pathlib import Path
from threading import Lock
from typing import Literal
from uuid import uuid4
from fastapi import APIRouter, BackgroundTasks, Depends, File, Form, HTTPException, Query, UploadFile
from sqlalchemy.orm import Session
from backend.auth import get_current_user
from backend.database import SessionLocal, get_db
from backend.models import ImportHistory, Question
from backend.repositories import import_job_repo
from backend.schemas import ImportHistoryOut, ImportJobOut, QuestionOut
from backend.services.excel_service import create_template_bytes, parse_excel_file
from backend.services.file_utils import save_upload, save_upload_for_job
from backend.services.parser import OpenAICompatibleParserService, extract_metadata
router = APIRouter(dependencies=[Depends(get_current_user)])
BATCH_TASKS: dict[str, dict] = {}
BATCH_TASKS_LOCK = Lock()
BatchMethod = Literal["excel", "ai"]
def _set_task(task_id: str, **fields) -> None:
with BATCH_TASKS_LOCK:
task = BATCH_TASKS.get(task_id)
if not task:
return
task.update(fields)
def _build_ai_rows(path: Path) -> list[dict]:
parser = OpenAICompatibleParserService()
metadata = extract_metadata(path.name)
questions = parser.parse_file(str(path))
rows = []
for q in questions:
rows.append(
{
"chapter": metadata["chapter"],
"primary_knowledge": "",
"secondary_knowledge": metadata["secondary_knowledge"],
"question_type": metadata["question_type"],
"difficulty": metadata["difficulty"],
"stem": q.get("题干", ""),
"option_a": q.get("选项A", ""),
"option_b": q.get("选项B", ""),
"option_c": q.get("选项C", ""),
"option_d": q.get("选项D", ""),
"answer": q.get("正确答案", ""),
"explanation": q.get("解析", ""),
"notes": q.get("备注", ""),
"source_file": metadata["source_file"],
}
)
return rows
def _run_batch_import(task_id: str, files: list[dict], method: BatchMethod) -> None:
db = SessionLocal()
try:
for index, item in enumerate(files, start=1):
path = Path(item["path"])
filename = item["filename"]
_set_task(task_id, current_file=filename)
try:
if method == "excel":
if path.suffix.lower() not in [".xlsx", ".xlsm", ".xltx", ".xltm"]:
raise ValueError("仅支持 Excel 文件")
rows = parse_excel_file(path)
else:
rows = _build_ai_rows(path)
questions = [Question(**row) for row in rows]
if questions:
db.add_all(questions)
db.add(
ImportHistory(
filename=filename,
method=method,
question_count=len(questions),
status="success",
)
)
db.commit()
with BATCH_TASKS_LOCK:
task = BATCH_TASKS[task_id]
task["success_count"] += 1
task["total_questions"] += len(questions)
task["results"].append(
{
"filename": filename,
"status": "success",
"question_count": len(questions),
}
)
except Exception as exc:
db.rollback()
db.add(
ImportHistory(
filename=filename,
method=method,
question_count=0,
status="failed",
)
)
db.commit()
with BATCH_TASKS_LOCK:
task = BATCH_TASKS[task_id]
task["failed_count"] += 1
task["results"].append(
{
"filename": filename,
"status": "failed",
"error": str(exc),
"question_count": 0,
}
)
finally:
_set_task(task_id, processed=index)
_set_task(
task_id,
status="completed",
current_file="",
ended_at=datetime.utcnow().isoformat(),
)
except Exception as exc:
_set_task(
task_id,
status="failed",
error=str(exc),
ended_at=datetime.utcnow().isoformat(),
)
finally:
db.close()
@router.post("/ai/parse")
def parse_by_ai(file: UploadFile = File(...)) -> dict:
path = save_upload(file)
parser = OpenAICompatibleParserService()
metadata = extract_metadata(file.filename or path.name)
try:
questions = parser.parse_file(str(path))
except ValueError as exc:
raise HTTPException(status_code=502, detail=str(exc)) from exc
preview = []
for q in questions:
preview.append(
{
"chapter": metadata["chapter"],
"primary_knowledge": "",
"secondary_knowledge": metadata["secondary_knowledge"],
"question_type": metadata["question_type"],
"difficulty": metadata["difficulty"],
"stem": q.get("题干", ""),
"option_a": q.get("选项A", ""),
"option_b": q.get("选项B", ""),
"option_c": q.get("选项C", ""),
"option_d": q.get("选项D", ""),
"answer": q.get("正确答案", ""),
"explanation": q.get("解析", ""),
"notes": q.get("备注", ""),
"source_file": metadata["source_file"],
}
)
return {"filename": file.filename, "preview": preview}
@router.post("/ai/confirm", response_model=list[QuestionOut])
def confirm_ai_import(payload: list[dict], db: Session = Depends(get_db)) -> list[QuestionOut]:
if not payload:
raise HTTPException(status_code=400, detail="没有可导入数据")
items = [Question(**item) for item in payload]
db.add_all(items)
db.add(
ImportHistory(
filename=items[0].source_file if items else "",
method="ai",
question_count=len(items),
status="success",
)
)
db.commit()
for item in items:
db.refresh(item)
return [QuestionOut.model_validate(item) for item in items]
@router.post("/excel", response_model=list[QuestionOut])
def import_excel(file: UploadFile = File(...), db: Session = Depends(get_db)) -> list[QuestionOut]:
path = save_upload(file)
if Path(path).suffix.lower() not in [".xlsx", ".xlsm", ".xltx", ".xltm"]:
raise HTTPException(status_code=400, detail="仅支持 Excel 文件")
rows = parse_excel_file(Path(path))
items = [Question(**row) for row in rows]
db.add_all(items)
db.add(
ImportHistory(
filename=file.filename or "",
method="excel",
question_count=len(items),
status="success",
)
)
db.commit()
for item in items:
db.refresh(item)
return [QuestionOut.model_validate(item) for item in items]
@router.get("/template")
def download_template() -> dict:
content = create_template_bytes()
return {"filename": "question_template.xlsx", "content_base64": content.hex()}
@router.get("/history", response_model=list[ImportHistoryOut])
def import_history(db: Session = Depends(get_db)) -> list[ImportHistoryOut]:
rows = db.query(ImportHistory).order_by(ImportHistory.created_at.desc()).limit(100).all()
return [ImportHistoryOut.model_validate(r) for r in rows]
@router.post("/batch/start")
def start_batch_import(
background_tasks: BackgroundTasks,
files: list[UploadFile] = File(...),
method: BatchMethod = Form("ai"),
) -> dict:
if not files:
raise HTTPException(status_code=400, detail="请至少上传一个文件")
task_id = uuid4().hex
saved_files: list[dict] = []
for f in files:
saved_path = save_upload(f)
saved_files.append({"path": str(saved_path), "filename": f.filename or Path(saved_path).name})
with BATCH_TASKS_LOCK:
BATCH_TASKS[task_id] = {
"task_id": task_id,
"status": "running",
"method": method,
"total": len(saved_files),
"processed": 0,
"current_file": "",
"success_count": 0,
"failed_count": 0,
"total_questions": 0,
"results": [],
"error": "",
"started_at": datetime.utcnow().isoformat(),
"ended_at": "",
}
background_tasks.add_task(_run_batch_import, task_id, saved_files, method)
return {"task_id": task_id}
@router.get("/batch/{task_id}")
def get_batch_import_progress(task_id: str) -> dict:
with BATCH_TASKS_LOCK:
task = BATCH_TASKS.get(task_id)
if not task:
raise HTTPException(status_code=404, detail="任务不存在")
total = task["total"] or 1
progress = round((task["processed"] / total) * 100, 2)
return {**task, "progress": progress}
# ----- 持久化队列 Jobs API -----
@router.post("/jobs", response_model=ImportJobOut)
def create_import_job(
files: list[UploadFile] = File(...),
method: BatchMethod = Form("ai"),
db: Session = Depends(get_db),
) -> ImportJobOut:
if not files:
raise HTTPException(status_code=400, detail="请至少上传一个文件")
job = import_job_repo.create_job_empty(db, method)
items: list[dict] = []
for seq, f in enumerate(files, start=1):
path = save_upload_for_job(job.id, seq, f)
items.append({"filename": f.filename or path.name, "stored_path": str(path)})
import_job_repo.add_job_items(db, job.id, items)
job = import_job_repo.get_job(db, job.id)
return ImportJobOut.model_validate(job)
@router.get("/jobs/{job_id}", response_model=ImportJobOut)
def get_import_job(job_id: int, db: Session = Depends(get_db)) -> ImportJobOut:
job = import_job_repo.get_job(db, job_id)
if not job:
raise HTTPException(status_code=404, detail="任务不存在")
return ImportJobOut.model_validate(job)
@router.get("/jobs", response_model=list[ImportJobOut])
def list_import_jobs(
status: str | None = Query(None, description="queued,running 等,逗号分隔"),
limit: int = Query(50, le=100),
db: Session = Depends(get_db),
) -> list[ImportJobOut]:
statuses = [s.strip() for s in status.split(",") if s.strip()] if status else None
jobs = import_job_repo.list_jobs(db, statuses=statuses, limit=limit)
return [ImportJobOut.model_validate(j) for j in jobs]
@router.post("/jobs/{job_id}/retry", response_model=ImportJobOut)
def retry_import_job(job_id: int, db: Session = Depends(get_db)) -> ImportJobOut:
job = import_job_repo.get_job(db, job_id)
if not job:
raise HTTPException(status_code=404, detail="任务不存在")
failed = import_job_repo.get_failed_items(db, job_id)
if not failed:
raise HTTPException(status_code=400, detail="没有可重试的失败项")
items = [{"filename": it.filename, "stored_path": it.stored_path} for it in failed]
new_job = import_job_repo.create_job(db, job.method, items)
new_job = import_job_repo.get_job(db, new_job.id)
return ImportJobOut.model_validate(new_job)
@router.post("/jobs/{job_id}/cancel", response_model=ImportJobOut)
def cancel_import_job(job_id: int, db: Session = Depends(get_db)) -> ImportJobOut:
job = import_job_repo.cancel_job(db, job_id)
if not job:
raise HTTPException(status_code=400, detail="任务不存在或无法取消")
return ImportJobOut.model_validate(job)