update: uploads
This commit is contained in:
@@ -7,7 +7,7 @@ from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
class Settings(BaseSettings):
|
||||
api_key: str = ""
|
||||
model_name: str = "gpt-4.1"
|
||||
dmxapi_url: str = "https://www.dmxapi.cn/v1/responses"
|
||||
openai_api_url: str = "https://api.openai.com/v1/responses"
|
||||
jwt_secret_key: str = "change-me-in-env"
|
||||
jwt_algorithm: str = "HS256"
|
||||
access_token_expire_minutes: int = 60 * 12
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
import threading
|
||||
|
||||
from backend.database import Base, engine
|
||||
from backend.database import Base, engine, SessionLocal
|
||||
from backend.routers import auth, categories, exports, imports, practice, questions, stats
|
||||
from backend.services.import_queue_service import reset_stale_running_jobs, run_worker_loop
|
||||
|
||||
app = FastAPI(title="Problem Bank API", version="1.0.0")
|
||||
|
||||
@@ -16,6 +18,17 @@ app.add_middleware(
|
||||
|
||||
Base.metadata.create_all(bind=engine)
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
def startup() -> None:
|
||||
db = SessionLocal()
|
||||
try:
|
||||
reset_stale_running_jobs(db)
|
||||
finally:
|
||||
db.close()
|
||||
thread = threading.Thread(target=run_worker_loop, daemon=True)
|
||||
thread.start()
|
||||
|
||||
app.include_router(auth.router, prefix="/api/auth", tags=["auth"])
|
||||
app.include_router(questions.router, prefix="/api/questions", tags=["questions"])
|
||||
app.include_router(imports.router, prefix="/api/import", tags=["imports"])
|
||||
|
||||
@@ -1,11 +1,29 @@
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import DateTime, ForeignKey, Integer, Text
|
||||
from sqlalchemy import DateTime, ForeignKey, Integer, String, Text
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from backend.database import Base
|
||||
|
||||
|
||||
# 导入任务/文件项统一状态
|
||||
JOB_STATUS_QUEUED = "queued"
|
||||
JOB_STATUS_RUNNING = "running"
|
||||
JOB_STATUS_SUCCESS = "success"
|
||||
JOB_STATUS_FAILED = "failed"
|
||||
JOB_STATUS_CANCELLED = "cancelled"
|
||||
JOB_STATUS_RETRYING = "retrying"
|
||||
|
||||
JOB_STATUSES = (
|
||||
JOB_STATUS_QUEUED,
|
||||
JOB_STATUS_RUNNING,
|
||||
JOB_STATUS_SUCCESS,
|
||||
JOB_STATUS_FAILED,
|
||||
JOB_STATUS_CANCELLED,
|
||||
JOB_STATUS_RETRYING,
|
||||
)
|
||||
|
||||
|
||||
class Question(Base):
|
||||
__tablename__ = "questions"
|
||||
|
||||
@@ -52,3 +70,51 @@ class ImportHistory(Base):
|
||||
question_count: Mapped[int] = mapped_column(Integer, default=0)
|
||||
status: Mapped[str] = mapped_column(Text, default="success")
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
|
||||
|
||||
|
||||
class ImportJob(Base):
|
||||
"""导入任务(持久化队列项)。"""
|
||||
|
||||
__tablename__ = "import_jobs"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, index=True)
|
||||
status: Mapped[str] = mapped_column(String(32), default=JOB_STATUS_QUEUED, index=True)
|
||||
method: Mapped[str] = mapped_column(String(32), nullable=False)
|
||||
total: Mapped[int] = mapped_column(Integer, default=0)
|
||||
processed: Mapped[int] = mapped_column(Integer, default=0)
|
||||
success_count: Mapped[int] = mapped_column(Integer, default=0)
|
||||
failed_count: Mapped[int] = mapped_column(Integer, default=0)
|
||||
current_index: Mapped[int] = mapped_column(Integer, default=0)
|
||||
current_file: Mapped[str] = mapped_column(Text, default="")
|
||||
error: Mapped[str] = mapped_column(Text, default="")
|
||||
attempt: Mapped[int] = mapped_column(Integer, default=1)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
|
||||
started_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||
ended_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, default=datetime.utcnow, onupdate=datetime.utcnow
|
||||
)
|
||||
|
||||
items: Mapped[list["ImportJobItem"]] = relationship(
|
||||
"ImportJobItem", back_populates="job", order_by="ImportJobItem.seq", lazy="selectin"
|
||||
)
|
||||
|
||||
|
||||
class ImportJobItem(Base):
|
||||
"""导入任务内的单个文件项。"""
|
||||
|
||||
__tablename__ = "import_job_items"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, index=True)
|
||||
job_id: Mapped[int] = mapped_column(ForeignKey("import_jobs.id", ondelete="CASCADE"), nullable=False)
|
||||
seq: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
filename: Mapped[str] = mapped_column(Text, default="")
|
||||
stored_path: Mapped[str] = mapped_column(Text, default="")
|
||||
status: Mapped[str] = mapped_column(String(32), default=JOB_STATUS_QUEUED)
|
||||
attempt: Mapped[int] = mapped_column(Integer, default=1)
|
||||
error: Mapped[str] = mapped_column(Text, default="")
|
||||
question_count: Mapped[int] = mapped_column(Integer, default=0)
|
||||
started_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||
ended_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||
|
||||
job: Mapped["ImportJob"] = relationship("ImportJob", back_populates="items")
|
||||
|
||||
1
backend/repositories/__init__.py
Normal file
1
backend/repositories/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Repositories for DB access
|
||||
181
backend/repositories/import_job_repo.py
Normal file
181
backend/repositories/import_job_repo.py
Normal file
@@ -0,0 +1,181 @@
|
||||
"""Import job persistence: create, get, list, update, claim for FIFO worker."""
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from backend.models import (
|
||||
JOB_STATUS_FAILED,
|
||||
JOB_STATUS_QUEUED,
|
||||
JOB_STATUS_RUNNING,
|
||||
ImportJob,
|
||||
ImportJobItem,
|
||||
)
|
||||
|
||||
|
||||
def create_job(
|
||||
db: Session,
|
||||
method: str,
|
||||
items: list[dict],
|
||||
) -> ImportJob:
|
||||
"""Create a new import job with items. items: [{"filename": str, "stored_path": str}, ...]."""
|
||||
job = ImportJob(
|
||||
status=JOB_STATUS_QUEUED,
|
||||
method=method,
|
||||
total=len(items),
|
||||
processed=0,
|
||||
success_count=0,
|
||||
failed_count=0,
|
||||
)
|
||||
db.add(job)
|
||||
db.flush()
|
||||
for seq, it in enumerate(items, start=1):
|
||||
db.add(
|
||||
ImportJobItem(
|
||||
job_id=job.id,
|
||||
seq=seq,
|
||||
filename=it.get("filename", ""),
|
||||
stored_path=it.get("stored_path", ""),
|
||||
status=JOB_STATUS_QUEUED,
|
||||
)
|
||||
)
|
||||
db.commit()
|
||||
db.refresh(job)
|
||||
return job
|
||||
|
||||
|
||||
def create_job_empty(db: Session, method: str) -> ImportJob:
|
||||
"""Create job with no items; caller then adds items and sets total."""
|
||||
job = ImportJob(
|
||||
status=JOB_STATUS_QUEUED,
|
||||
method=method,
|
||||
total=0,
|
||||
processed=0,
|
||||
success_count=0,
|
||||
failed_count=0,
|
||||
)
|
||||
db.add(job)
|
||||
db.flush()
|
||||
return job
|
||||
|
||||
|
||||
def add_job_items(db: Session, job_id: int, items: list[dict]) -> None:
|
||||
"""Append items to job and set job.total. items: [{"filename": str, "stored_path": str}, ...]."""
|
||||
job = db.query(ImportJob).filter(ImportJob.id == job_id).first()
|
||||
if not job:
|
||||
return
|
||||
for seq, it in enumerate(items, start=1):
|
||||
db.add(
|
||||
ImportJobItem(
|
||||
job_id=job_id,
|
||||
seq=seq,
|
||||
filename=it.get("filename", ""),
|
||||
stored_path=it.get("stored_path", ""),
|
||||
status=JOB_STATUS_QUEUED,
|
||||
)
|
||||
)
|
||||
job.total = len(items)
|
||||
db.commit()
|
||||
|
||||
|
||||
def get_job(db: Session, job_id: int) -> ImportJob | None:
|
||||
"""Get job by id with items loaded."""
|
||||
return db.query(ImportJob).filter(ImportJob.id == job_id).first()
|
||||
|
||||
|
||||
def list_jobs(
|
||||
db: Session,
|
||||
statuses: list[str] | None = None,
|
||||
limit: int = 50,
|
||||
) -> list[ImportJob]:
|
||||
"""List jobs, optionally filtered by status(es), newest first."""
|
||||
q = db.query(ImportJob).order_by(ImportJob.created_at.desc())
|
||||
if statuses:
|
||||
q = q.filter(ImportJob.status.in_(statuses))
|
||||
return q.limit(limit).all()
|
||||
|
||||
|
||||
def update_job(db: Session, job_id: int, **kwargs) -> ImportJob | None:
|
||||
"""Update job fields. Returns updated job or None if not found."""
|
||||
job = db.query(ImportJob).filter(ImportJob.id == job_id).first()
|
||||
if not job:
|
||||
return None
|
||||
for k, v in kwargs.items():
|
||||
if hasattr(job, k):
|
||||
setattr(job, k, v)
|
||||
job.updated_at = datetime.utcnow()
|
||||
db.commit()
|
||||
db.refresh(job)
|
||||
return job
|
||||
|
||||
|
||||
def update_job_item(db: Session, item_id: int, **kwargs) -> ImportJobItem | None:
|
||||
"""Update a single job item."""
|
||||
item = db.query(ImportJobItem).filter(ImportJobItem.id == item_id).first()
|
||||
if not item:
|
||||
return None
|
||||
for k, v in kwargs.items():
|
||||
if hasattr(item, k):
|
||||
setattr(item, k, v)
|
||||
db.commit()
|
||||
db.refresh(item)
|
||||
return item
|
||||
|
||||
|
||||
def claim_oldest_queued(db: Session) -> ImportJob | None:
|
||||
"""
|
||||
Claim the oldest queued job by setting status to running.
|
||||
Returns the job if claimed, None if no queued job.
|
||||
Used by single-worker FIFO loop. (SQLite-compatible: no row locking.)
|
||||
"""
|
||||
job = (
|
||||
db.query(ImportJob)
|
||||
.filter(ImportJob.status == JOB_STATUS_QUEUED)
|
||||
.order_by(ImportJob.created_at.asc())
|
||||
.first()
|
||||
)
|
||||
if not job:
|
||||
return None
|
||||
job.status = JOB_STATUS_RUNNING
|
||||
job.started_at = datetime.utcnow()
|
||||
job.updated_at = datetime.utcnow()
|
||||
db.commit()
|
||||
db.refresh(job)
|
||||
return job
|
||||
|
||||
|
||||
def get_queued_job_for_worker(db: Session) -> ImportJob | None:
|
||||
"""
|
||||
Get oldest queued job without locking (SQLite has limited FOR UPDATE support).
|
||||
Caller must then update status to running in same or follow-up transaction.
|
||||
"""
|
||||
job = (
|
||||
db.query(ImportJob)
|
||||
.filter(ImportJob.status == JOB_STATUS_QUEUED)
|
||||
.order_by(ImportJob.created_at.asc())
|
||||
.first()
|
||||
)
|
||||
return job
|
||||
|
||||
|
||||
def cancel_job(db: Session, job_id: int) -> ImportJob | None:
|
||||
"""Set job status to cancelled if it is queued or running."""
|
||||
job = db.query(ImportJob).filter(ImportJob.id == job_id).first()
|
||||
if not job or job.status not in (JOB_STATUS_QUEUED, JOB_STATUS_RUNNING):
|
||||
return None
|
||||
job.status = "cancelled"
|
||||
job.ended_at = datetime.utcnow()
|
||||
job.updated_at = datetime.utcnow()
|
||||
db.commit()
|
||||
db.refresh(job)
|
||||
return job
|
||||
|
||||
|
||||
def get_failed_items(db: Session, job_id: int) -> list[ImportJobItem]:
|
||||
"""Return items with status failed for retry."""
|
||||
return (
|
||||
db.query(ImportJobItem)
|
||||
.filter(ImportJobItem.job_id == job_id, ImportJobItem.status == JOB_STATUS_FAILED)
|
||||
.order_by(ImportJobItem.seq.asc())
|
||||
.all()
|
||||
)
|
||||
@@ -10,3 +10,4 @@ python-multipart
|
||||
requests
|
||||
openpyxl
|
||||
python-dotenv
|
||||
pytest
|
||||
|
||||
@@ -1,25 +1,153 @@
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import APIRouter, Depends, File, HTTPException, UploadFile
|
||||
from threading import Lock
|
||||
from typing import Literal
|
||||
from uuid import uuid4
|
||||
from fastapi import APIRouter, BackgroundTasks, Depends, File, Form, HTTPException, Query, UploadFile
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from backend.auth import get_current_user
|
||||
from backend.database import get_db
|
||||
from backend.database import SessionLocal, get_db
|
||||
from backend.models import ImportHistory, Question
|
||||
from backend.schemas import ImportHistoryOut, QuestionOut
|
||||
from backend.repositories import import_job_repo
|
||||
from backend.schemas import ImportHistoryOut, ImportJobOut, QuestionOut
|
||||
from backend.services.excel_service import create_template_bytes, parse_excel_file
|
||||
from backend.services.file_utils import save_upload
|
||||
from backend.services.parser import DMXAPIService, extract_metadata
|
||||
from backend.services.file_utils import save_upload, save_upload_for_job
|
||||
from backend.services.parser import OpenAICompatibleParserService, extract_metadata
|
||||
|
||||
router = APIRouter(dependencies=[Depends(get_current_user)])
|
||||
|
||||
BATCH_TASKS: dict[str, dict] = {}
|
||||
BATCH_TASKS_LOCK = Lock()
|
||||
BatchMethod = Literal["excel", "ai"]
|
||||
|
||||
|
||||
|
||||
def _set_task(task_id: str, **fields) -> None:
|
||||
with BATCH_TASKS_LOCK:
|
||||
task = BATCH_TASKS.get(task_id)
|
||||
if not task:
|
||||
return
|
||||
task.update(fields)
|
||||
|
||||
|
||||
def _build_ai_rows(path: Path) -> list[dict]:
|
||||
parser = OpenAICompatibleParserService()
|
||||
metadata = extract_metadata(path.name)
|
||||
questions = parser.parse_file(str(path))
|
||||
rows = []
|
||||
for q in questions:
|
||||
rows.append(
|
||||
{
|
||||
"chapter": metadata["chapter"],
|
||||
"primary_knowledge": "",
|
||||
"secondary_knowledge": metadata["secondary_knowledge"],
|
||||
"question_type": metadata["question_type"],
|
||||
"difficulty": metadata["difficulty"],
|
||||
"stem": q.get("题干", ""),
|
||||
"option_a": q.get("选项A", ""),
|
||||
"option_b": q.get("选项B", ""),
|
||||
"option_c": q.get("选项C", ""),
|
||||
"option_d": q.get("选项D", ""),
|
||||
"answer": q.get("正确答案", ""),
|
||||
"explanation": q.get("解析", ""),
|
||||
"notes": q.get("备注", ""),
|
||||
"source_file": metadata["source_file"],
|
||||
}
|
||||
)
|
||||
return rows
|
||||
|
||||
|
||||
def _run_batch_import(task_id: str, files: list[dict], method: BatchMethod) -> None:
|
||||
db = SessionLocal()
|
||||
try:
|
||||
for index, item in enumerate(files, start=1):
|
||||
path = Path(item["path"])
|
||||
filename = item["filename"]
|
||||
_set_task(task_id, current_file=filename)
|
||||
try:
|
||||
if method == "excel":
|
||||
if path.suffix.lower() not in [".xlsx", ".xlsm", ".xltx", ".xltm"]:
|
||||
raise ValueError("仅支持 Excel 文件")
|
||||
rows = parse_excel_file(path)
|
||||
else:
|
||||
rows = _build_ai_rows(path)
|
||||
|
||||
questions = [Question(**row) for row in rows]
|
||||
if questions:
|
||||
db.add_all(questions)
|
||||
db.add(
|
||||
ImportHistory(
|
||||
filename=filename,
|
||||
method=method,
|
||||
question_count=len(questions),
|
||||
status="success",
|
||||
)
|
||||
)
|
||||
db.commit()
|
||||
|
||||
with BATCH_TASKS_LOCK:
|
||||
task = BATCH_TASKS[task_id]
|
||||
task["success_count"] += 1
|
||||
task["total_questions"] += len(questions)
|
||||
task["results"].append(
|
||||
{
|
||||
"filename": filename,
|
||||
"status": "success",
|
||||
"question_count": len(questions),
|
||||
}
|
||||
)
|
||||
except Exception as exc:
|
||||
db.rollback()
|
||||
db.add(
|
||||
ImportHistory(
|
||||
filename=filename,
|
||||
method=method,
|
||||
question_count=0,
|
||||
status="failed",
|
||||
)
|
||||
)
|
||||
db.commit()
|
||||
with BATCH_TASKS_LOCK:
|
||||
task = BATCH_TASKS[task_id]
|
||||
task["failed_count"] += 1
|
||||
task["results"].append(
|
||||
{
|
||||
"filename": filename,
|
||||
"status": "failed",
|
||||
"error": str(exc),
|
||||
"question_count": 0,
|
||||
}
|
||||
)
|
||||
finally:
|
||||
_set_task(task_id, processed=index)
|
||||
|
||||
_set_task(
|
||||
task_id,
|
||||
status="completed",
|
||||
current_file="",
|
||||
ended_at=datetime.utcnow().isoformat(),
|
||||
)
|
||||
except Exception as exc:
|
||||
_set_task(
|
||||
task_id,
|
||||
status="failed",
|
||||
error=str(exc),
|
||||
ended_at=datetime.utcnow().isoformat(),
|
||||
)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
@router.post("/ai/parse")
|
||||
def parse_by_ai(file: UploadFile = File(...)) -> dict:
|
||||
path = save_upload(file)
|
||||
parser = DMXAPIService()
|
||||
parser = OpenAICompatibleParserService()
|
||||
metadata = extract_metadata(file.filename or path.name)
|
||||
questions = parser.parse_file(str(path))
|
||||
try:
|
||||
questions = parser.parse_file(str(path))
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=502, detail=str(exc)) from exc
|
||||
preview = []
|
||||
for q in questions:
|
||||
preview.append(
|
||||
@@ -95,3 +223,110 @@ def download_template() -> dict:
|
||||
def import_history(db: Session = Depends(get_db)) -> list[ImportHistoryOut]:
|
||||
rows = db.query(ImportHistory).order_by(ImportHistory.created_at.desc()).limit(100).all()
|
||||
return [ImportHistoryOut.model_validate(r) for r in rows]
|
||||
|
||||
|
||||
@router.post("/batch/start")
|
||||
def start_batch_import(
|
||||
background_tasks: BackgroundTasks,
|
||||
files: list[UploadFile] = File(...),
|
||||
method: BatchMethod = Form("ai"),
|
||||
) -> dict:
|
||||
if not files:
|
||||
raise HTTPException(status_code=400, detail="请至少上传一个文件")
|
||||
|
||||
task_id = uuid4().hex
|
||||
saved_files: list[dict] = []
|
||||
for f in files:
|
||||
saved_path = save_upload(f)
|
||||
saved_files.append({"path": str(saved_path), "filename": f.filename or Path(saved_path).name})
|
||||
|
||||
with BATCH_TASKS_LOCK:
|
||||
BATCH_TASKS[task_id] = {
|
||||
"task_id": task_id,
|
||||
"status": "running",
|
||||
"method": method,
|
||||
"total": len(saved_files),
|
||||
"processed": 0,
|
||||
"current_file": "",
|
||||
"success_count": 0,
|
||||
"failed_count": 0,
|
||||
"total_questions": 0,
|
||||
"results": [],
|
||||
"error": "",
|
||||
"started_at": datetime.utcnow().isoformat(),
|
||||
"ended_at": "",
|
||||
}
|
||||
background_tasks.add_task(_run_batch_import, task_id, saved_files, method)
|
||||
return {"task_id": task_id}
|
||||
|
||||
|
||||
@router.get("/batch/{task_id}")
|
||||
def get_batch_import_progress(task_id: str) -> dict:
|
||||
with BATCH_TASKS_LOCK:
|
||||
task = BATCH_TASKS.get(task_id)
|
||||
if not task:
|
||||
raise HTTPException(status_code=404, detail="任务不存在")
|
||||
total = task["total"] or 1
|
||||
progress = round((task["processed"] / total) * 100, 2)
|
||||
return {**task, "progress": progress}
|
||||
|
||||
|
||||
# ----- 持久化队列 Jobs API -----
|
||||
|
||||
@router.post("/jobs", response_model=ImportJobOut)
|
||||
def create_import_job(
|
||||
files: list[UploadFile] = File(...),
|
||||
method: BatchMethod = Form("ai"),
|
||||
db: Session = Depends(get_db),
|
||||
) -> ImportJobOut:
|
||||
if not files:
|
||||
raise HTTPException(status_code=400, detail="请至少上传一个文件")
|
||||
job = import_job_repo.create_job_empty(db, method)
|
||||
items: list[dict] = []
|
||||
for seq, f in enumerate(files, start=1):
|
||||
path = save_upload_for_job(job.id, seq, f)
|
||||
items.append({"filename": f.filename or path.name, "stored_path": str(path)})
|
||||
import_job_repo.add_job_items(db, job.id, items)
|
||||
job = import_job_repo.get_job(db, job.id)
|
||||
return ImportJobOut.model_validate(job)
|
||||
|
||||
|
||||
@router.get("/jobs/{job_id}", response_model=ImportJobOut)
|
||||
def get_import_job(job_id: int, db: Session = Depends(get_db)) -> ImportJobOut:
|
||||
job = import_job_repo.get_job(db, job_id)
|
||||
if not job:
|
||||
raise HTTPException(status_code=404, detail="任务不存在")
|
||||
return ImportJobOut.model_validate(job)
|
||||
|
||||
|
||||
@router.get("/jobs", response_model=list[ImportJobOut])
|
||||
def list_import_jobs(
|
||||
status: str | None = Query(None, description="queued,running 等,逗号分隔"),
|
||||
limit: int = Query(50, le=100),
|
||||
db: Session = Depends(get_db),
|
||||
) -> list[ImportJobOut]:
|
||||
statuses = [s.strip() for s in status.split(",") if s.strip()] if status else None
|
||||
jobs = import_job_repo.list_jobs(db, statuses=statuses, limit=limit)
|
||||
return [ImportJobOut.model_validate(j) for j in jobs]
|
||||
|
||||
|
||||
@router.post("/jobs/{job_id}/retry", response_model=ImportJobOut)
|
||||
def retry_import_job(job_id: int, db: Session = Depends(get_db)) -> ImportJobOut:
|
||||
job = import_job_repo.get_job(db, job_id)
|
||||
if not job:
|
||||
raise HTTPException(status_code=404, detail="任务不存在")
|
||||
failed = import_job_repo.get_failed_items(db, job_id)
|
||||
if not failed:
|
||||
raise HTTPException(status_code=400, detail="没有可重试的失败项")
|
||||
items = [{"filename": it.filename, "stored_path": it.stored_path} for it in failed]
|
||||
new_job = import_job_repo.create_job(db, job.method, items)
|
||||
new_job = import_job_repo.get_job(db, new_job.id)
|
||||
return ImportJobOut.model_validate(new_job)
|
||||
|
||||
|
||||
@router.post("/jobs/{job_id}/cancel", response_model=ImportJobOut)
|
||||
def cancel_import_job(job_id: int, db: Session = Depends(get_db)) -> ImportJobOut:
|
||||
job = import_job_repo.cancel_job(db, job_id)
|
||||
if not job:
|
||||
raise HTTPException(status_code=400, detail="任务不存在或无法取消")
|
||||
return ImportJobOut.model_validate(job)
|
||||
|
||||
@@ -103,6 +103,45 @@ class ImportHistoryOut(BaseModel):
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class ImportJobItemOut(BaseModel):
|
||||
id: int
|
||||
job_id: int
|
||||
seq: int
|
||||
filename: str
|
||||
stored_path: str
|
||||
status: str
|
||||
attempt: int
|
||||
error: str
|
||||
question_count: int
|
||||
started_at: datetime | None
|
||||
ended_at: datetime | None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class ImportJobOut(BaseModel):
|
||||
id: int
|
||||
status: str
|
||||
method: str
|
||||
total: int
|
||||
processed: int
|
||||
success_count: int
|
||||
failed_count: int
|
||||
current_index: int
|
||||
current_file: str
|
||||
error: str
|
||||
attempt: int
|
||||
created_at: datetime
|
||||
started_at: datetime | None
|
||||
ended_at: datetime | None
|
||||
updated_at: datetime
|
||||
items: list[ImportJobItemOut] = []
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class PracticeStartRequest(BaseModel):
|
||||
chapter: Optional[str] = None
|
||||
secondary_knowledge: Optional[str] = None
|
||||
|
||||
@@ -18,3 +18,14 @@ def save_upload(upload_file: UploadFile) -> Path:
|
||||
with target_path.open("wb") as buffer:
|
||||
shutil.copyfileobj(upload_file.file, buffer)
|
||||
return target_path
|
||||
|
||||
|
||||
def save_upload_for_job(job_id: int, seq: int, upload_file: UploadFile) -> Path:
|
||||
"""Save file with unique path under upload_dir/job_id/ to avoid overwrites."""
|
||||
base = ensure_upload_dir() / str(job_id)
|
||||
base.mkdir(parents=True, exist_ok=True)
|
||||
name = upload_file.filename or "file"
|
||||
target_path = base / f"{seq}_{name}"
|
||||
with target_path.open("wb") as buffer:
|
||||
shutil.copyfileobj(upload_file.file, buffer)
|
||||
return target_path
|
||||
|
||||
180
backend/services/import_queue_service.py
Normal file
180
backend/services/import_queue_service.py
Normal file
@@ -0,0 +1,180 @@
|
||||
"""
|
||||
Import queue: single-consumer FIFO worker and job execution.
|
||||
Run run_worker_loop() in a background thread; on startup call reset_stale_running_jobs().
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
import time
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from backend.database import SessionLocal
|
||||
from backend.models import (
|
||||
JOB_STATUS_FAILED,
|
||||
JOB_STATUS_QUEUED,
|
||||
JOB_STATUS_RUNNING,
|
||||
JOB_STATUS_SUCCESS,
|
||||
ImportHistory,
|
||||
ImportJob,
|
||||
ImportJobItem,
|
||||
Question,
|
||||
)
|
||||
from backend.repositories import import_job_repo as repo
|
||||
from backend.services.excel_service import parse_excel_file
|
||||
from backend.services.parser import OpenAICompatibleParserService, extract_metadata
|
||||
|
||||
|
||||
def _build_ai_rows(path: Path) -> list[dict]:
|
||||
parser = OpenAICompatibleParserService()
|
||||
metadata = extract_metadata(path.name)
|
||||
questions = parser.parse_file(str(path))
|
||||
rows = []
|
||||
for q in questions:
|
||||
rows.append(
|
||||
{
|
||||
"chapter": metadata["chapter"],
|
||||
"primary_knowledge": "",
|
||||
"secondary_knowledge": metadata["secondary_knowledge"],
|
||||
"question_type": metadata["question_type"],
|
||||
"difficulty": metadata["difficulty"],
|
||||
"stem": q.get("题干", ""),
|
||||
"option_a": q.get("选项A", ""),
|
||||
"option_b": q.get("选项B", ""),
|
||||
"option_c": q.get("选项C", ""),
|
||||
"option_d": q.get("选项D", ""),
|
||||
"answer": q.get("正确答案", ""),
|
||||
"explanation": q.get("解析", ""),
|
||||
"notes": q.get("备注", ""),
|
||||
"source_file": metadata["source_file"],
|
||||
}
|
||||
)
|
||||
return rows
|
||||
|
||||
|
||||
def _process_one_item(
|
||||
db: Session,
|
||||
job: ImportJob,
|
||||
item: ImportJobItem,
|
||||
method: str,
|
||||
) -> None:
|
||||
path = Path(item.stored_path)
|
||||
filename = item.filename
|
||||
job.current_file = filename
|
||||
job.current_index = item.seq
|
||||
job.updated_at = datetime.utcnow()
|
||||
item.status = JOB_STATUS_RUNNING
|
||||
item.started_at = datetime.utcnow()
|
||||
db.commit()
|
||||
|
||||
try:
|
||||
if method == "excel":
|
||||
if path.suffix.lower() not in [".xlsx", ".xlsm", ".xltx", ".xltm"]:
|
||||
raise ValueError("仅支持 Excel 文件")
|
||||
rows = parse_excel_file(path)
|
||||
else:
|
||||
rows = _build_ai_rows(path)
|
||||
|
||||
questions = [Question(**row) for row in rows]
|
||||
if questions:
|
||||
db.add_all(questions)
|
||||
db.add(
|
||||
ImportHistory(
|
||||
filename=filename,
|
||||
method=method,
|
||||
question_count=len(questions),
|
||||
status="success",
|
||||
)
|
||||
)
|
||||
db.commit()
|
||||
|
||||
item.status = JOB_STATUS_SUCCESS
|
||||
item.question_count = len(questions)
|
||||
item.ended_at = datetime.utcnow()
|
||||
job.success_count += 1
|
||||
job.processed += 1
|
||||
job.updated_at = datetime.utcnow()
|
||||
db.commit()
|
||||
except Exception as exc:
|
||||
db.rollback()
|
||||
db.add(
|
||||
ImportHistory(
|
||||
filename=filename,
|
||||
method=method,
|
||||
question_count=0,
|
||||
status="failed",
|
||||
)
|
||||
)
|
||||
db.commit()
|
||||
item.status = JOB_STATUS_FAILED
|
||||
item.error = str(exc)
|
||||
item.ended_at = datetime.utcnow()
|
||||
job.failed_count += 1
|
||||
job.processed += 1
|
||||
job.updated_at = datetime.utcnow()
|
||||
db.commit()
|
||||
|
||||
|
||||
def process_job(db: Session, job_id: int) -> None:
|
||||
"""Execute a single job: process all items in order, then set job terminal status."""
|
||||
job = repo.get_job(db, job_id)
|
||||
if not job or job.status != JOB_STATUS_RUNNING:
|
||||
return
|
||||
method = job.method
|
||||
items = sorted(job.items, key=lambda x: x.seq)
|
||||
# Resume: ensure processed/success_count/failed_count reflect already-completed items
|
||||
job.processed = sum(1 for it in items if it.status in (JOB_STATUS_SUCCESS, JOB_STATUS_FAILED))
|
||||
job.success_count = sum(1 for it in items if it.status == JOB_STATUS_SUCCESS)
|
||||
job.failed_count = sum(1 for it in items if it.status == JOB_STATUS_FAILED)
|
||||
db.commit()
|
||||
for item in items:
|
||||
if item.status in (JOB_STATUS_SUCCESS, JOB_STATUS_FAILED):
|
||||
continue
|
||||
_process_one_item(db, job, item, method)
|
||||
db.refresh(job)
|
||||
|
||||
job = repo.get_job(db, job_id)
|
||||
if not job:
|
||||
return
|
||||
if job.failed_count > 0 and job.success_count == 0:
|
||||
job.status = JOB_STATUS_FAILED
|
||||
job.error = "部分或全部文件处理失败"
|
||||
else:
|
||||
job.status = JOB_STATUS_SUCCESS
|
||||
job.error = ""
|
||||
job.ended_at = datetime.utcnow()
|
||||
job.current_file = ""
|
||||
job.updated_at = datetime.utcnow()
|
||||
db.commit()
|
||||
|
||||
|
||||
def reset_stale_running_jobs(db: Session) -> int:
|
||||
"""On startup: set any job left in 'running' back to 'queued' so worker can pick it up."""
|
||||
count = 0
|
||||
for job in db.query(ImportJob).filter(ImportJob.status == JOB_STATUS_RUNNING).all():
|
||||
job.status = JOB_STATUS_QUEUED
|
||||
count += 1
|
||||
if count:
|
||||
db.commit()
|
||||
return count
|
||||
|
||||
|
||||
def run_worker_loop(interval_seconds: float = 1.0) -> None:
|
||||
"""
|
||||
Single-consumer FIFO loop. Call from a background thread.
|
||||
Claims oldest queued job, processes it, then repeats. Sleeps when no job.
|
||||
"""
|
||||
while True:
|
||||
db = SessionLocal()
|
||||
try:
|
||||
job = repo.claim_oldest_queued(db)
|
||||
if job:
|
||||
process_job(db, job.id)
|
||||
else:
|
||||
time.sleep(interval_seconds)
|
||||
except Exception:
|
||||
if db:
|
||||
db.rollback()
|
||||
time.sleep(interval_seconds)
|
||||
finally:
|
||||
db.close()
|
||||
@@ -1,6 +1,8 @@
|
||||
import json
|
||||
import re
|
||||
import shutil
|
||||
import subprocess
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import requests
|
||||
@@ -8,11 +10,13 @@ import requests
|
||||
from backend.config import settings
|
||||
|
||||
|
||||
class DMXAPIService:
|
||||
|
||||
|
||||
class OpenAICompatibleParserService:
|
||||
def __init__(self) -> None:
|
||||
self.api_key = settings.api_key
|
||||
self.model_name = settings.model_name
|
||||
self.api_url = settings.dmxapi_url
|
||||
self.api_url = settings.openai_api_url
|
||||
|
||||
def parse_file(self, file_path: str) -> list[dict]:
|
||||
path = Path(file_path)
|
||||
@@ -51,33 +55,94 @@ class DMXAPIService:
|
||||
def _convert_to_pdf(self, path: Path) -> Path:
|
||||
pdf_path = path.with_suffix(".pdf")
|
||||
pdf_path.unlink(missing_ok=True)
|
||||
source_path = path
|
||||
temp_dir: str | None = None
|
||||
|
||||
cmd = [
|
||||
"pandoc",
|
||||
str(path),
|
||||
"-o",
|
||||
str(pdf_path),
|
||||
"--pdf-engine=xelatex",
|
||||
"-V",
|
||||
"CJKmainfont=PingFang SC",
|
||||
]
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, timeout=90)
|
||||
if result.returncode != 0:
|
||||
fallback = [
|
||||
if path.suffix.lower() == ".doc":
|
||||
source_path, temp_dir = self._convert_doc_to_docx(path)
|
||||
|
||||
try:
|
||||
cmd = [
|
||||
"pandoc",
|
||||
str(path),
|
||||
str(source_path),
|
||||
"-o",
|
||||
str(pdf_path),
|
||||
"--pdf-engine=weasyprint",
|
||||
"--pdf-engine=xelatex",
|
||||
"-V",
|
||||
"CJKmainfont=PingFang SC",
|
||||
]
|
||||
result = subprocess.run(fallback, capture_output=True, text=True, timeout=90)
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, timeout=90)
|
||||
if result.returncode != 0:
|
||||
raise ValueError(f"文件转 PDF 失败: {result.stderr}")
|
||||
return pdf_path
|
||||
fallback = [
|
||||
"pandoc",
|
||||
str(source_path),
|
||||
"-o",
|
||||
str(pdf_path),
|
||||
"--pdf-engine=weasyprint",
|
||||
]
|
||||
result = subprocess.run(fallback, capture_output=True, text=True, timeout=90)
|
||||
if result.returncode != 0:
|
||||
raise ValueError(f"文件转 PDF 失败: {result.stderr}")
|
||||
return pdf_path
|
||||
finally:
|
||||
if temp_dir:
|
||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||
|
||||
def _convert_doc_to_docx(self, path: Path) -> tuple[Path, str]:
|
||||
temp_dir = tempfile.mkdtemp(prefix="pb_doc_convert_")
|
||||
converted_path = Path(temp_dir) / f"{path.stem}.docx"
|
||||
convert_errors: list[str] = []
|
||||
|
||||
if shutil.which("soffice"):
|
||||
result = subprocess.run(
|
||||
[
|
||||
"soffice",
|
||||
"--headless",
|
||||
"--convert-to",
|
||||
"docx",
|
||||
"--outdir",
|
||||
temp_dir,
|
||||
str(path),
|
||||
],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=120,
|
||||
)
|
||||
if result.returncode == 0 and converted_path.exists():
|
||||
return converted_path, temp_dir
|
||||
convert_errors.append(f"soffice: {(result.stderr or result.stdout).strip()}")
|
||||
else:
|
||||
convert_errors.append("soffice: 未安装")
|
||||
|
||||
if shutil.which("textutil"):
|
||||
result = subprocess.run(
|
||||
[
|
||||
"textutil",
|
||||
"-convert",
|
||||
"docx",
|
||||
"-output",
|
||||
str(converted_path),
|
||||
str(path),
|
||||
],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=120,
|
||||
)
|
||||
if result.returncode == 0 and converted_path.exists():
|
||||
return converted_path, temp_dir
|
||||
convert_errors.append(f"textutil: {(result.stderr or result.stdout).strip()}")
|
||||
else:
|
||||
convert_errors.append("textutil: 未安装")
|
||||
|
||||
raise ValueError(
|
||||
"检测到 .doc 文件,pandoc 不支持直接转换。"
|
||||
"已尝试自动转换为 .docx 但失败。请先把 .doc 另存为 .docx 后重试。"
|
||||
f" 详细信息: {' | '.join(convert_errors)}"
|
||||
)
|
||||
|
||||
def _parse_with_file_url(self, file_url: str, original_filename: str) -> list[dict]:
|
||||
if not self.api_key:
|
||||
raise ValueError("未配置 API_KEY,无法调用 DMXAPI")
|
||||
raise ValueError("未配置 API_KEY,无法调用 OpenAI 兼容接口")
|
||||
|
||||
payload = {
|
||||
"model": self.model_name,
|
||||
@@ -99,7 +164,7 @@ class DMXAPIService:
|
||||
self.api_url, headers=headers, data=json.dumps(payload), timeout=180
|
||||
)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"DMXAPI 请求失败: {response.status_code} {response.text}")
|
||||
raise ValueError(f"OpenAI 兼容接口请求失败: {response.status_code} {response.text}")
|
||||
return self._extract_questions(response.json())
|
||||
|
||||
def _build_instruction(self, filename: str) -> str:
|
||||
|
||||
1
backend/tests/__init__.py
Normal file
1
backend/tests/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Backend tests
|
||||
17
backend/tests/conftest.py
Normal file
17
backend/tests/conftest.py
Normal file
@@ -0,0 +1,17 @@
|
||||
import pytest
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from backend.database import Base
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def db_session():
|
||||
engine = create_engine("sqlite:///:memory:")
|
||||
Base.metadata.create_all(engine)
|
||||
Session = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
session = Session()
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
session.close()
|
||||
82
backend/tests/test_import_job_repo.py
Normal file
82
backend/tests/test_import_job_repo.py
Normal file
@@ -0,0 +1,82 @@
|
||||
"""Tests for import job repository: create, get, list, claim FIFO, cancel, failed items."""
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.models import JOB_STATUS_QUEUED, JOB_STATUS_RUNNING
|
||||
from backend.repositories import import_job_repo as repo
|
||||
|
||||
|
||||
def test_create_job(db_session):
|
||||
items = [
|
||||
{"filename": "a.xlsx", "stored_path": "/up/1/a.xlsx"},
|
||||
{"filename": "b.xlsx", "stored_path": "/up/1/b.xlsx"},
|
||||
]
|
||||
job = repo.create_job(db_session, "excel", items)
|
||||
assert job.id is not None
|
||||
assert job.status == JOB_STATUS_QUEUED
|
||||
assert job.method == "excel"
|
||||
assert job.total == 2
|
||||
assert job.processed == 0
|
||||
assert len(job.items) == 2
|
||||
assert job.items[0].filename == "a.xlsx" and job.items[0].seq == 1
|
||||
assert job.items[1].filename == "b.xlsx" and job.items[1].seq == 2
|
||||
|
||||
|
||||
def test_get_job(db_session):
|
||||
job = repo.create_job(db_session, "ai", [{"filename": "f.pdf", "stored_path": "/f.pdf"}])
|
||||
loaded = repo.get_job(db_session, job.id)
|
||||
assert loaded is not None
|
||||
assert loaded.id == job.id
|
||||
assert len(loaded.items) == 1
|
||||
assert repo.get_job(db_session, 99999) is None
|
||||
|
||||
|
||||
def test_list_jobs(db_session):
|
||||
repo.create_job(db_session, "excel", [{"filename": "a", "stored_path": "/a"}])
|
||||
repo.create_job(db_session, "excel", [{"filename": "b", "stored_path": "/b"}])
|
||||
all_jobs = repo.list_jobs(db_session, limit=10)
|
||||
assert len(all_jobs) >= 2
|
||||
statuses = repo.list_jobs(db_session, statuses=[JOB_STATUS_QUEUED], limit=10)
|
||||
assert all(j.status == JOB_STATUS_QUEUED for j in statuses)
|
||||
|
||||
|
||||
def test_claim_oldest_queued_fifo(db_session):
|
||||
j1 = repo.create_job(db_session, "excel", [{"filename": "a", "stored_path": "/a"}])
|
||||
j2 = repo.create_job(db_session, "excel", [{"filename": "b", "stored_path": "/b"}])
|
||||
claimed = repo.claim_oldest_queued(db_session)
|
||||
assert claimed is not None
|
||||
assert claimed.id == j1.id
|
||||
assert claimed.status == JOB_STATUS_RUNNING
|
||||
second = repo.claim_oldest_queued(db_session)
|
||||
assert second is not None
|
||||
assert second.id == j2.id
|
||||
assert repo.claim_oldest_queued(db_session) is None
|
||||
|
||||
|
||||
def test_cancel_job(db_session):
|
||||
job = repo.create_job(db_session, "excel", [{"filename": "a", "stored_path": "/a"}])
|
||||
cancelled = repo.cancel_job(db_session, job.id)
|
||||
assert cancelled is not None
|
||||
assert cancelled.status == "cancelled"
|
||||
loaded = repo.get_job(db_session, job.id)
|
||||
assert loaded.status == "cancelled"
|
||||
assert repo.cancel_job(db_session, job.id) is None # already cancelled
|
||||
|
||||
|
||||
def test_get_failed_items(db_session):
|
||||
job = repo.create_job(
|
||||
db_session,
|
||||
"excel",
|
||||
[
|
||||
{"filename": "a", "stored_path": "/a"},
|
||||
{"filename": "b", "stored_path": "/b"},
|
||||
],
|
||||
)
|
||||
failed = repo.get_failed_items(db_session, job.id)
|
||||
assert len(failed) == 0
|
||||
job.items[0].status = "failed"
|
||||
job.items[0].error = "err"
|
||||
db_session.commit()
|
||||
failed = repo.get_failed_items(db_session, job.id)
|
||||
assert len(failed) == 1
|
||||
assert failed[0].filename == "a"
|
||||
Reference in New Issue
Block a user