75 lines
2.4 KiB
Python
75 lines
2.4 KiB
Python
from fastapi import APIRouter, Depends, HTTPException
|
|
from sqlalchemy import func
|
|
from sqlalchemy.orm import Session
|
|
|
|
from backend.auth import get_current_user
|
|
from backend.database import get_db
|
|
from backend.models import Category, Question
|
|
from backend.schemas import CategoryCreate, CategoryOut
|
|
|
|
router = APIRouter(dependencies=[Depends(get_current_user)])
|
|
|
|
|
|
def _build_tree(rows: list[Category], counts: dict[str, int]) -> list[dict]:
|
|
nodes = {
|
|
row.id: {
|
|
"id": row.id,
|
|
"name": row.name,
|
|
"level": row.level,
|
|
"parent_id": row.parent_id,
|
|
"count": counts.get(row.name, 0),
|
|
"children": [],
|
|
}
|
|
for row in rows
|
|
}
|
|
tree = []
|
|
for node in nodes.values():
|
|
pid = node["parent_id"]
|
|
if pid and pid in nodes:
|
|
nodes[pid]["children"].append(node)
|
|
else:
|
|
tree.append(node)
|
|
return tree
|
|
|
|
|
|
@router.get("")
|
|
def list_categories(db: Session = Depends(get_db)) -> dict:
|
|
rows = db.query(Category).order_by(Category.level.asc(), Category.id.asc()).all()
|
|
counts = dict(db.query(Question.chapter, func.count(Question.id)).group_by(Question.chapter).all())
|
|
return {"items": _build_tree(rows, counts)}
|
|
|
|
|
|
@router.post("", response_model=CategoryOut)
|
|
def create_category(payload: CategoryCreate, db: Session = Depends(get_db)) -> CategoryOut:
|
|
if payload.parent_id:
|
|
parent = db.get(Category, payload.parent_id)
|
|
if not parent:
|
|
raise HTTPException(status_code=404, detail="父分类不存在")
|
|
item = Category(**payload.model_dump())
|
|
db.add(item)
|
|
db.commit()
|
|
db.refresh(item)
|
|
return CategoryOut.model_validate(item)
|
|
|
|
|
|
@router.put("/{category_id}", response_model=CategoryOut)
|
|
def update_category(category_id: int, payload: CategoryCreate, db: Session = Depends(get_db)) -> CategoryOut:
|
|
item = db.get(Category, category_id)
|
|
if not item:
|
|
raise HTTPException(status_code=404, detail="分类不存在")
|
|
for key, value in payload.model_dump().items():
|
|
setattr(item, key, value)
|
|
db.commit()
|
|
db.refresh(item)
|
|
return CategoryOut.model_validate(item)
|
|
|
|
|
|
@router.delete("/{category_id}")
|
|
def delete_category(category_id: int, db: Session = Depends(get_db)) -> dict:
|
|
item = db.get(Category, category_id)
|
|
if not item:
|
|
raise HTTPException(status_code=404, detail="分类不存在")
|
|
db.delete(item)
|
|
db.commit()
|
|
return {"ok": True}
|