fix: mock

This commit is contained in:
2026-03-13 23:29:55 +08:00
parent b7e973e2b6
commit c72fbc9a14
7 changed files with 165 additions and 21 deletions

View File

@@ -65,7 +65,13 @@ async def upload_images(
# trigger OCR tasks in-process background (non-blocking for API response)
from app.workers.ocr_tasks import process_images_ocr_batch_async
pending_ids = [str(img.id) for img in results if img.ocr_status.value == "pending"]
pending_imgs = [img for img in results if img.ocr_status.value == "pending"]
for img in pending_imgs:
img.ocr_status = OcrStatus.processing
if pending_imgs:
await db.flush()
await db.commit()
pending_ids = [str(img.id) for img in pending_imgs]
if pending_ids:
asyncio.create_task(
process_images_ocr_batch_async(
@@ -171,20 +177,25 @@ async def start_case_ocr(
image_ids = payload.image_ids if payload else []
if image_ids:
images = await repo.list_by_ids_in_case(case_id, image_ids)
# Never submit images that are already processing: this prevents
# duplicate OCR tasks when users trigger OCR from multiple pages.
images = [img for img in images if img.ocr_status != OcrStatus.processing]
# For explicit re-run, mark selected images as processing immediately
# so frontend can reflect state transition without full page refresh.
for img in images:
img.ocr_status = OcrStatus.processing
await db.flush()
await db.commit()
if images:
await db.flush()
await db.commit()
else:
images = await repo.list_for_ocr(case_id, include_done=include_done)
# Mark queued images as processing immediately, including when OCR is
# triggered from workspace page, so UI can show progress right away.
for img in images:
img.ocr_status = OcrStatus.processing
await db.flush()
await db.commit()
if images:
await db.flush()
await db.commit()
from app.workers.ocr_tasks import process_images_ocr_batch_async

View File

@@ -2,9 +2,12 @@ from pathlib import Path
from pydantic_settings import BaseSettings, SettingsConfigDict
ENV_FILE_PATH = Path(__file__).resolve().parents[2] / ".env"
class Settings(BaseSettings):
model_config = SettingsConfigDict(
env_file=".env",
env_file=ENV_FILE_PATH,
env_file_encoding="utf-8",
extra="ignore",
)
@@ -18,6 +21,7 @@ class Settings(BaseSettings):
OCR_API_KEY: str = ""
OCR_API_URL: str = ""
OCR_MODEL: str = ""
OCR_ALLOW_MOCK_FALLBACK: bool = False
OCR_PARALLELISM: int = 4
LLM_API_KEY: str = ""
LLM_API_URL: str = ""

View File

@@ -51,8 +51,12 @@ class ImageRepository(BaseRepository[EvidenceImage]):
async def list_for_ocr(self, case_id: UUID, include_done: bool = False) -> list[EvidenceImage]:
query = select(EvidenceImage).where(EvidenceImage.case_id == case_id)
if not include_done:
query = query.where(EvidenceImage.ocr_status != OcrStatus.done)
# Always exclude currently-processing images to avoid duplicate OCR
# submission from different trigger paths (upload/workspace/screenshots).
if include_done:
query = query.where(EvidenceImage.ocr_status != OcrStatus.processing)
else:
query = query.where(EvidenceImage.ocr_status.in_([OcrStatus.pending, OcrStatus.failed]))
result = await self.session.execute(query.order_by(EvidenceImage.uploaded_at.desc()))
return list(result.scalars().all())

View File

@@ -43,6 +43,17 @@ def _ocr_available() -> bool:
return bool(url and key and model)
def _missing_ocr_fields() -> list[str]:
missing: list[str] = []
if not (settings.OCR_API_URL or settings.LLM_API_URL):
missing.append("OCR_API_URL(or LLM_API_URL)")
if not (settings.OCR_API_KEY or settings.LLM_API_KEY):
missing.append("OCR_API_KEY(or LLM_API_KEY)")
if not (settings.OCR_MODEL or settings.LLM_MODEL):
missing.append("OCR_MODEL(or LLM_MODEL)")
return missing
def _llm_available() -> bool:
url, key, model = _llm_config()
return bool(url and key and model)
@@ -54,6 +65,10 @@ async def classify_page(image_path: str) -> tuple[SourceApp, PageType]:
"""Identify the source app and page type of a screenshot."""
if _ocr_available():
return await _classify_via_api(image_path)
if not settings.OCR_ALLOW_MOCK_FALLBACK:
missing = ", ".join(_missing_ocr_fields()) or "unknown"
raise RuntimeError(f"OCR configuration missing: {missing}")
logger.warning("OCR unavailable, falling back to mock classification for image: %s", image_path)
return _classify_mock(image_path)
@@ -63,6 +78,10 @@ async def extract_transaction_fields(
"""Extract structured transaction fields from a screenshot."""
if _ocr_available():
return await _extract_via_api(image_path, source_app, page_type)
if not settings.OCR_ALLOW_MOCK_FALLBACK:
missing = ", ".join(_missing_ocr_fields()) or "unknown"
raise RuntimeError(f"OCR configuration missing: {missing}")
logger.warning("OCR unavailable, falling back to mock extraction for image: %s", image_path)
mock_data = _extract_mock(image_path, source_app, page_type)
return mock_data, json.dumps(mock_data, ensure_ascii=False)

View File

@@ -28,6 +28,9 @@ async def process_images_ocr_batch_async(image_ids: list[str], max_concurrency:
"""Process many images with bounded OCR concurrency."""
if not image_ids:
return
# De-duplicate in-memory to prevent repeated processing of same image id
# in a single batch submission.
image_ids = list(dict.fromkeys(image_ids))
concurrency = max(1, max_concurrency)
semaphore = asyncio.Semaphore(concurrency)