fix: mock
This commit is contained in:
@@ -65,7 +65,13 @@ async def upload_images(
|
||||
|
||||
# trigger OCR tasks in-process background (non-blocking for API response)
|
||||
from app.workers.ocr_tasks import process_images_ocr_batch_async
|
||||
pending_ids = [str(img.id) for img in results if img.ocr_status.value == "pending"]
|
||||
pending_imgs = [img for img in results if img.ocr_status.value == "pending"]
|
||||
for img in pending_imgs:
|
||||
img.ocr_status = OcrStatus.processing
|
||||
if pending_imgs:
|
||||
await db.flush()
|
||||
await db.commit()
|
||||
pending_ids = [str(img.id) for img in pending_imgs]
|
||||
if pending_ids:
|
||||
asyncio.create_task(
|
||||
process_images_ocr_batch_async(
|
||||
@@ -171,20 +177,25 @@ async def start_case_ocr(
|
||||
image_ids = payload.image_ids if payload else []
|
||||
if image_ids:
|
||||
images = await repo.list_by_ids_in_case(case_id, image_ids)
|
||||
# Never submit images that are already processing: this prevents
|
||||
# duplicate OCR tasks when users trigger OCR from multiple pages.
|
||||
images = [img for img in images if img.ocr_status != OcrStatus.processing]
|
||||
# For explicit re-run, mark selected images as processing immediately
|
||||
# so frontend can reflect state transition without full page refresh.
|
||||
for img in images:
|
||||
img.ocr_status = OcrStatus.processing
|
||||
await db.flush()
|
||||
await db.commit()
|
||||
if images:
|
||||
await db.flush()
|
||||
await db.commit()
|
||||
else:
|
||||
images = await repo.list_for_ocr(case_id, include_done=include_done)
|
||||
# Mark queued images as processing immediately, including when OCR is
|
||||
# triggered from workspace page, so UI can show progress right away.
|
||||
for img in images:
|
||||
img.ocr_status = OcrStatus.processing
|
||||
await db.flush()
|
||||
await db.commit()
|
||||
if images:
|
||||
await db.flush()
|
||||
await db.commit()
|
||||
|
||||
from app.workers.ocr_tasks import process_images_ocr_batch_async
|
||||
|
||||
|
||||
@@ -2,9 +2,12 @@ from pathlib import Path
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
ENV_FILE_PATH = Path(__file__).resolve().parents[2] / ".env"
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
model_config = SettingsConfigDict(
|
||||
env_file=".env",
|
||||
env_file=ENV_FILE_PATH,
|
||||
env_file_encoding="utf-8",
|
||||
extra="ignore",
|
||||
)
|
||||
@@ -18,6 +21,7 @@ class Settings(BaseSettings):
|
||||
OCR_API_KEY: str = ""
|
||||
OCR_API_URL: str = ""
|
||||
OCR_MODEL: str = ""
|
||||
OCR_ALLOW_MOCK_FALLBACK: bool = False
|
||||
OCR_PARALLELISM: int = 4
|
||||
LLM_API_KEY: str = ""
|
||||
LLM_API_URL: str = ""
|
||||
|
||||
@@ -51,8 +51,12 @@ class ImageRepository(BaseRepository[EvidenceImage]):
|
||||
|
||||
async def list_for_ocr(self, case_id: UUID, include_done: bool = False) -> list[EvidenceImage]:
|
||||
query = select(EvidenceImage).where(EvidenceImage.case_id == case_id)
|
||||
if not include_done:
|
||||
query = query.where(EvidenceImage.ocr_status != OcrStatus.done)
|
||||
# Always exclude currently-processing images to avoid duplicate OCR
|
||||
# submission from different trigger paths (upload/workspace/screenshots).
|
||||
if include_done:
|
||||
query = query.where(EvidenceImage.ocr_status != OcrStatus.processing)
|
||||
else:
|
||||
query = query.where(EvidenceImage.ocr_status.in_([OcrStatus.pending, OcrStatus.failed]))
|
||||
result = await self.session.execute(query.order_by(EvidenceImage.uploaded_at.desc()))
|
||||
return list(result.scalars().all())
|
||||
|
||||
|
||||
@@ -43,6 +43,17 @@ def _ocr_available() -> bool:
|
||||
return bool(url and key and model)
|
||||
|
||||
|
||||
def _missing_ocr_fields() -> list[str]:
|
||||
missing: list[str] = []
|
||||
if not (settings.OCR_API_URL or settings.LLM_API_URL):
|
||||
missing.append("OCR_API_URL(or LLM_API_URL)")
|
||||
if not (settings.OCR_API_KEY or settings.LLM_API_KEY):
|
||||
missing.append("OCR_API_KEY(or LLM_API_KEY)")
|
||||
if not (settings.OCR_MODEL or settings.LLM_MODEL):
|
||||
missing.append("OCR_MODEL(or LLM_MODEL)")
|
||||
return missing
|
||||
|
||||
|
||||
def _llm_available() -> bool:
|
||||
url, key, model = _llm_config()
|
||||
return bool(url and key and model)
|
||||
@@ -54,6 +65,10 @@ async def classify_page(image_path: str) -> tuple[SourceApp, PageType]:
|
||||
"""Identify the source app and page type of a screenshot."""
|
||||
if _ocr_available():
|
||||
return await _classify_via_api(image_path)
|
||||
if not settings.OCR_ALLOW_MOCK_FALLBACK:
|
||||
missing = ", ".join(_missing_ocr_fields()) or "unknown"
|
||||
raise RuntimeError(f"OCR configuration missing: {missing}")
|
||||
logger.warning("OCR unavailable, falling back to mock classification for image: %s", image_path)
|
||||
return _classify_mock(image_path)
|
||||
|
||||
|
||||
@@ -63,6 +78,10 @@ async def extract_transaction_fields(
|
||||
"""Extract structured transaction fields from a screenshot."""
|
||||
if _ocr_available():
|
||||
return await _extract_via_api(image_path, source_app, page_type)
|
||||
if not settings.OCR_ALLOW_MOCK_FALLBACK:
|
||||
missing = ", ".join(_missing_ocr_fields()) or "unknown"
|
||||
raise RuntimeError(f"OCR configuration missing: {missing}")
|
||||
logger.warning("OCR unavailable, falling back to mock extraction for image: %s", image_path)
|
||||
mock_data = _extract_mock(image_path, source_app, page_type)
|
||||
return mock_data, json.dumps(mock_data, ensure_ascii=False)
|
||||
|
||||
|
||||
@@ -28,6 +28,9 @@ async def process_images_ocr_batch_async(image_ids: list[str], max_concurrency:
|
||||
"""Process many images with bounded OCR concurrency."""
|
||||
if not image_ids:
|
||||
return
|
||||
# De-duplicate in-memory to prevent repeated processing of same image id
|
||||
# in a single batch submission.
|
||||
image_ids = list(dict.fromkeys(image_ids))
|
||||
concurrency = max(1, max_concurrency)
|
||||
semaphore = asyncio.Semaphore(concurrency)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user