diff --git a/.env.example b/.env.example index 2cb64da..c20e439 100644 --- a/.env.example +++ b/.env.example @@ -1,6 +1,6 @@ API_KEY= MODEL_NAME=gpt-4.1 -DMXAPI_URL=https://www.dmxapi.cn/v1/responses +OPENAI_API_URL=https://api.openai.com/v1/responses JWT_SECRET_KEY=please_change_me JWT_ALGORITHM=HS256 ACCESS_TOKEN_EXPIRE_MINUTES=720 diff --git a/README.md b/README.md index 0fcbb60..3db0021 100644 --- a/README.md +++ b/README.md @@ -30,6 +30,14 @@ 3. 打开: - `http://127.0.0.1:5173` +## 导入中心(队列)说明 + +- 导入任务使用**持久化队列**:任务与文件项写入数据库(`import_jobs` / `import_job_items`),后端单消费者按创建时间 FIFO 串行执行。 +- 状态枚举:`queued`(排队中)、`running`(执行中)、`success`、`failed`、`cancelled`、`retrying`。 +- 刷新页面后,前端会请求 `GET /api/import/jobs?status=queued,running` 恢复未完成任务并轮询进度;无需依赖本地缓存。 +- 失败项可通过「重试失败项」再次入队(新建任务);排队中/执行中任务可取消。 +- 上传文件按任务与序号存为唯一路径(`upload_dir/{job_id}/{seq}_{filename}`),避免同名覆盖。 + ## 默认登录 - 用户名:`admin` @@ -38,8 +46,9 @@ ## 功能清单 - 题目 CRUD、搜索、筛选、批量删除、批量更新 -- AI 智能导入(PDF/Word -> DMXAPI -> 预览 -> 确认保存) +- AI 智能导入(PDF/Word -> OpenAI兼容接口 -> 预览 -> 确认保存) - Excel 批量导入、模板下载、导出 JSON/CSV/Excel +- **导入中心(持久化队列)**:严格 FIFO 串行执行,任务状态持久化到数据库;刷新页面后自动恢复未完成任务列表;支持取消排队中/执行中任务、对失败项一键重试入队。 - 分类树管理(章节/知识点) - 练习模式(抽题、判题、解析反馈) - 仪表盘统计(总量、题型、难度、章节、导入历史) diff --git a/backend/config.py b/backend/config.py index b203816..c6ab747 100644 --- a/backend/config.py +++ b/backend/config.py @@ -7,7 +7,7 @@ from pydantic_settings import BaseSettings, SettingsConfigDict class Settings(BaseSettings): api_key: str = "" model_name: str = "gpt-4.1" - dmxapi_url: str = "https://www.dmxapi.cn/v1/responses" + openai_api_url: str = "https://api.openai.com/v1/responses" jwt_secret_key: str = "change-me-in-env" jwt_algorithm: str = "HS256" access_token_expire_minutes: int = 60 * 12 diff --git a/backend/main.py b/backend/main.py index 1e7c447..1810ae3 100644 --- a/backend/main.py +++ b/backend/main.py @@ -1,8 +1,10 @@ from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware +import threading -from backend.database import Base, engine +from backend.database import Base, engine, SessionLocal from backend.routers import auth, categories, exports, imports, practice, questions, stats +from backend.services.import_queue_service import reset_stale_running_jobs, run_worker_loop app = FastAPI(title="Problem Bank API", version="1.0.0") @@ -16,6 +18,17 @@ app.add_middleware( Base.metadata.create_all(bind=engine) + +@app.on_event("startup") +def startup() -> None: + db = SessionLocal() + try: + reset_stale_running_jobs(db) + finally: + db.close() + thread = threading.Thread(target=run_worker_loop, daemon=True) + thread.start() + app.include_router(auth.router, prefix="/api/auth", tags=["auth"]) app.include_router(questions.router, prefix="/api/questions", tags=["questions"]) app.include_router(imports.router, prefix="/api/import", tags=["imports"]) diff --git a/backend/models.py b/backend/models.py index 3264509..8a1aff3 100644 --- a/backend/models.py +++ b/backend/models.py @@ -1,11 +1,29 @@ from datetime import datetime -from sqlalchemy import DateTime, ForeignKey, Integer, Text +from sqlalchemy import DateTime, ForeignKey, Integer, String, Text from sqlalchemy.orm import Mapped, mapped_column, relationship from backend.database import Base +# 导入任务/文件项统一状态 +JOB_STATUS_QUEUED = "queued" +JOB_STATUS_RUNNING = "running" +JOB_STATUS_SUCCESS = "success" +JOB_STATUS_FAILED = "failed" +JOB_STATUS_CANCELLED = "cancelled" +JOB_STATUS_RETRYING = "retrying" + +JOB_STATUSES = ( + JOB_STATUS_QUEUED, + JOB_STATUS_RUNNING, + JOB_STATUS_SUCCESS, + JOB_STATUS_FAILED, + JOB_STATUS_CANCELLED, + JOB_STATUS_RETRYING, +) + + class Question(Base): __tablename__ = "questions" @@ -52,3 +70,51 @@ class ImportHistory(Base): question_count: Mapped[int] = mapped_column(Integer, default=0) status: Mapped[str] = mapped_column(Text, default="success") created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow) + + +class ImportJob(Base): + """导入任务(持久化队列项)。""" + + __tablename__ = "import_jobs" + + id: Mapped[int] = mapped_column(Integer, primary_key=True, index=True) + status: Mapped[str] = mapped_column(String(32), default=JOB_STATUS_QUEUED, index=True) + method: Mapped[str] = mapped_column(String(32), nullable=False) + total: Mapped[int] = mapped_column(Integer, default=0) + processed: Mapped[int] = mapped_column(Integer, default=0) + success_count: Mapped[int] = mapped_column(Integer, default=0) + failed_count: Mapped[int] = mapped_column(Integer, default=0) + current_index: Mapped[int] = mapped_column(Integer, default=0) + current_file: Mapped[str] = mapped_column(Text, default="") + error: Mapped[str] = mapped_column(Text, default="") + attempt: Mapped[int] = mapped_column(Integer, default=1) + created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow) + started_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) + ended_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) + updated_at: Mapped[datetime] = mapped_column( + DateTime, default=datetime.utcnow, onupdate=datetime.utcnow + ) + + items: Mapped[list["ImportJobItem"]] = relationship( + "ImportJobItem", back_populates="job", order_by="ImportJobItem.seq", lazy="selectin" + ) + + +class ImportJobItem(Base): + """导入任务内的单个文件项。""" + + __tablename__ = "import_job_items" + + id: Mapped[int] = mapped_column(Integer, primary_key=True, index=True) + job_id: Mapped[int] = mapped_column(ForeignKey("import_jobs.id", ondelete="CASCADE"), nullable=False) + seq: Mapped[int] = mapped_column(Integer, nullable=False) + filename: Mapped[str] = mapped_column(Text, default="") + stored_path: Mapped[str] = mapped_column(Text, default="") + status: Mapped[str] = mapped_column(String(32), default=JOB_STATUS_QUEUED) + attempt: Mapped[int] = mapped_column(Integer, default=1) + error: Mapped[str] = mapped_column(Text, default="") + question_count: Mapped[int] = mapped_column(Integer, default=0) + started_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) + ended_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) + + job: Mapped["ImportJob"] = relationship("ImportJob", back_populates="items") diff --git a/backend/repositories/__init__.py b/backend/repositories/__init__.py new file mode 100644 index 0000000..73ede9a --- /dev/null +++ b/backend/repositories/__init__.py @@ -0,0 +1 @@ +# Repositories for DB access diff --git a/backend/repositories/import_job_repo.py b/backend/repositories/import_job_repo.py new file mode 100644 index 0000000..0c3d911 --- /dev/null +++ b/backend/repositories/import_job_repo.py @@ -0,0 +1,181 @@ +"""Import job persistence: create, get, list, update, claim for FIFO worker.""" + +from datetime import datetime + +from sqlalchemy.orm import Session + +from backend.models import ( + JOB_STATUS_FAILED, + JOB_STATUS_QUEUED, + JOB_STATUS_RUNNING, + ImportJob, + ImportJobItem, +) + + +def create_job( + db: Session, + method: str, + items: list[dict], +) -> ImportJob: + """Create a new import job with items. items: [{"filename": str, "stored_path": str}, ...].""" + job = ImportJob( + status=JOB_STATUS_QUEUED, + method=method, + total=len(items), + processed=0, + success_count=0, + failed_count=0, + ) + db.add(job) + db.flush() + for seq, it in enumerate(items, start=1): + db.add( + ImportJobItem( + job_id=job.id, + seq=seq, + filename=it.get("filename", ""), + stored_path=it.get("stored_path", ""), + status=JOB_STATUS_QUEUED, + ) + ) + db.commit() + db.refresh(job) + return job + + +def create_job_empty(db: Session, method: str) -> ImportJob: + """Create job with no items; caller then adds items and sets total.""" + job = ImportJob( + status=JOB_STATUS_QUEUED, + method=method, + total=0, + processed=0, + success_count=0, + failed_count=0, + ) + db.add(job) + db.flush() + return job + + +def add_job_items(db: Session, job_id: int, items: list[dict]) -> None: + """Append items to job and set job.total. items: [{"filename": str, "stored_path": str}, ...].""" + job = db.query(ImportJob).filter(ImportJob.id == job_id).first() + if not job: + return + for seq, it in enumerate(items, start=1): + db.add( + ImportJobItem( + job_id=job_id, + seq=seq, + filename=it.get("filename", ""), + stored_path=it.get("stored_path", ""), + status=JOB_STATUS_QUEUED, + ) + ) + job.total = len(items) + db.commit() + + +def get_job(db: Session, job_id: int) -> ImportJob | None: + """Get job by id with items loaded.""" + return db.query(ImportJob).filter(ImportJob.id == job_id).first() + + +def list_jobs( + db: Session, + statuses: list[str] | None = None, + limit: int = 50, +) -> list[ImportJob]: + """List jobs, optionally filtered by status(es), newest first.""" + q = db.query(ImportJob).order_by(ImportJob.created_at.desc()) + if statuses: + q = q.filter(ImportJob.status.in_(statuses)) + return q.limit(limit).all() + + +def update_job(db: Session, job_id: int, **kwargs) -> ImportJob | None: + """Update job fields. Returns updated job or None if not found.""" + job = db.query(ImportJob).filter(ImportJob.id == job_id).first() + if not job: + return None + for k, v in kwargs.items(): + if hasattr(job, k): + setattr(job, k, v) + job.updated_at = datetime.utcnow() + db.commit() + db.refresh(job) + return job + + +def update_job_item(db: Session, item_id: int, **kwargs) -> ImportJobItem | None: + """Update a single job item.""" + item = db.query(ImportJobItem).filter(ImportJobItem.id == item_id).first() + if not item: + return None + for k, v in kwargs.items(): + if hasattr(item, k): + setattr(item, k, v) + db.commit() + db.refresh(item) + return item + + +def claim_oldest_queued(db: Session) -> ImportJob | None: + """ + Claim the oldest queued job by setting status to running. + Returns the job if claimed, None if no queued job. + Used by single-worker FIFO loop. (SQLite-compatible: no row locking.) + """ + job = ( + db.query(ImportJob) + .filter(ImportJob.status == JOB_STATUS_QUEUED) + .order_by(ImportJob.created_at.asc()) + .first() + ) + if not job: + return None + job.status = JOB_STATUS_RUNNING + job.started_at = datetime.utcnow() + job.updated_at = datetime.utcnow() + db.commit() + db.refresh(job) + return job + + +def get_queued_job_for_worker(db: Session) -> ImportJob | None: + """ + Get oldest queued job without locking (SQLite has limited FOR UPDATE support). + Caller must then update status to running in same or follow-up transaction. + """ + job = ( + db.query(ImportJob) + .filter(ImportJob.status == JOB_STATUS_QUEUED) + .order_by(ImportJob.created_at.asc()) + .first() + ) + return job + + +def cancel_job(db: Session, job_id: int) -> ImportJob | None: + """Set job status to cancelled if it is queued or running.""" + job = db.query(ImportJob).filter(ImportJob.id == job_id).first() + if not job or job.status not in (JOB_STATUS_QUEUED, JOB_STATUS_RUNNING): + return None + job.status = "cancelled" + job.ended_at = datetime.utcnow() + job.updated_at = datetime.utcnow() + db.commit() + db.refresh(job) + return job + + +def get_failed_items(db: Session, job_id: int) -> list[ImportJobItem]: + """Return items with status failed for retry.""" + return ( + db.query(ImportJobItem) + .filter(ImportJobItem.job_id == job_id, ImportJobItem.status == JOB_STATUS_FAILED) + .order_by(ImportJobItem.seq.asc()) + .all() + ) diff --git a/backend/requirements.txt b/backend/requirements.txt index cd720d8..811d3b6 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -10,3 +10,4 @@ python-multipart requests openpyxl python-dotenv +pytest diff --git a/backend/routers/imports.py b/backend/routers/imports.py index 5c14736..a729be0 100644 --- a/backend/routers/imports.py +++ b/backend/routers/imports.py @@ -1,25 +1,153 @@ +from datetime import datetime from pathlib import Path - -from fastapi import APIRouter, Depends, File, HTTPException, UploadFile +from threading import Lock +from typing import Literal +from uuid import uuid4 +from fastapi import APIRouter, BackgroundTasks, Depends, File, Form, HTTPException, Query, UploadFile from sqlalchemy.orm import Session from backend.auth import get_current_user -from backend.database import get_db +from backend.database import SessionLocal, get_db from backend.models import ImportHistory, Question -from backend.schemas import ImportHistoryOut, QuestionOut +from backend.repositories import import_job_repo +from backend.schemas import ImportHistoryOut, ImportJobOut, QuestionOut from backend.services.excel_service import create_template_bytes, parse_excel_file -from backend.services.file_utils import save_upload -from backend.services.parser import DMXAPIService, extract_metadata +from backend.services.file_utils import save_upload, save_upload_for_job +from backend.services.parser import OpenAICompatibleParserService, extract_metadata router = APIRouter(dependencies=[Depends(get_current_user)]) +BATCH_TASKS: dict[str, dict] = {} +BATCH_TASKS_LOCK = Lock() +BatchMethod = Literal["excel", "ai"] + + + +def _set_task(task_id: str, **fields) -> None: + with BATCH_TASKS_LOCK: + task = BATCH_TASKS.get(task_id) + if not task: + return + task.update(fields) + + +def _build_ai_rows(path: Path) -> list[dict]: + parser = OpenAICompatibleParserService() + metadata = extract_metadata(path.name) + questions = parser.parse_file(str(path)) + rows = [] + for q in questions: + rows.append( + { + "chapter": metadata["chapter"], + "primary_knowledge": "", + "secondary_knowledge": metadata["secondary_knowledge"], + "question_type": metadata["question_type"], + "difficulty": metadata["difficulty"], + "stem": q.get("题干", ""), + "option_a": q.get("选项A", ""), + "option_b": q.get("选项B", ""), + "option_c": q.get("选项C", ""), + "option_d": q.get("选项D", ""), + "answer": q.get("正确答案", ""), + "explanation": q.get("解析", ""), + "notes": q.get("备注", ""), + "source_file": metadata["source_file"], + } + ) + return rows + + +def _run_batch_import(task_id: str, files: list[dict], method: BatchMethod) -> None: + db = SessionLocal() + try: + for index, item in enumerate(files, start=1): + path = Path(item["path"]) + filename = item["filename"] + _set_task(task_id, current_file=filename) + try: + if method == "excel": + if path.suffix.lower() not in [".xlsx", ".xlsm", ".xltx", ".xltm"]: + raise ValueError("仅支持 Excel 文件") + rows = parse_excel_file(path) + else: + rows = _build_ai_rows(path) + + questions = [Question(**row) for row in rows] + if questions: + db.add_all(questions) + db.add( + ImportHistory( + filename=filename, + method=method, + question_count=len(questions), + status="success", + ) + ) + db.commit() + + with BATCH_TASKS_LOCK: + task = BATCH_TASKS[task_id] + task["success_count"] += 1 + task["total_questions"] += len(questions) + task["results"].append( + { + "filename": filename, + "status": "success", + "question_count": len(questions), + } + ) + except Exception as exc: + db.rollback() + db.add( + ImportHistory( + filename=filename, + method=method, + question_count=0, + status="failed", + ) + ) + db.commit() + with BATCH_TASKS_LOCK: + task = BATCH_TASKS[task_id] + task["failed_count"] += 1 + task["results"].append( + { + "filename": filename, + "status": "failed", + "error": str(exc), + "question_count": 0, + } + ) + finally: + _set_task(task_id, processed=index) + + _set_task( + task_id, + status="completed", + current_file="", + ended_at=datetime.utcnow().isoformat(), + ) + except Exception as exc: + _set_task( + task_id, + status="failed", + error=str(exc), + ended_at=datetime.utcnow().isoformat(), + ) + finally: + db.close() + @router.post("/ai/parse") def parse_by_ai(file: UploadFile = File(...)) -> dict: path = save_upload(file) - parser = DMXAPIService() + parser = OpenAICompatibleParserService() metadata = extract_metadata(file.filename or path.name) - questions = parser.parse_file(str(path)) + try: + questions = parser.parse_file(str(path)) + except ValueError as exc: + raise HTTPException(status_code=502, detail=str(exc)) from exc preview = [] for q in questions: preview.append( @@ -95,3 +223,110 @@ def download_template() -> dict: def import_history(db: Session = Depends(get_db)) -> list[ImportHistoryOut]: rows = db.query(ImportHistory).order_by(ImportHistory.created_at.desc()).limit(100).all() return [ImportHistoryOut.model_validate(r) for r in rows] + + +@router.post("/batch/start") +def start_batch_import( + background_tasks: BackgroundTasks, + files: list[UploadFile] = File(...), + method: BatchMethod = Form("ai"), +) -> dict: + if not files: + raise HTTPException(status_code=400, detail="请至少上传一个文件") + + task_id = uuid4().hex + saved_files: list[dict] = [] + for f in files: + saved_path = save_upload(f) + saved_files.append({"path": str(saved_path), "filename": f.filename or Path(saved_path).name}) + + with BATCH_TASKS_LOCK: + BATCH_TASKS[task_id] = { + "task_id": task_id, + "status": "running", + "method": method, + "total": len(saved_files), + "processed": 0, + "current_file": "", + "success_count": 0, + "failed_count": 0, + "total_questions": 0, + "results": [], + "error": "", + "started_at": datetime.utcnow().isoformat(), + "ended_at": "", + } + background_tasks.add_task(_run_batch_import, task_id, saved_files, method) + return {"task_id": task_id} + + +@router.get("/batch/{task_id}") +def get_batch_import_progress(task_id: str) -> dict: + with BATCH_TASKS_LOCK: + task = BATCH_TASKS.get(task_id) + if not task: + raise HTTPException(status_code=404, detail="任务不存在") + total = task["total"] or 1 + progress = round((task["processed"] / total) * 100, 2) + return {**task, "progress": progress} + + +# ----- 持久化队列 Jobs API ----- + +@router.post("/jobs", response_model=ImportJobOut) +def create_import_job( + files: list[UploadFile] = File(...), + method: BatchMethod = Form("ai"), + db: Session = Depends(get_db), +) -> ImportJobOut: + if not files: + raise HTTPException(status_code=400, detail="请至少上传一个文件") + job = import_job_repo.create_job_empty(db, method) + items: list[dict] = [] + for seq, f in enumerate(files, start=1): + path = save_upload_for_job(job.id, seq, f) + items.append({"filename": f.filename or path.name, "stored_path": str(path)}) + import_job_repo.add_job_items(db, job.id, items) + job = import_job_repo.get_job(db, job.id) + return ImportJobOut.model_validate(job) + + +@router.get("/jobs/{job_id}", response_model=ImportJobOut) +def get_import_job(job_id: int, db: Session = Depends(get_db)) -> ImportJobOut: + job = import_job_repo.get_job(db, job_id) + if not job: + raise HTTPException(status_code=404, detail="任务不存在") + return ImportJobOut.model_validate(job) + + +@router.get("/jobs", response_model=list[ImportJobOut]) +def list_import_jobs( + status: str | None = Query(None, description="queued,running 等,逗号分隔"), + limit: int = Query(50, le=100), + db: Session = Depends(get_db), +) -> list[ImportJobOut]: + statuses = [s.strip() for s in status.split(",") if s.strip()] if status else None + jobs = import_job_repo.list_jobs(db, statuses=statuses, limit=limit) + return [ImportJobOut.model_validate(j) for j in jobs] + + +@router.post("/jobs/{job_id}/retry", response_model=ImportJobOut) +def retry_import_job(job_id: int, db: Session = Depends(get_db)) -> ImportJobOut: + job = import_job_repo.get_job(db, job_id) + if not job: + raise HTTPException(status_code=404, detail="任务不存在") + failed = import_job_repo.get_failed_items(db, job_id) + if not failed: + raise HTTPException(status_code=400, detail="没有可重试的失败项") + items = [{"filename": it.filename, "stored_path": it.stored_path} for it in failed] + new_job = import_job_repo.create_job(db, job.method, items) + new_job = import_job_repo.get_job(db, new_job.id) + return ImportJobOut.model_validate(new_job) + + +@router.post("/jobs/{job_id}/cancel", response_model=ImportJobOut) +def cancel_import_job(job_id: int, db: Session = Depends(get_db)) -> ImportJobOut: + job = import_job_repo.cancel_job(db, job_id) + if not job: + raise HTTPException(status_code=400, detail="任务不存在或无法取消") + return ImportJobOut.model_validate(job) diff --git a/backend/schemas.py b/backend/schemas.py index b4a8e0e..2942250 100644 --- a/backend/schemas.py +++ b/backend/schemas.py @@ -103,6 +103,45 @@ class ImportHistoryOut(BaseModel): from_attributes = True +class ImportJobItemOut(BaseModel): + id: int + job_id: int + seq: int + filename: str + stored_path: str + status: str + attempt: int + error: str + question_count: int + started_at: datetime | None + ended_at: datetime | None + + class Config: + from_attributes = True + + +class ImportJobOut(BaseModel): + id: int + status: str + method: str + total: int + processed: int + success_count: int + failed_count: int + current_index: int + current_file: str + error: str + attempt: int + created_at: datetime + started_at: datetime | None + ended_at: datetime | None + updated_at: datetime + items: list[ImportJobItemOut] = [] + + class Config: + from_attributes = True + + class PracticeStartRequest(BaseModel): chapter: Optional[str] = None secondary_knowledge: Optional[str] = None diff --git a/backend/services/file_utils.py b/backend/services/file_utils.py index 41c7488..835a217 100644 --- a/backend/services/file_utils.py +++ b/backend/services/file_utils.py @@ -18,3 +18,14 @@ def save_upload(upload_file: UploadFile) -> Path: with target_path.open("wb") as buffer: shutil.copyfileobj(upload_file.file, buffer) return target_path + + +def save_upload_for_job(job_id: int, seq: int, upload_file: UploadFile) -> Path: + """Save file with unique path under upload_dir/job_id/ to avoid overwrites.""" + base = ensure_upload_dir() / str(job_id) + base.mkdir(parents=True, exist_ok=True) + name = upload_file.filename or "file" + target_path = base / f"{seq}_{name}" + with target_path.open("wb") as buffer: + shutil.copyfileobj(upload_file.file, buffer) + return target_path diff --git a/backend/services/import_queue_service.py b/backend/services/import_queue_service.py new file mode 100644 index 0000000..5994217 --- /dev/null +++ b/backend/services/import_queue_service.py @@ -0,0 +1,180 @@ +""" +Import queue: single-consumer FIFO worker and job execution. +Run run_worker_loop() in a background thread; on startup call reset_stale_running_jobs(). +""" + +from datetime import datetime +from pathlib import Path +import time + +from sqlalchemy.orm import Session + +from backend.database import SessionLocal +from backend.models import ( + JOB_STATUS_FAILED, + JOB_STATUS_QUEUED, + JOB_STATUS_RUNNING, + JOB_STATUS_SUCCESS, + ImportHistory, + ImportJob, + ImportJobItem, + Question, +) +from backend.repositories import import_job_repo as repo +from backend.services.excel_service import parse_excel_file +from backend.services.parser import OpenAICompatibleParserService, extract_metadata + + +def _build_ai_rows(path: Path) -> list[dict]: + parser = OpenAICompatibleParserService() + metadata = extract_metadata(path.name) + questions = parser.parse_file(str(path)) + rows = [] + for q in questions: + rows.append( + { + "chapter": metadata["chapter"], + "primary_knowledge": "", + "secondary_knowledge": metadata["secondary_knowledge"], + "question_type": metadata["question_type"], + "difficulty": metadata["difficulty"], + "stem": q.get("题干", ""), + "option_a": q.get("选项A", ""), + "option_b": q.get("选项B", ""), + "option_c": q.get("选项C", ""), + "option_d": q.get("选项D", ""), + "answer": q.get("正确答案", ""), + "explanation": q.get("解析", ""), + "notes": q.get("备注", ""), + "source_file": metadata["source_file"], + } + ) + return rows + + +def _process_one_item( + db: Session, + job: ImportJob, + item: ImportJobItem, + method: str, +) -> None: + path = Path(item.stored_path) + filename = item.filename + job.current_file = filename + job.current_index = item.seq + job.updated_at = datetime.utcnow() + item.status = JOB_STATUS_RUNNING + item.started_at = datetime.utcnow() + db.commit() + + try: + if method == "excel": + if path.suffix.lower() not in [".xlsx", ".xlsm", ".xltx", ".xltm"]: + raise ValueError("仅支持 Excel 文件") + rows = parse_excel_file(path) + else: + rows = _build_ai_rows(path) + + questions = [Question(**row) for row in rows] + if questions: + db.add_all(questions) + db.add( + ImportHistory( + filename=filename, + method=method, + question_count=len(questions), + status="success", + ) + ) + db.commit() + + item.status = JOB_STATUS_SUCCESS + item.question_count = len(questions) + item.ended_at = datetime.utcnow() + job.success_count += 1 + job.processed += 1 + job.updated_at = datetime.utcnow() + db.commit() + except Exception as exc: + db.rollback() + db.add( + ImportHistory( + filename=filename, + method=method, + question_count=0, + status="failed", + ) + ) + db.commit() + item.status = JOB_STATUS_FAILED + item.error = str(exc) + item.ended_at = datetime.utcnow() + job.failed_count += 1 + job.processed += 1 + job.updated_at = datetime.utcnow() + db.commit() + + +def process_job(db: Session, job_id: int) -> None: + """Execute a single job: process all items in order, then set job terminal status.""" + job = repo.get_job(db, job_id) + if not job or job.status != JOB_STATUS_RUNNING: + return + method = job.method + items = sorted(job.items, key=lambda x: x.seq) + # Resume: ensure processed/success_count/failed_count reflect already-completed items + job.processed = sum(1 for it in items if it.status in (JOB_STATUS_SUCCESS, JOB_STATUS_FAILED)) + job.success_count = sum(1 for it in items if it.status == JOB_STATUS_SUCCESS) + job.failed_count = sum(1 for it in items if it.status == JOB_STATUS_FAILED) + db.commit() + for item in items: + if item.status in (JOB_STATUS_SUCCESS, JOB_STATUS_FAILED): + continue + _process_one_item(db, job, item, method) + db.refresh(job) + + job = repo.get_job(db, job_id) + if not job: + return + if job.failed_count > 0 and job.success_count == 0: + job.status = JOB_STATUS_FAILED + job.error = "部分或全部文件处理失败" + else: + job.status = JOB_STATUS_SUCCESS + job.error = "" + job.ended_at = datetime.utcnow() + job.current_file = "" + job.updated_at = datetime.utcnow() + db.commit() + + +def reset_stale_running_jobs(db: Session) -> int: + """On startup: set any job left in 'running' back to 'queued' so worker can pick it up.""" + count = 0 + for job in db.query(ImportJob).filter(ImportJob.status == JOB_STATUS_RUNNING).all(): + job.status = JOB_STATUS_QUEUED + count += 1 + if count: + db.commit() + return count + + +def run_worker_loop(interval_seconds: float = 1.0) -> None: + """ + Single-consumer FIFO loop. Call from a background thread. + Claims oldest queued job, processes it, then repeats. Sleeps when no job. + """ + while True: + db = SessionLocal() + try: + job = repo.claim_oldest_queued(db) + if job: + process_job(db, job.id) + else: + time.sleep(interval_seconds) + except Exception: + if db: + db.rollback() + time.sleep(interval_seconds) + finally: + db.close() diff --git a/backend/services/parser.py b/backend/services/parser.py index b6a8357..5b520dd 100644 --- a/backend/services/parser.py +++ b/backend/services/parser.py @@ -1,6 +1,8 @@ import json import re +import shutil import subprocess +import tempfile from pathlib import Path import requests @@ -8,11 +10,13 @@ import requests from backend.config import settings -class DMXAPIService: + + +class OpenAICompatibleParserService: def __init__(self) -> None: self.api_key = settings.api_key self.model_name = settings.model_name - self.api_url = settings.dmxapi_url + self.api_url = settings.openai_api_url def parse_file(self, file_path: str) -> list[dict]: path = Path(file_path) @@ -51,33 +55,94 @@ class DMXAPIService: def _convert_to_pdf(self, path: Path) -> Path: pdf_path = path.with_suffix(".pdf") pdf_path.unlink(missing_ok=True) + source_path = path + temp_dir: str | None = None - cmd = [ - "pandoc", - str(path), - "-o", - str(pdf_path), - "--pdf-engine=xelatex", - "-V", - "CJKmainfont=PingFang SC", - ] - result = subprocess.run(cmd, capture_output=True, text=True, timeout=90) - if result.returncode != 0: - fallback = [ + if path.suffix.lower() == ".doc": + source_path, temp_dir = self._convert_doc_to_docx(path) + + try: + cmd = [ "pandoc", - str(path), + str(source_path), "-o", str(pdf_path), - "--pdf-engine=weasyprint", + "--pdf-engine=xelatex", + "-V", + "CJKmainfont=PingFang SC", ] - result = subprocess.run(fallback, capture_output=True, text=True, timeout=90) + result = subprocess.run(cmd, capture_output=True, text=True, timeout=90) if result.returncode != 0: - raise ValueError(f"文件转 PDF 失败: {result.stderr}") - return pdf_path + fallback = [ + "pandoc", + str(source_path), + "-o", + str(pdf_path), + "--pdf-engine=weasyprint", + ] + result = subprocess.run(fallback, capture_output=True, text=True, timeout=90) + if result.returncode != 0: + raise ValueError(f"文件转 PDF 失败: {result.stderr}") + return pdf_path + finally: + if temp_dir: + shutil.rmtree(temp_dir, ignore_errors=True) + + def _convert_doc_to_docx(self, path: Path) -> tuple[Path, str]: + temp_dir = tempfile.mkdtemp(prefix="pb_doc_convert_") + converted_path = Path(temp_dir) / f"{path.stem}.docx" + convert_errors: list[str] = [] + + if shutil.which("soffice"): + result = subprocess.run( + [ + "soffice", + "--headless", + "--convert-to", + "docx", + "--outdir", + temp_dir, + str(path), + ], + capture_output=True, + text=True, + timeout=120, + ) + if result.returncode == 0 and converted_path.exists(): + return converted_path, temp_dir + convert_errors.append(f"soffice: {(result.stderr or result.stdout).strip()}") + else: + convert_errors.append("soffice: 未安装") + + if shutil.which("textutil"): + result = subprocess.run( + [ + "textutil", + "-convert", + "docx", + "-output", + str(converted_path), + str(path), + ], + capture_output=True, + text=True, + timeout=120, + ) + if result.returncode == 0 and converted_path.exists(): + return converted_path, temp_dir + convert_errors.append(f"textutil: {(result.stderr or result.stdout).strip()}") + else: + convert_errors.append("textutil: 未安装") + + raise ValueError( + "检测到 .doc 文件,pandoc 不支持直接转换。" + "已尝试自动转换为 .docx 但失败。请先把 .doc 另存为 .docx 后重试。" + f" 详细信息: {' | '.join(convert_errors)}" + ) def _parse_with_file_url(self, file_url: str, original_filename: str) -> list[dict]: if not self.api_key: - raise ValueError("未配置 API_KEY,无法调用 DMXAPI") + raise ValueError("未配置 API_KEY,无法调用 OpenAI 兼容接口") payload = { "model": self.model_name, @@ -99,7 +164,7 @@ class DMXAPIService: self.api_url, headers=headers, data=json.dumps(payload), timeout=180 ) if response.status_code != 200: - raise ValueError(f"DMXAPI 请求失败: {response.status_code} {response.text}") + raise ValueError(f"OpenAI 兼容接口请求失败: {response.status_code} {response.text}") return self._extract_questions(response.json()) def _build_instruction(self, filename: str) -> str: diff --git a/backend/tests/__init__.py b/backend/tests/__init__.py new file mode 100644 index 0000000..a8a65ec --- /dev/null +++ b/backend/tests/__init__.py @@ -0,0 +1 @@ +# Backend tests diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py new file mode 100644 index 0000000..c36aef3 --- /dev/null +++ b/backend/tests/conftest.py @@ -0,0 +1,17 @@ +import pytest +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker + +from backend.database import Base + + +@pytest.fixture +def db_session(): + engine = create_engine("sqlite:///:memory:") + Base.metadata.create_all(engine) + Session = sessionmaker(autocommit=False, autoflush=False, bind=engine) + session = Session() + try: + yield session + finally: + session.close() diff --git a/backend/tests/test_import_job_repo.py b/backend/tests/test_import_job_repo.py new file mode 100644 index 0000000..dd09489 --- /dev/null +++ b/backend/tests/test_import_job_repo.py @@ -0,0 +1,82 @@ +"""Tests for import job repository: create, get, list, claim FIFO, cancel, failed items.""" + +import pytest + +from backend.models import JOB_STATUS_QUEUED, JOB_STATUS_RUNNING +from backend.repositories import import_job_repo as repo + + +def test_create_job(db_session): + items = [ + {"filename": "a.xlsx", "stored_path": "/up/1/a.xlsx"}, + {"filename": "b.xlsx", "stored_path": "/up/1/b.xlsx"}, + ] + job = repo.create_job(db_session, "excel", items) + assert job.id is not None + assert job.status == JOB_STATUS_QUEUED + assert job.method == "excel" + assert job.total == 2 + assert job.processed == 0 + assert len(job.items) == 2 + assert job.items[0].filename == "a.xlsx" and job.items[0].seq == 1 + assert job.items[1].filename == "b.xlsx" and job.items[1].seq == 2 + + +def test_get_job(db_session): + job = repo.create_job(db_session, "ai", [{"filename": "f.pdf", "stored_path": "/f.pdf"}]) + loaded = repo.get_job(db_session, job.id) + assert loaded is not None + assert loaded.id == job.id + assert len(loaded.items) == 1 + assert repo.get_job(db_session, 99999) is None + + +def test_list_jobs(db_session): + repo.create_job(db_session, "excel", [{"filename": "a", "stored_path": "/a"}]) + repo.create_job(db_session, "excel", [{"filename": "b", "stored_path": "/b"}]) + all_jobs = repo.list_jobs(db_session, limit=10) + assert len(all_jobs) >= 2 + statuses = repo.list_jobs(db_session, statuses=[JOB_STATUS_QUEUED], limit=10) + assert all(j.status == JOB_STATUS_QUEUED for j in statuses) + + +def test_claim_oldest_queued_fifo(db_session): + j1 = repo.create_job(db_session, "excel", [{"filename": "a", "stored_path": "/a"}]) + j2 = repo.create_job(db_session, "excel", [{"filename": "b", "stored_path": "/b"}]) + claimed = repo.claim_oldest_queued(db_session) + assert claimed is not None + assert claimed.id == j1.id + assert claimed.status == JOB_STATUS_RUNNING + second = repo.claim_oldest_queued(db_session) + assert second is not None + assert second.id == j2.id + assert repo.claim_oldest_queued(db_session) is None + + +def test_cancel_job(db_session): + job = repo.create_job(db_session, "excel", [{"filename": "a", "stored_path": "/a"}]) + cancelled = repo.cancel_job(db_session, job.id) + assert cancelled is not None + assert cancelled.status == "cancelled" + loaded = repo.get_job(db_session, job.id) + assert loaded.status == "cancelled" + assert repo.cancel_job(db_session, job.id) is None # already cancelled + + +def test_get_failed_items(db_session): + job = repo.create_job( + db_session, + "excel", + [ + {"filename": "a", "stored_path": "/a"}, + {"filename": "b", "stored_path": "/b"}, + ], + ) + failed = repo.get_failed_items(db_session, job.id) + assert len(failed) == 0 + job.items[0].status = "failed" + job.items[0].error = "err" + db_session.commit() + failed = repo.get_failed_items(db_session, job.id) + assert len(failed) == 1 + assert failed[0].filename == "a" diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx index c2989ea..fa632fb 100644 --- a/frontend/src/App.tsx +++ b/frontend/src/App.tsx @@ -1,5 +1,5 @@ import { Layout, Menu, message } from "antd"; -import { useMemo, useState } from "react"; +import { useEffect, useMemo, useState } from "react"; import { Navigate, Route, Routes, useLocation, useNavigate } from "react-router-dom"; import LoginModal from "./components/LoginModal"; import Categories from "./pages/Categories"; @@ -26,6 +26,14 @@ export default function App() { const location = useLocation(); const [loggedIn, setLoggedIn] = useState(Boolean(localStorage.getItem("pb_token"))); + useEffect(() => { + const handleUnauthorized = () => { + setLoggedIn(false); + }; + window.addEventListener("pb-auth-unauthorized", handleUnauthorized); + return () => window.removeEventListener("pb-auth-unauthorized", handleUnauthorized); + }, []); + const selectedKey = useMemo(() => { const hit = menuItems.find((m) => location.pathname.startsWith(m.key)); return hit ? [hit.key] : ["/dashboard"]; diff --git a/frontend/src/api/index.ts b/frontend/src/api/index.ts index c2b220a..f5c87ad 100644 --- a/frontend/src/api/index.ts +++ b/frontend/src/api/index.ts @@ -12,4 +12,80 @@ api.interceptors.request.use((config) => { return config; }); +api.interceptors.response.use( + (response) => response, + (error) => { + if (error?.response?.status === 401) { + localStorage.removeItem("pb_token"); + window.dispatchEvent(new CustomEvent("pb-auth-unauthorized")); + } + return Promise.reject(error); + } +); + +// 导入任务(持久化队列)类型与 API +export interface ImportJobItemOut { + id: number; + job_id: number; + seq: number; + filename: string; + stored_path: string; + status: string; + attempt: number; + error: string; + question_count: number; + started_at: string | null; + ended_at: string | null; +} + +export interface ImportJobOut { + id: number; + status: string; + method: string; + total: number; + processed: number; + success_count: number; + failed_count: number; + current_index: number; + current_file: string; + error: string; + attempt: number; + created_at: string; + started_at: string | null; + ended_at: string | null; + updated_at: string; + items: ImportJobItemOut[]; +} + +export async function createImportJob(files: File[], method: "excel" | "ai"): Promise { + const formData = new FormData(); + formData.append("method", method); + files.forEach((file) => formData.append("files", file)); + const { data } = await api.post("/import/jobs", formData, { + headers: { "Content-Type": "multipart/form-data" } + }); + return data; +} + +export async function getImportJob(jobId: number): Promise { + const { data } = await api.get(`/import/jobs/${jobId}`); + return data; +} + +export async function listImportJobs(status?: string): Promise { + const params = status ? { status } : {}; + const { data } = await api.get("/import/jobs", { params }); + return data; +} + +export async function retryImportJob(jobId: number): Promise { + const { data } = await api.post(`/import/jobs/${jobId}/retry`); + return data; +} + +export async function cancelImportJob(jobId: number): Promise { + const { data } = await api.post(`/import/jobs/${jobId}/cancel`); + return data; +} + export default api; diff --git a/frontend/src/pages/Import.tsx b/frontend/src/pages/Import.tsx index 4a54dfc..d929385 100644 --- a/frontend/src/pages/Import.tsx +++ b/frontend/src/pages/Import.tsx @@ -1,92 +1,334 @@ -import { Button, Card, Table, Tabs, Upload, message } from "antd"; -import type { UploadProps } from "antd"; -import { useState } from "react"; -import api from "../api"; +import { Button, Card, Progress, Select, Space, Table, Tag, Upload, message } from "antd"; +import type { UploadFile, UploadProps } from "antd"; +import { useCallback, useEffect, useRef, useState } from "react"; +import { + createImportJob, + getImportJob, + listImportJobs, + retryImportJob, + cancelImportJob, + type ImportJobOut +} from "../api"; + +type ImportMethod = "excel" | "ai"; + +const POLL_INTERVAL_MS = 1500; +const TERMINAL_STATUSES = ["success", "failed", "cancelled"]; +const IMPORT_METHOD_KEY = "pb_import_method"; + +function isTerminal(status: string) { + return TERMINAL_STATUSES.includes(status); +} + +function loadStoredMethod(): ImportMethod { + try { + const v = sessionStorage.getItem(IMPORT_METHOD_KEY); + if (v === "excel" || v === "ai") return v; + } catch { + /* ignore */ + } + return "ai"; +} export default function ImportPage() { - const [preview, setPreview] = useState[]>([]); + const [method, setMethod] = useState(loadStoredMethod); + const [jobs, setJobs] = useState([]); + const [pendingFiles, setPendingFiles] = useState([]); const [loading, setLoading] = useState(false); + const pollTimerRef = useRef | null>(null); - const aiProps: UploadProps = { - name: "file", - customRequest: async ({ file, onSuccess, onError }) => { + // Persist method so remounts (e.g. Strict Mode, route switch) restore selection + const setMethodAndStore = useCallback((value: ImportMethod) => { + setMethod(value); + try { + sessionStorage.setItem(IMPORT_METHOD_KEY, value); + } catch { + /* ignore */ + } + }, []); + + + + const activeJobs = jobs.filter((j) => !isTerminal(j.status)); + const hasActive = activeJobs.length > 0; + const pendingCount = pendingFiles.length; + + const fetchJob = useCallback(async (jobId: number) => { + try { + const job = await getImportJob(jobId); + setJobs((prev) => { + const next = prev.map((j) => (j.id === jobId ? job : j)); + if (!next.find((j) => j.id === jobId)) next.unshift(job); + return next; + }); + return job; + } catch { + return null; + } + }, []); + + const startPolling = useCallback(() => { + if (pollTimerRef.current) return; + pollTimerRef.current = setInterval(() => { + setJobs((prev) => { + const toPoll = prev.filter((j) => !isTerminal(j.status)); + toPoll.forEach((j) => { + getImportJob(j.id).then((job) => { + setJobs((p) => p.map((x) => (x.id === job.id ? job : x))); + }).catch(() => {}); + }); + return prev; + }); + }, POLL_INTERVAL_MS); + }, []); + + const stopPolling = useCallback(() => { + if (pollTimerRef.current) { + clearInterval(pollTimerRef.current); + pollTimerRef.current = null; + } + }, []); + + useEffect(() => { + if (!hasActive) { + stopPolling(); + return; + } + startPolling(); + return () => { stopPolling(); }; + }, [hasActive, startPolling, stopPolling]); + + useEffect(() => { + let cancelled = false; + (async () => { try { - setLoading(true); - const formData = new FormData(); - formData.append("file", file as File); - const { data } = await api.post("/import/ai/parse", formData); - setPreview(data.preview || []); - message.success("AI 解析完成"); - onSuccess?.({}); - } catch (err) { - onError?.(err as Error); - } finally { - setLoading(false); + const list = await listImportJobs("queued,running"); + if (!cancelled) { + setJobs((prev) => { + const byId = new Map(prev.map((j) => [j.id, j])); + list.forEach((j) => byId.set(j.id, j)); + return Array.from(byId.values()).sort((a, b) => b.id - a.id); + }); + } + } catch { + // ignore } + })(); + return () => { cancelled = true; }; + }, []); + + const uploadProps: UploadProps = { + multiple: true, + beforeUpload: () => false, + fileList: pendingFiles, + onChange: (info) => { + setPendingFiles(info.fileList.slice(-100)); + }, + disabled: loading + }; + + const startImport = async () => { + if (!pendingFiles.length) { + message.info("请先添加文档"); + return; + } + const files = pendingFiles.map((f) => f.originFileObj).filter(Boolean) as File[]; + if (!files.length) { + message.info("没有可上传的文件"); + return; + } + setLoading(true); + try { + const job = await createImportJob(files, method); + setJobs((prev) => [job, ...prev]); + setPendingFiles([]); + message.success("任务已入队,将按顺序处理"); + startPolling(); + } catch (err: unknown) { + const detail = err && typeof err === "object" && "response" in err + ? (err as { response?: { data?: { detail?: string } } }).response?.data?.detail + : "创建任务失败"; + message.error(String(detail)); + } finally { + setLoading(false); } }; - const excelProps: UploadProps = { - name: "file", - customRequest: async ({ file, onSuccess, onError }) => { - try { - const formData = new FormData(); - formData.append("file", file as File); - const { data } = await api.post("/import/excel", formData); - message.success(`Excel 导入成功,共 ${data.length} 题`); - onSuccess?.({}); - } catch (err) { - onError?.(err as Error); - } + const handleRetry = async (jobId: number) => { + if (loading) return; + setLoading(true); + try { + const newJob = await retryImportJob(jobId); + setJobs((prev) => [newJob, ...prev]); + message.success("已创建重试任务"); + startPolling(); + } catch (err: unknown) { + const detail = err && typeof err === "object" && "response" in err + ? (err as { response?: { data?: { detail?: string } } }).response?.data?.detail + : "重试失败"; + message.error(String(detail)); + } finally { + setLoading(false); } }; - const confirmSave = async () => { - await api.post("/import/ai/confirm", preview); - message.success(`已保存 ${preview.length} 道题`); - setPreview([]); + const handleCancel = async (jobId: number) => { + if (loading) return; + setLoading(true); + try { + const job = await cancelImportJob(jobId); + setJobs((prev) => prev.map((j) => (j.id === jobId ? job : j))); + message.success("已取消任务"); + } catch { + message.error("取消失败"); + } finally { + setLoading(false); + } + }; + + const clearFinished = () => { + setJobs((prev) => prev.filter((j) => !isTerminal(j.status))); + }; + + const progressPercent = (job: ImportJobOut) => + job.total ? Number(((job.processed / job.total) * 100).toFixed(2)) : 0; + + const statusTag = (status: string) => { + if (status === "queued") return 排队中; + if (status === "running") return 处理中; + if (status === "success") return 成功; + if (status === "failed") return 失败; + if (status === "cancelled") return 已取消; + if (status === "retrying") return 重试中; + return {status}; }; return ( - - - - + + + + 导入方式: +