2026-03-06 15:52:34 +08:00
|
|
|
from datetime import datetime
|
2026-03-05 11:50:15 +08:00
|
|
|
from pathlib import Path
|
2026-03-06 15:52:34 +08:00
|
|
|
from threading import Lock
|
|
|
|
|
from typing import Literal
|
|
|
|
|
from uuid import uuid4
|
|
|
|
|
from fastapi import APIRouter, BackgroundTasks, Depends, File, Form, HTTPException, Query, UploadFile
|
2026-03-05 11:50:15 +08:00
|
|
|
from sqlalchemy.orm import Session
|
|
|
|
|
|
|
|
|
|
from backend.auth import get_current_user
|
2026-03-06 15:52:34 +08:00
|
|
|
from backend.database import SessionLocal, get_db
|
2026-03-05 11:50:15 +08:00
|
|
|
from backend.models import ImportHistory, Question
|
2026-03-06 15:52:34 +08:00
|
|
|
from backend.repositories import import_job_repo
|
|
|
|
|
from backend.schemas import ImportHistoryOut, ImportJobOut, QuestionOut
|
2026-03-05 11:50:15 +08:00
|
|
|
from backend.services.excel_service import create_template_bytes, parse_excel_file
|
2026-03-06 15:52:34 +08:00
|
|
|
from backend.services.file_utils import save_upload, save_upload_for_job
|
|
|
|
|
from backend.services.parser import OpenAICompatibleParserService, extract_metadata
|
2026-03-05 11:50:15 +08:00
|
|
|
|
|
|
|
|
router = APIRouter(dependencies=[Depends(get_current_user)])
|
|
|
|
|
|
2026-03-06 15:52:34 +08:00
|
|
|
BATCH_TASKS: dict[str, dict] = {}
|
|
|
|
|
BATCH_TASKS_LOCK = Lock()
|
|
|
|
|
BatchMethod = Literal["excel", "ai"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _set_task(task_id: str, **fields) -> None:
|
|
|
|
|
with BATCH_TASKS_LOCK:
|
|
|
|
|
task = BATCH_TASKS.get(task_id)
|
|
|
|
|
if not task:
|
|
|
|
|
return
|
|
|
|
|
task.update(fields)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _build_ai_rows(path: Path) -> list[dict]:
|
|
|
|
|
parser = OpenAICompatibleParserService()
|
|
|
|
|
metadata = extract_metadata(path.name)
|
|
|
|
|
questions = parser.parse_file(str(path))
|
|
|
|
|
rows = []
|
|
|
|
|
for q in questions:
|
|
|
|
|
rows.append(
|
|
|
|
|
{
|
|
|
|
|
"chapter": metadata["chapter"],
|
|
|
|
|
"primary_knowledge": "",
|
|
|
|
|
"secondary_knowledge": metadata["secondary_knowledge"],
|
|
|
|
|
"question_type": metadata["question_type"],
|
|
|
|
|
"difficulty": metadata["difficulty"],
|
|
|
|
|
"stem": q.get("题干", ""),
|
|
|
|
|
"option_a": q.get("选项A", ""),
|
|
|
|
|
"option_b": q.get("选项B", ""),
|
|
|
|
|
"option_c": q.get("选项C", ""),
|
|
|
|
|
"option_d": q.get("选项D", ""),
|
|
|
|
|
"answer": q.get("正确答案", ""),
|
|
|
|
|
"explanation": q.get("解析", ""),
|
|
|
|
|
"notes": q.get("备注", ""),
|
|
|
|
|
"source_file": metadata["source_file"],
|
|
|
|
|
}
|
|
|
|
|
)
|
|
|
|
|
return rows
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _run_batch_import(task_id: str, files: list[dict], method: BatchMethod) -> None:
|
|
|
|
|
db = SessionLocal()
|
|
|
|
|
try:
|
|
|
|
|
for index, item in enumerate(files, start=1):
|
|
|
|
|
path = Path(item["path"])
|
|
|
|
|
filename = item["filename"]
|
|
|
|
|
_set_task(task_id, current_file=filename)
|
|
|
|
|
try:
|
|
|
|
|
if method == "excel":
|
|
|
|
|
if path.suffix.lower() not in [".xlsx", ".xlsm", ".xltx", ".xltm"]:
|
|
|
|
|
raise ValueError("仅支持 Excel 文件")
|
|
|
|
|
rows = parse_excel_file(path)
|
|
|
|
|
else:
|
|
|
|
|
rows = _build_ai_rows(path)
|
|
|
|
|
|
|
|
|
|
questions = [Question(**row) for row in rows]
|
|
|
|
|
if questions:
|
|
|
|
|
db.add_all(questions)
|
|
|
|
|
db.add(
|
|
|
|
|
ImportHistory(
|
|
|
|
|
filename=filename,
|
|
|
|
|
method=method,
|
|
|
|
|
question_count=len(questions),
|
|
|
|
|
status="success",
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
db.commit()
|
|
|
|
|
|
|
|
|
|
with BATCH_TASKS_LOCK:
|
|
|
|
|
task = BATCH_TASKS[task_id]
|
|
|
|
|
task["success_count"] += 1
|
|
|
|
|
task["total_questions"] += len(questions)
|
|
|
|
|
task["results"].append(
|
|
|
|
|
{
|
|
|
|
|
"filename": filename,
|
|
|
|
|
"status": "success",
|
|
|
|
|
"question_count": len(questions),
|
|
|
|
|
}
|
|
|
|
|
)
|
|
|
|
|
except Exception as exc:
|
|
|
|
|
db.rollback()
|
|
|
|
|
db.add(
|
|
|
|
|
ImportHistory(
|
|
|
|
|
filename=filename,
|
|
|
|
|
method=method,
|
|
|
|
|
question_count=0,
|
|
|
|
|
status="failed",
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
db.commit()
|
|
|
|
|
with BATCH_TASKS_LOCK:
|
|
|
|
|
task = BATCH_TASKS[task_id]
|
|
|
|
|
task["failed_count"] += 1
|
|
|
|
|
task["results"].append(
|
|
|
|
|
{
|
|
|
|
|
"filename": filename,
|
|
|
|
|
"status": "failed",
|
|
|
|
|
"error": str(exc),
|
|
|
|
|
"question_count": 0,
|
|
|
|
|
}
|
|
|
|
|
)
|
|
|
|
|
finally:
|
|
|
|
|
_set_task(task_id, processed=index)
|
|
|
|
|
|
|
|
|
|
_set_task(
|
|
|
|
|
task_id,
|
|
|
|
|
status="completed",
|
|
|
|
|
current_file="",
|
|
|
|
|
ended_at=datetime.utcnow().isoformat(),
|
|
|
|
|
)
|
|
|
|
|
except Exception as exc:
|
|
|
|
|
_set_task(
|
|
|
|
|
task_id,
|
|
|
|
|
status="failed",
|
|
|
|
|
error=str(exc),
|
|
|
|
|
ended_at=datetime.utcnow().isoformat(),
|
|
|
|
|
)
|
|
|
|
|
finally:
|
|
|
|
|
db.close()
|
|
|
|
|
|
2026-03-05 11:50:15 +08:00
|
|
|
|
|
|
|
|
@router.post("/ai/parse")
|
|
|
|
|
def parse_by_ai(file: UploadFile = File(...)) -> dict:
|
|
|
|
|
path = save_upload(file)
|
2026-03-06 15:52:34 +08:00
|
|
|
parser = OpenAICompatibleParserService()
|
2026-03-05 11:50:15 +08:00
|
|
|
metadata = extract_metadata(file.filename or path.name)
|
2026-03-06 15:52:34 +08:00
|
|
|
try:
|
|
|
|
|
questions = parser.parse_file(str(path))
|
|
|
|
|
except ValueError as exc:
|
|
|
|
|
raise HTTPException(status_code=502, detail=str(exc)) from exc
|
2026-03-05 11:50:15 +08:00
|
|
|
preview = []
|
|
|
|
|
for q in questions:
|
|
|
|
|
preview.append(
|
|
|
|
|
{
|
|
|
|
|
"chapter": metadata["chapter"],
|
|
|
|
|
"primary_knowledge": "",
|
|
|
|
|
"secondary_knowledge": metadata["secondary_knowledge"],
|
|
|
|
|
"question_type": metadata["question_type"],
|
|
|
|
|
"difficulty": metadata["difficulty"],
|
|
|
|
|
"stem": q.get("题干", ""),
|
|
|
|
|
"option_a": q.get("选项A", ""),
|
|
|
|
|
"option_b": q.get("选项B", ""),
|
|
|
|
|
"option_c": q.get("选项C", ""),
|
|
|
|
|
"option_d": q.get("选项D", ""),
|
|
|
|
|
"answer": q.get("正确答案", ""),
|
|
|
|
|
"explanation": q.get("解析", ""),
|
|
|
|
|
"notes": q.get("备注", ""),
|
|
|
|
|
"source_file": metadata["source_file"],
|
|
|
|
|
}
|
|
|
|
|
)
|
|
|
|
|
return {"filename": file.filename, "preview": preview}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.post("/ai/confirm", response_model=list[QuestionOut])
|
|
|
|
|
def confirm_ai_import(payload: list[dict], db: Session = Depends(get_db)) -> list[QuestionOut]:
|
|
|
|
|
if not payload:
|
|
|
|
|
raise HTTPException(status_code=400, detail="没有可导入数据")
|
|
|
|
|
items = [Question(**item) for item in payload]
|
|
|
|
|
db.add_all(items)
|
|
|
|
|
db.add(
|
|
|
|
|
ImportHistory(
|
|
|
|
|
filename=items[0].source_file if items else "",
|
|
|
|
|
method="ai",
|
|
|
|
|
question_count=len(items),
|
|
|
|
|
status="success",
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
db.commit()
|
|
|
|
|
for item in items:
|
|
|
|
|
db.refresh(item)
|
|
|
|
|
return [QuestionOut.model_validate(item) for item in items]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.post("/excel", response_model=list[QuestionOut])
|
|
|
|
|
def import_excel(file: UploadFile = File(...), db: Session = Depends(get_db)) -> list[QuestionOut]:
|
|
|
|
|
path = save_upload(file)
|
|
|
|
|
if Path(path).suffix.lower() not in [".xlsx", ".xlsm", ".xltx", ".xltm"]:
|
|
|
|
|
raise HTTPException(status_code=400, detail="仅支持 Excel 文件")
|
|
|
|
|
rows = parse_excel_file(Path(path))
|
|
|
|
|
items = [Question(**row) for row in rows]
|
|
|
|
|
db.add_all(items)
|
|
|
|
|
db.add(
|
|
|
|
|
ImportHistory(
|
|
|
|
|
filename=file.filename or "",
|
|
|
|
|
method="excel",
|
|
|
|
|
question_count=len(items),
|
|
|
|
|
status="success",
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
db.commit()
|
|
|
|
|
for item in items:
|
|
|
|
|
db.refresh(item)
|
|
|
|
|
return [QuestionOut.model_validate(item) for item in items]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.get("/template")
|
|
|
|
|
def download_template() -> dict:
|
|
|
|
|
content = create_template_bytes()
|
|
|
|
|
return {"filename": "question_template.xlsx", "content_base64": content.hex()}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.get("/history", response_model=list[ImportHistoryOut])
|
|
|
|
|
def import_history(db: Session = Depends(get_db)) -> list[ImportHistoryOut]:
|
|
|
|
|
rows = db.query(ImportHistory).order_by(ImportHistory.created_at.desc()).limit(100).all()
|
|
|
|
|
return [ImportHistoryOut.model_validate(r) for r in rows]
|
2026-03-06 15:52:34 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.post("/batch/start")
|
|
|
|
|
def start_batch_import(
|
|
|
|
|
background_tasks: BackgroundTasks,
|
|
|
|
|
files: list[UploadFile] = File(...),
|
|
|
|
|
method: BatchMethod = Form("ai"),
|
|
|
|
|
) -> dict:
|
|
|
|
|
if not files:
|
|
|
|
|
raise HTTPException(status_code=400, detail="请至少上传一个文件")
|
|
|
|
|
|
|
|
|
|
task_id = uuid4().hex
|
|
|
|
|
saved_files: list[dict] = []
|
|
|
|
|
for f in files:
|
|
|
|
|
saved_path = save_upload(f)
|
|
|
|
|
saved_files.append({"path": str(saved_path), "filename": f.filename or Path(saved_path).name})
|
|
|
|
|
|
|
|
|
|
with BATCH_TASKS_LOCK:
|
|
|
|
|
BATCH_TASKS[task_id] = {
|
|
|
|
|
"task_id": task_id,
|
|
|
|
|
"status": "running",
|
|
|
|
|
"method": method,
|
|
|
|
|
"total": len(saved_files),
|
|
|
|
|
"processed": 0,
|
|
|
|
|
"current_file": "",
|
|
|
|
|
"success_count": 0,
|
|
|
|
|
"failed_count": 0,
|
|
|
|
|
"total_questions": 0,
|
|
|
|
|
"results": [],
|
|
|
|
|
"error": "",
|
|
|
|
|
"started_at": datetime.utcnow().isoformat(),
|
|
|
|
|
"ended_at": "",
|
|
|
|
|
}
|
|
|
|
|
background_tasks.add_task(_run_batch_import, task_id, saved_files, method)
|
|
|
|
|
return {"task_id": task_id}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.get("/batch/{task_id}")
|
|
|
|
|
def get_batch_import_progress(task_id: str) -> dict:
|
|
|
|
|
with BATCH_TASKS_LOCK:
|
|
|
|
|
task = BATCH_TASKS.get(task_id)
|
|
|
|
|
if not task:
|
|
|
|
|
raise HTTPException(status_code=404, detail="任务不存在")
|
|
|
|
|
total = task["total"] or 1
|
|
|
|
|
progress = round((task["processed"] / total) * 100, 2)
|
|
|
|
|
return {**task, "progress": progress}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ----- 持久化队列 Jobs API -----
|
|
|
|
|
|
|
|
|
|
@router.post("/jobs", response_model=ImportJobOut)
|
|
|
|
|
def create_import_job(
|
|
|
|
|
files: list[UploadFile] = File(...),
|
|
|
|
|
method: BatchMethod = Form("ai"),
|
|
|
|
|
db: Session = Depends(get_db),
|
|
|
|
|
) -> ImportJobOut:
|
|
|
|
|
if not files:
|
|
|
|
|
raise HTTPException(status_code=400, detail="请至少上传一个文件")
|
|
|
|
|
job = import_job_repo.create_job_empty(db, method)
|
|
|
|
|
items: list[dict] = []
|
|
|
|
|
for seq, f in enumerate(files, start=1):
|
|
|
|
|
path = save_upload_for_job(job.id, seq, f)
|
|
|
|
|
items.append({"filename": f.filename or path.name, "stored_path": str(path)})
|
|
|
|
|
import_job_repo.add_job_items(db, job.id, items)
|
|
|
|
|
job = import_job_repo.get_job(db, job.id)
|
|
|
|
|
return ImportJobOut.model_validate(job)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.get("/jobs/{job_id}", response_model=ImportJobOut)
|
|
|
|
|
def get_import_job(job_id: int, db: Session = Depends(get_db)) -> ImportJobOut:
|
|
|
|
|
job = import_job_repo.get_job(db, job_id)
|
|
|
|
|
if not job:
|
|
|
|
|
raise HTTPException(status_code=404, detail="任务不存在")
|
|
|
|
|
return ImportJobOut.model_validate(job)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.get("/jobs", response_model=list[ImportJobOut])
|
|
|
|
|
def list_import_jobs(
|
|
|
|
|
status: str | None = Query(None, description="queued,running 等,逗号分隔"),
|
|
|
|
|
limit: int = Query(50, le=100),
|
|
|
|
|
db: Session = Depends(get_db),
|
|
|
|
|
) -> list[ImportJobOut]:
|
|
|
|
|
statuses = [s.strip() for s in status.split(",") if s.strip()] if status else None
|
|
|
|
|
jobs = import_job_repo.list_jobs(db, statuses=statuses, limit=limit)
|
|
|
|
|
return [ImportJobOut.model_validate(j) for j in jobs]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.post("/jobs/{job_id}/retry", response_model=ImportJobOut)
|
|
|
|
|
def retry_import_job(job_id: int, db: Session = Depends(get_db)) -> ImportJobOut:
|
|
|
|
|
job = import_job_repo.get_job(db, job_id)
|
|
|
|
|
if not job:
|
|
|
|
|
raise HTTPException(status_code=404, detail="任务不存在")
|
|
|
|
|
failed = import_job_repo.get_failed_items(db, job_id)
|
|
|
|
|
if not failed:
|
|
|
|
|
raise HTTPException(status_code=400, detail="没有可重试的失败项")
|
|
|
|
|
items = [{"filename": it.filename, "stored_path": it.stored_path} for it in failed]
|
|
|
|
|
new_job = import_job_repo.create_job(db, job.method, items)
|
|
|
|
|
new_job = import_job_repo.get_job(db, new_job.id)
|
|
|
|
|
return ImportJobOut.model_validate(new_job)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.post("/jobs/{job_id}/cancel", response_model=ImportJobOut)
|
|
|
|
|
def cancel_import_job(job_id: int, db: Session = Depends(get_db)) -> ImportJobOut:
|
|
|
|
|
job = import_job_repo.cancel_job(db, job_id)
|
|
|
|
|
if not job:
|
|
|
|
|
raise HTTPException(status_code=400, detail="任务不存在或无法取消")
|
|
|
|
|
return ImportJobOut.model_validate(job)
|