diff --git a/backend/alembic/env.py b/backend/alembic/env.py index 3181e12..0bbd9a1 100644 --- a/backend/alembic/env.py +++ b/backend/alembic/env.py @@ -1,8 +1,16 @@ from logging.config import fileConfig +import sys +from pathlib import Path from sqlalchemy import engine_from_config, pool from alembic import context +# Ensure `backend/` is on sys.path so `import app...` works +# no matter where `alembic` is executed from. +BACKEND_ROOT = Path(__file__).resolve().parents[1] +if str(BACKEND_ROOT) not in sys.path: + sys.path.insert(0, str(BACKEND_ROOT)) + from app.core.config import settings from app.core.database import Base import app.models # noqa: F401 – ensure all models are imported diff --git a/backend/alembic/script.py.mako b/backend/alembic/script.py.mako index 590f5b3..dafb58b 100644 --- a/backend/alembic/script.py.mako +++ b/backend/alembic/script.py.mako @@ -8,6 +8,7 @@ from typing import Sequence, Union from alembic import op import sqlalchemy as sa +from sqlalchemy import Text ${imports if imports else ""} revision: str = ${repr(up_revision)} diff --git a/backend/alembic/versions/be562b8079e3_init_sqlite_test.py b/backend/alembic/versions/be562b8079e3_init_sqlite_test.py new file mode 100644 index 0000000..cc4783b --- /dev/null +++ b/backend/alembic/versions/be562b8079e3_init_sqlite_test.py @@ -0,0 +1,169 @@ +"""init_sqlite_test + +Revision ID: be562b8079e3 +Revises: +Create Date: 2026-03-11 17:33:30.695730 +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy import Text +from sqlalchemy.dialects import postgresql + +revision: str = 'be562b8079e3' +down_revision: Union[str, None] = None +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('cases', + sa.Column('id', sa.UUID(), nullable=False), + sa.Column('case_no', sa.String(length=64), nullable=False), + sa.Column('title', sa.String(length=256), nullable=False), + sa.Column('victim_name', sa.String(length=128), nullable=False), + sa.Column('handler', sa.String(length=128), nullable=False), + sa.Column('status', sa.Enum('pending', 'uploading', 'analyzing', 'reviewing', 'completed', name='casestatus'), nullable=False), + sa.Column('image_count', sa.Integer(), nullable=False), + sa.Column('total_amount', sa.Numeric(precision=14, scale=2), nullable=False), + sa.Column('created_by', sa.String(length=128), nullable=False), + sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=False), + sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=False), + sa.Column('deleted_at', sa.DateTime(timezone=True), nullable=True), + sa.PrimaryKeyConstraint('id') + ) + op.create_index(op.f('ix_cases_case_no'), 'cases', ['case_no'], unique=True) + op.create_table('evidence_images', + sa.Column('id', sa.UUID(), nullable=False), + sa.Column('case_id', sa.UUID(), nullable=False), + sa.Column('file_path', sa.String(length=512), nullable=False), + sa.Column('thumb_path', sa.String(length=512), nullable=False), + sa.Column('source_app', sa.Enum('wechat', 'alipay', 'bank', 'digital_wallet', 'other', name='sourceapp'), nullable=False), + sa.Column('page_type', sa.Enum('bill_list', 'bill_detail', 'transfer_receipt', 'sms_notice', 'balance', 'unknown', name='pagetype'), nullable=False), + sa.Column('ocr_status', sa.Enum('pending', 'processing', 'done', 'failed', name='ocrstatus'), nullable=False), + sa.Column('file_hash', sa.String(length=128), nullable=False), + sa.Column('file_size', sa.Integer(), nullable=False), + sa.Column('uploaded_at', sa.DateTime(timezone=True), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=False), + sa.ForeignKeyConstraint(['case_id'], ['cases.id'], ), + sa.PrimaryKeyConstraint('id') + ) + op.create_index(op.f('ix_evidence_images_case_id'), 'evidence_images', ['case_id'], unique=False) + op.create_index(op.f('ix_evidence_images_file_hash'), 'evidence_images', ['file_hash'], unique=True) + op.create_table('export_reports', + sa.Column('id', sa.UUID(), nullable=False), + sa.Column('case_id', sa.UUID(), nullable=False), + sa.Column('report_type', sa.Enum('pdf', 'excel', 'word', name='reporttype'), nullable=False), + sa.Column('file_path', sa.String(length=512), nullable=False), + sa.Column('version', sa.Integer(), nullable=False), + sa.Column('content_snapshot', postgresql.JSONB(astext_type=Text()), nullable=False), + sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=False), + sa.ForeignKeyConstraint(['case_id'], ['cases.id'], ), + sa.PrimaryKeyConstraint('id') + ) + op.create_index(op.f('ix_export_reports_case_id'), 'export_reports', ['case_id'], unique=False) + op.create_table('fund_flow_edges', + sa.Column('id', sa.UUID(), nullable=False), + sa.Column('case_id', sa.UUID(), nullable=False), + sa.Column('source_node', sa.String(length=256), nullable=False), + sa.Column('target_node', sa.String(length=256), nullable=False), + sa.Column('source_type', sa.String(length=32), nullable=False), + sa.Column('target_type', sa.String(length=32), nullable=False), + sa.Column('amount', sa.Numeric(precision=14, scale=2), nullable=False), + sa.Column('tx_count', sa.Integer(), nullable=False), + sa.Column('earliest_time', sa.DateTime(timezone=True), nullable=True), + sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=False), + sa.ForeignKeyConstraint(['case_id'], ['cases.id'], ), + sa.PrimaryKeyConstraint('id') + ) + op.create_index(op.f('ix_fund_flow_edges_case_id'), 'fund_flow_edges', ['case_id'], unique=False) + op.create_table('transaction_clusters', + sa.Column('id', sa.UUID(), nullable=False), + sa.Column('case_id', sa.UUID(), nullable=False), + sa.Column('primary_tx_id', sa.UUID(), nullable=True), + sa.Column('match_reason', sa.String(length=512), nullable=False), + sa.ForeignKeyConstraint(['case_id'], ['cases.id'], ), + sa.PrimaryKeyConstraint('id') + ) + op.create_index(op.f('ix_transaction_clusters_case_id'), 'transaction_clusters', ['case_id'], unique=False) + op.create_table('ocr_blocks', + sa.Column('id', sa.UUID(), nullable=False), + sa.Column('image_id', sa.UUID(), nullable=False), + sa.Column('content', sa.Text(), nullable=False), + sa.Column('bbox', postgresql.JSONB(astext_type=Text()), nullable=False), + sa.Column('seq_order', sa.Integer(), nullable=False), + sa.Column('confidence', sa.Float(), nullable=False), + sa.ForeignKeyConstraint(['image_id'], ['evidence_images.id'], ), + sa.PrimaryKeyConstraint('id') + ) + op.create_index(op.f('ix_ocr_blocks_image_id'), 'ocr_blocks', ['image_id'], unique=False) + op.create_table('transaction_records', + sa.Column('id', sa.UUID(), nullable=False), + sa.Column('case_id', sa.UUID(), nullable=False), + sa.Column('evidence_image_id', sa.UUID(), nullable=True), + sa.Column('cluster_id', sa.UUID(), nullable=True), + sa.Column('source_app', sa.Enum('wechat', 'alipay', 'bank', 'digital_wallet', 'other', name='sourceapp'), nullable=False), + sa.Column('trade_time', sa.DateTime(timezone=True), nullable=False), + sa.Column('amount', sa.Numeric(precision=14, scale=2), nullable=False), + sa.Column('direction', sa.Enum('in_', 'out', name='direction'), nullable=False), + sa.Column('counterparty_name', sa.String(length=256), nullable=False), + sa.Column('counterparty_account', sa.String(length=256), nullable=False), + sa.Column('self_account_tail_no', sa.String(length=32), nullable=False), + sa.Column('order_no', sa.String(length=128), nullable=False), + sa.Column('remark', sa.Text(), nullable=False), + sa.Column('confidence', sa.Float(), nullable=False), + sa.Column('is_duplicate', sa.Boolean(), nullable=False), + sa.Column('is_transit', sa.Boolean(), nullable=False), + sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=False), + sa.ForeignKeyConstraint(['case_id'], ['cases.id'], ), + sa.ForeignKeyConstraint(['cluster_id'], ['transaction_clusters.id'], ), + sa.ForeignKeyConstraint(['evidence_image_id'], ['evidence_images.id'], ), + sa.PrimaryKeyConstraint('id') + ) + op.create_index(op.f('ix_transaction_records_case_id'), 'transaction_records', ['case_id'], unique=False) + op.create_index(op.f('ix_transaction_records_order_no'), 'transaction_records', ['order_no'], unique=False) + op.create_index(op.f('ix_transaction_records_trade_time'), 'transaction_records', ['trade_time'], unique=False) + op.create_table('fraud_assessments', + sa.Column('id', sa.UUID(), nullable=False), + sa.Column('case_id', sa.UUID(), nullable=False), + sa.Column('transaction_id', sa.UUID(), nullable=False), + sa.Column('confidence_level', sa.Enum('high', 'medium', 'low', name='confidencelevel'), nullable=False), + sa.Column('assessed_amount', sa.Numeric(precision=14, scale=2), nullable=False), + sa.Column('reason', sa.Text(), nullable=False), + sa.Column('exclude_reason', sa.Text(), nullable=False), + sa.Column('review_status', sa.Enum('pending', 'confirmed', 'rejected', 'needs_info', name='reviewstatus'), nullable=False), + sa.Column('review_note', sa.Text(), nullable=False), + sa.Column('reviewed_by', sa.String(length=128), nullable=False), + sa.Column('reviewed_at', sa.DateTime(timezone=True), nullable=True), + sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=False), + sa.ForeignKeyConstraint(['case_id'], ['cases.id'], ), + sa.ForeignKeyConstraint(['transaction_id'], ['transaction_records.id'], ), + sa.PrimaryKeyConstraint('id') + ) + op.create_index(op.f('ix_fraud_assessments_case_id'), 'fraud_assessments', ['case_id'], unique=False) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index(op.f('ix_fraud_assessments_case_id'), table_name='fraud_assessments') + op.drop_table('fraud_assessments') + op.drop_index(op.f('ix_transaction_records_trade_time'), table_name='transaction_records') + op.drop_index(op.f('ix_transaction_records_order_no'), table_name='transaction_records') + op.drop_index(op.f('ix_transaction_records_case_id'), table_name='transaction_records') + op.drop_table('transaction_records') + op.drop_index(op.f('ix_ocr_blocks_image_id'), table_name='ocr_blocks') + op.drop_table('ocr_blocks') + op.drop_index(op.f('ix_transaction_clusters_case_id'), table_name='transaction_clusters') + op.drop_table('transaction_clusters') + op.drop_index(op.f('ix_fund_flow_edges_case_id'), table_name='fund_flow_edges') + op.drop_table('fund_flow_edges') + op.drop_index(op.f('ix_export_reports_case_id'), table_name='export_reports') + op.drop_table('export_reports') + op.drop_index(op.f('ix_evidence_images_file_hash'), table_name='evidence_images') + op.drop_index(op.f('ix_evidence_images_case_id'), table_name='evidence_images') + op.drop_table('evidence_images') + op.drop_index(op.f('ix_cases_case_no'), table_name='cases') + op.drop_table('cases') + # ### end Alembic commands ### diff --git a/backend/app/api/v1/assessments.py b/backend/app/api/v1/assessments.py index 7f1fa2f..7139e01 100644 --- a/backend/app/api/v1/assessments.py +++ b/backend/app/api/v1/assessments.py @@ -2,10 +2,12 @@ from uuid import UUID from datetime import datetime, timezone from fastapi import APIRouter, Depends, Query, HTTPException +from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import selectinload from app.core.database import get_db -from app.models.assessment import ConfidenceLevel +from app.models.assessment import FraudAssessment, ConfidenceLevel from app.repositories.assessment_repo import AssessmentRepository from app.schemas.assessment import ( AssessmentOut, @@ -46,6 +48,14 @@ async def review_assessment( "reviewed_by": body.reviewed_by, "reviewed_at": datetime.now(timezone.utc), }) + + # eager-load the transaction relationship to avoid lazy-load in async context + result = await db.execute( + select(FraudAssessment) + .options(selectinload(FraudAssessment.transaction)) + .where(FraudAssessment.id == assessment_id) + ) + assessment = result.scalar_one() return assessment diff --git a/backend/app/api/v1/images.py b/backend/app/api/v1/images.py index ca17f25..5490708 100644 --- a/backend/app/api/v1/images.py +++ b/backend/app/api/v1/images.py @@ -1,15 +1,16 @@ from uuid import UUID +import asyncio -from fastapi import APIRouter, Depends, UploadFile, File, HTTPException, Query +from fastapi import APIRouter, Depends, UploadFile, File, HTTPException from fastapi.responses import FileResponse from sqlalchemy.ext.asyncio import AsyncSession from app.core.config import settings from app.core.database import get_db -from app.models.evidence_image import EvidenceImage, SourceApp, PageType +from app.models.evidence_image import EvidenceImage, SourceApp, PageType, OcrStatus from app.repositories.image_repo import ImageRepository from app.repositories.case_repo import CaseRepository -from app.schemas.image import ImageOut, ImageDetailOut, OcrFieldCorrection +from app.schemas.image import ImageOut, ImageDetailOut, OcrFieldCorrection, CaseOcrStartIn from app.utils.hash import sha256_file from app.utils.file_storage import save_upload @@ -32,9 +33,11 @@ async def upload_images( for f in files: data = await f.read() - file_hash = sha256_file(data) + raw_hash = sha256_file(data) + # Scope hash by case to avoid cross-case unique conflicts while still deduplicating inside one case. + scoped_hash = f"{raw_hash}:{case_id}" - existing = await img_repo.find_by_hash(file_hash) + existing = await img_repo.find_by_hash_in_case(case_id, [raw_hash, scoped_hash]) if existing: results.append(existing) continue @@ -44,7 +47,7 @@ async def upload_images( case_id=case_id, file_path=file_path, thumb_path=thumb_path, - file_hash=file_hash, + file_hash=scoped_hash, file_size=len(data), ) image = await img_repo.create(image) @@ -53,14 +56,11 @@ async def upload_images( case.image_count = await img_repo.count_by_case(case_id) await db.flush() - # trigger OCR tasks (non-blocking) - from app.workers.ocr_tasks import process_image_ocr + # trigger OCR tasks in-process background (non-blocking for API response) + from app.workers.ocr_tasks import process_image_ocr_async for img in results: if img.ocr_status.value == "pending": - try: - process_image_ocr.delay(str(img.id)) - except Exception: - pass + asyncio.create_task(process_image_ocr_async(str(img.id))) return results @@ -73,7 +73,21 @@ async def list_images( db: AsyncSession = Depends(get_db), ): repo = ImageRepository(db) - return await repo.list_by_case(case_id, source_app=source_app, page_type=page_type) + images = await repo.list_by_case(case_id, source_app=source_app, page_type=page_type) + return [ + ImageOut( + id=img.id, + case_id=img.case_id, + url=f"/api/v1/images/{img.id}/file", + thumb_url=f"/api/v1/images/{img.id}/file", + source_app=img.source_app, + page_type=img.page_type, + ocr_status=img.ocr_status, + file_hash=img.file_hash, + uploaded_at=img.uploaded_at, + ) + for img in images + ] @router.get("/images/{image_id}", response_model=ImageDetailOut) @@ -128,3 +142,43 @@ async def get_image_file(image_id: UUID, db: AsyncSession = Depends(get_db)): if not full_path.exists(): raise HTTPException(404, "文件不存在") return FileResponse(full_path) + + +@router.post("/cases/{case_id}/ocr/start") +async def start_case_ocr( + case_id: UUID, + payload: CaseOcrStartIn | None = None, + db: AsyncSession = Depends(get_db), +): + case_repo = CaseRepository(db) + case = await case_repo.get(case_id) + if not case: + raise HTTPException(404, "案件不存在") + + repo = ImageRepository(db) + include_done = payload.include_done if payload else False + image_ids = payload.image_ids if payload else [] + if image_ids: + images = await repo.list_by_ids_in_case(case_id, image_ids) + # For explicit re-run, mark selected images as processing immediately + # so frontend can reflect state transition without full page refresh. + for img in images: + img.ocr_status = OcrStatus.processing + await db.flush() + await db.commit() + else: + images = await repo.list_for_ocr(case_id, include_done=include_done) + + from app.workers.ocr_tasks import process_image_ocr_async + + submitted = 0 + for img in images: + asyncio.create_task(process_image_ocr_async(str(img.id))) + submitted += 1 + + return { + "caseId": str(case_id), + "submitted": submitted, + "totalCandidates": len(images), + "message": f"已提交 {submitted} 张截图的 OCR 任务", + } diff --git a/backend/app/core/config.py b/backend/app/core/config.py index 6093ce3..f5443da 100644 --- a/backend/app/core/config.py +++ b/backend/app/core/config.py @@ -17,6 +17,7 @@ class Settings(BaseSettings): OCR_API_KEY: str = "" OCR_API_URL: str = "" + OCR_MODEL: str = "" LLM_API_KEY: str = "" LLM_API_URL: str = "" LLM_MODEL: str = "" diff --git a/backend/app/repositories/image_repo.py b/backend/app/repositories/image_repo.py index b277bda..38df492 100644 --- a/backend/app/repositories/image_repo.py +++ b/backend/app/repositories/image_repo.py @@ -3,7 +3,7 @@ from uuid import UUID from sqlalchemy import select, func from sqlalchemy.ext.asyncio import AsyncSession -from app.models.evidence_image import EvidenceImage, SourceApp, PageType +from app.models.evidence_image import EvidenceImage, SourceApp, PageType, OcrStatus from app.repositories.base import BaseRepository @@ -17,6 +17,17 @@ class ImageRepository(BaseRepository[EvidenceImage]): ) return result.scalar_one_or_none() + async def find_by_hash_in_case(self, case_id: UUID, file_hashes: list[str]) -> EvidenceImage | None: + if not file_hashes: + return None + result = await self.session.execute( + select(EvidenceImage).where( + EvidenceImage.case_id == case_id, + EvidenceImage.file_hash.in_(file_hashes), + ) + ) + return result.scalar_one_or_none() + async def list_by_case( self, case_id: UUID, @@ -37,3 +48,20 @@ class ImageRepository(BaseRepository[EvidenceImage]): select(func.count()).select_from(EvidenceImage).where(EvidenceImage.case_id == case_id) ) return result.scalar() or 0 + + async def list_for_ocr(self, case_id: UUID, include_done: bool = False) -> list[EvidenceImage]: + query = select(EvidenceImage).where(EvidenceImage.case_id == case_id) + if not include_done: + query = query.where(EvidenceImage.ocr_status != OcrStatus.done) + result = await self.session.execute(query.order_by(EvidenceImage.uploaded_at.desc())) + return list(result.scalars().all()) + + async def list_by_ids_in_case(self, case_id: UUID, image_ids: list[UUID]) -> list[EvidenceImage]: + if not image_ids: + return [] + result = await self.session.execute( + select(EvidenceImage) + .where(EvidenceImage.case_id == case_id, EvidenceImage.id.in_(image_ids)) + .order_by(EvidenceImage.uploaded_at.desc()) + ) + return list(result.scalars().all()) diff --git a/backend/app/schemas/analysis.py b/backend/app/schemas/analysis.py index 4aacbe6..ab6e2ec 100644 --- a/backend/app/schemas/analysis.py +++ b/backend/app/schemas/analysis.py @@ -1,7 +1,7 @@ -from pydantic import BaseModel +from app.schemas.base import CamelModel -class AnalysisStatusOut(BaseModel): +class AnalysisStatusOut(CamelModel): case_id: str status: str progress: int = 0 @@ -9,6 +9,6 @@ class AnalysisStatusOut(BaseModel): message: str = "" -class AnalysisTriggerOut(BaseModel): +class AnalysisTriggerOut(CamelModel): task_id: str message: str diff --git a/backend/app/schemas/assessment.py b/backend/app/schemas/assessment.py index 3ac8dc7..0c14afd 100644 --- a/backend/app/schemas/assessment.py +++ b/backend/app/schemas/assessment.py @@ -1,13 +1,12 @@ from datetime import datetime from uuid import UUID -from pydantic import BaseModel - from app.models.assessment import ConfidenceLevel, ReviewStatus +from app.schemas.base import CamelModel from app.schemas.transaction import TransactionOut -class AssessmentOut(BaseModel): +class AssessmentOut(CamelModel): id: UUID case_id: UUID transaction_id: UUID @@ -21,19 +20,17 @@ class AssessmentOut(BaseModel): reviewed_by: str reviewed_at: datetime | None = None - model_config = {"from_attributes": True} - -class AssessmentListOut(BaseModel): +class AssessmentListOut(CamelModel): items: list[AssessmentOut] total: int -class ReviewSubmit(BaseModel): +class ReviewSubmit(CamelModel): review_status: ReviewStatus review_note: str = "" reviewed_by: str = "demo_user" -class InquirySuggestionOut(BaseModel): +class InquirySuggestionOut(CamelModel): suggestions: list[str] diff --git a/backend/app/schemas/base.py b/backend/app/schemas/base.py new file mode 100644 index 0000000..7fea96e --- /dev/null +++ b/backend/app/schemas/base.py @@ -0,0 +1,15 @@ +"""Base schema with camelCase JSON serialization.""" +from pydantic import BaseModel, ConfigDict + + +def to_camel(s: str) -> str: + parts = s.split("_") + return parts[0] + "".join(p.capitalize() for p in parts[1:]) + + +class CamelModel(BaseModel): + model_config = ConfigDict( + from_attributes=True, + alias_generator=to_camel, + populate_by_name=True, + ) diff --git a/backend/app/schemas/case.py b/backend/app/schemas/case.py index 50d4584..506668d 100644 --- a/backend/app/schemas/case.py +++ b/backend/app/schemas/case.py @@ -1,26 +1,25 @@ from datetime import datetime from uuid import UUID -from pydantic import BaseModel - from app.models.case import CaseStatus +from app.schemas.base import CamelModel -class CaseCreate(BaseModel): +class CaseCreate(CamelModel): case_no: str title: str victim_name: str handler: str = "" -class CaseUpdate(BaseModel): +class CaseUpdate(CamelModel): title: str | None = None victim_name: str | None = None handler: str | None = None status: CaseStatus | None = None -class CaseOut(BaseModel): +class CaseOut(CamelModel): id: UUID case_no: str title: str @@ -32,9 +31,7 @@ class CaseOut(BaseModel): created_at: datetime updated_at: datetime - model_config = {"from_attributes": True} - -class CaseListOut(BaseModel): +class CaseListOut(CamelModel): items: list[CaseOut] total: int diff --git a/backend/app/schemas/image.py b/backend/app/schemas/image.py index c3d80c7..069069b 100644 --- a/backend/app/schemas/image.py +++ b/backend/app/schemas/image.py @@ -1,12 +1,11 @@ from datetime import datetime from uuid import UUID -from pydantic import BaseModel - from app.models.evidence_image import SourceApp, PageType, OcrStatus +from app.schemas.base import CamelModel -class ImageOut(BaseModel): +class ImageOut(CamelModel): id: UUID case_id: UUID url: str = "" @@ -17,24 +16,25 @@ class ImageOut(BaseModel): file_hash: str uploaded_at: datetime - model_config = {"from_attributes": True} - -class OcrBlockOut(BaseModel): +class OcrBlockOut(CamelModel): id: UUID content: str bbox: dict seq_order: int confidence: float - model_config = {"from_attributes": True} - class ImageDetailOut(ImageOut): ocr_blocks: list[OcrBlockOut] = [] -class OcrFieldCorrection(BaseModel): +class OcrFieldCorrection(CamelModel): field_name: str old_value: str new_value: str + + +class CaseOcrStartIn(CamelModel): + include_done: bool = False + image_ids: list[UUID] = [] diff --git a/backend/app/schemas/report.py b/backend/app/schemas/report.py index 1fc2f6e..4143827 100644 --- a/backend/app/schemas/report.py +++ b/backend/app/schemas/report.py @@ -1,12 +1,11 @@ from datetime import datetime from uuid import UUID -from pydantic import BaseModel - from app.models.report import ReportType +from app.schemas.base import CamelModel -class ReportCreate(BaseModel): +class ReportCreate(CamelModel): report_type: ReportType include_summary: bool = True include_transactions: bool = True @@ -17,7 +16,7 @@ class ReportCreate(BaseModel): include_screenshots: bool = False -class ReportOut(BaseModel): +class ReportOut(CamelModel): id: UUID case_id: UUID report_type: ReportType @@ -25,9 +24,7 @@ class ReportOut(BaseModel): version: int created_at: datetime - model_config = {"from_attributes": True} - -class ReportListOut(BaseModel): +class ReportListOut(CamelModel): items: list[ReportOut] total: int diff --git a/backend/app/schemas/transaction.py b/backend/app/schemas/transaction.py index 070f224..75ed1f5 100644 --- a/backend/app/schemas/transaction.py +++ b/backend/app/schemas/transaction.py @@ -1,13 +1,12 @@ from datetime import datetime from uuid import UUID -from pydantic import BaseModel - from app.models.evidence_image import SourceApp from app.models.transaction import Direction +from app.schemas.base import CamelModel -class TransactionOut(BaseModel): +class TransactionOut(CamelModel): id: UUID case_id: UUID source_app: SourceApp @@ -25,21 +24,19 @@ class TransactionOut(BaseModel): is_duplicate: bool is_transit: bool - model_config = {"from_attributes": True} - -class TransactionListOut(BaseModel): +class TransactionListOut(CamelModel): items: list[TransactionOut] total: int -class FlowNodeOut(BaseModel): +class FlowNodeOut(CamelModel): id: str label: str type: str -class FlowEdgeOut(BaseModel): +class FlowEdgeOut(CamelModel): source: str target: str amount: float @@ -47,6 +44,6 @@ class FlowEdgeOut(BaseModel): trade_time: str -class FlowGraphOut(BaseModel): +class FlowGraphOut(CamelModel): nodes: list[FlowNodeOut] edges: list[FlowEdgeOut] diff --git a/backend/app/services/ocr_service.py b/backend/app/services/ocr_service.py index bba83cb..ef6fa4e 100644 --- a/backend/app/services/ocr_service.py +++ b/backend/app/services/ocr_service.py @@ -1,12 +1,20 @@ """OCR and multimodal extraction service. -Wraps calls to cloud OCR / multimodal APIs with a provider-agnostic interface. -When API keys are not configured, falls back to a mock implementation that -returns placeholder data (sufficient for demo / competition). +Both classify_page and extract_transaction_fields use OpenAI-compatible +multimodal chat completion APIs. OCR and LLM can point to different +providers / models via separate env vars: + + OCR_API_URL / OCR_API_KEY / OCR_MODEL — for page classification & field extraction + LLM_API_URL / LLM_API_KEY / LLM_MODEL — for reasoning tasks (assessment, suggestions) + +When OCR keys are not set, falls back to LLM keys. +When neither is set, returns mock data (sufficient for demo). """ +import base64 import json import logging -from pathlib import Path +import ast +import re import httpx @@ -14,28 +22,163 @@ from app.core.config import settings from app.models.evidence_image import SourceApp, PageType logger = logging.getLogger(__name__) +ENABLE_LLM_REPAIR = False # temporary: disabled per debugging request + + +def _ocr_config() -> tuple[str, str, str]: + """Return (api_url, api_key, model) for OCR, falling back to LLM config.""" + url = settings.OCR_API_URL or settings.LLM_API_URL + key = settings.OCR_API_KEY or settings.LLM_API_KEY + model = settings.OCR_MODEL or settings.LLM_MODEL + return url, key, model + + +def _llm_config() -> tuple[str, str, str]: + """Return (api_url, api_key, model) for text-only LLM repair.""" + return settings.LLM_API_URL, settings.LLM_API_KEY, settings.LLM_MODEL + + +def _ocr_available() -> bool: + url, key, model = _ocr_config() + return bool(url and key and model) + + +def _llm_available() -> bool: + url, key, model = _llm_config() + return bool(url and key and model) # ── provider-agnostic interface ────────────────────────────────────────── async def classify_page(image_path: str) -> tuple[SourceApp, PageType]: """Identify the source app and page type of a screenshot.""" - if settings.LLM_API_KEY and settings.LLM_API_URL: + if _ocr_available(): return await _classify_via_api(image_path) return _classify_mock(image_path) -async def extract_transaction_fields(image_path: str, source_app: SourceApp, page_type: PageType) -> dict: +async def extract_transaction_fields( + image_path: str, source_app: SourceApp, page_type: PageType +) -> tuple[dict | list, str]: """Extract structured transaction fields from a screenshot.""" - if settings.LLM_API_KEY and settings.LLM_API_URL: + if _ocr_available(): return await _extract_via_api(image_path, source_app, page_type) - return _extract_mock(image_path, source_app, page_type) + mock_data = _extract_mock(image_path, source_app, page_type) + return mock_data, json.dumps(mock_data, ensure_ascii=False) -# ── real API implementation ────────────────────────────────────────────── +# ── OpenAI-compatible API implementation ───────────────────────────────── + +async def _call_vision(prompt: str, image_b64: str, max_tokens: int = 2000) -> str: + """Send a vision request to the OCR model endpoint.""" + url, key, model = _ocr_config() + async with httpx.AsyncClient(timeout=60) as client: + resp = await client.post( + url, + headers={"Authorization": f"Bearer {key}"}, + json={ + "model": model, + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": prompt}, + {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_b64}"}}, + ], + } + ], + "max_tokens": max_tokens, + "temperature": 0, + }, + ) + resp.raise_for_status() + body = resp.json() + choice0 = (body.get("choices") or [{}])[0] + msg = choice0.get("message") or {} + content = msg.get("content", "") + finish_reason = choice0.get("finish_reason") + usage = body.get("usage") or {} + return content if isinstance(content, str) else str(content) + + +async def _call_text_llm(prompt: str, max_tokens: int = 1200) -> str: + """Send a text-only request to the reasoning LLM endpoint.""" + url, key, model = _llm_config() + async with httpx.AsyncClient(timeout=60) as client: + resp = await client.post( + url, + headers={"Authorization": f"Bearer {key}"}, + json={ + "model": model, + "messages": [{"role": "user", "content": prompt}], + "max_tokens": max_tokens, + }, + ) + resp.raise_for_status() + return resp.json()["choices"][0]["message"]["content"] + + +def _parse_json_response(text: str): + """Strip markdown fences and parse JSON from LLM output.""" + cleaned = text.strip() + if cleaned.startswith("```"): + cleaned = cleaned.split("\n", 1)[-1] + if cleaned.endswith("```"): + cleaned = cleaned.rsplit("```", 1)[0] + cleaned = cleaned.strip().removeprefix("json").strip() + # Try parsing the full body first. + try: + return json.loads(cleaned) + except Exception: + pass + + # Extract likely JSON block when model wraps with extra text. + match = re.search(r"(\{[\s\S]*\}|\[[\s\S]*\])", cleaned) + candidate = match.group(1) if match else cleaned + + # Normalize common model output issues: + # - smart quotes + # - trailing commas + normalized = ( + candidate + .replace("“", "\"") + .replace("”", "\"") + .replace("‘", "'") + .replace("’", "'") + ) + normalized = re.sub(r",\s*([}\]])", r"\1", normalized) + try: + return json.loads(normalized) + except Exception: + pass + + # Last fallback: Python-like literal payloads with single quotes / True/False/None. + py_like = re.sub(r"\btrue\b", "True", normalized, flags=re.IGNORECASE) + py_like = re.sub(r"\bfalse\b", "False", py_like, flags=re.IGNORECASE) + py_like = re.sub(r"\bnull\b", "None", py_like, flags=re.IGNORECASE) + parsed = ast.literal_eval(py_like) + if isinstance(parsed, (dict, list)): + return parsed + raise ValueError("Parsed payload is neither dict nor list") + + +async def _repair_broken_json_with_llm(broken_text: str) -> str: + """Use LLM_MODEL to repair malformed OCR JSON text without adding new semantics.""" + prompt = ( + "你是JSON修复器。下面是OCR模型返回的可能损坏JSON文本。\n" + "任务:仅修复语法并输出一个可被 json.loads 直接解析的JSON。\n" + "硬性要求:\n" + "1) 只能输出JSON,不要Markdown,不要解释。\n" + "2) 保留原有字段和语义,不新增未出现的信息。\n" + "3) 若出现截断,允许最小闭合修复(补齐缺失括号/引号)。\n" + "4) 结果必须是对象或数组。\n\n" + "损坏文本如下:\n" + f"{broken_text}" + ) + return await _call_text_llm(prompt, max_tokens=1200) + async def _classify_via_api(image_path: str) -> tuple[SourceApp, PageType]: - import base64 full_path = settings.upload_path / image_path if not full_path.exists(): return SourceApp.other, PageType.unknown @@ -43,79 +186,64 @@ async def _classify_via_api(image_path: str) -> tuple[SourceApp, PageType]: image_b64 = base64.b64encode(full_path.read_bytes()).decode() prompt = ( "请分析这张手机截图,判断它来自哪个APP(wechat/alipay/bank/digital_wallet/other)" - "以及页面类型(bill_list/bill_detail/transfer_receipt/sms_notice/balance/unknown)。" - "只返回JSON: {\"source_app\": \"...\", \"page_type\": \"...\"}" + "以及页面类型(bill_list/bill_detail/transfer_receipt/sms_notice/balance/unknown)。\n" + '只返回JSON,格式: {"source_app": "...", "page_type": "..."}' ) try: - async with httpx.AsyncClient(timeout=30) as client: - resp = await client.post( - settings.LLM_API_URL, - headers={"Authorization": f"Bearer {settings.LLM_API_KEY}"}, - json={ - "model": settings.LLM_MODEL, - "messages": [ - { - "role": "user", - "content": [ - {"type": "text", "text": prompt}, - {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_b64}"}}, - ], - } - ], - "max_tokens": 200, - }, - ) - resp.raise_for_status() - text = resp.json()["choices"][0]["message"]["content"] - data = json.loads(text.strip().strip("`").removeprefix("json").strip()) - return SourceApp(data.get("source_app", "other")), PageType(data.get("page_type", "unknown")) + text = await _call_vision(prompt, image_b64, max_tokens=600) + data = _parse_json_response(text) + return SourceApp(data.get("source_app", "other")), PageType(data.get("page_type", "unknown")) except Exception as e: logger.warning("classify_page API failed: %s", e) return SourceApp.other, PageType.unknown -async def _extract_via_api(image_path: str, source_app: SourceApp, page_type: PageType) -> dict: - import base64 +async def _extract_via_api( + image_path: str, source_app: SourceApp, page_type: PageType +) -> tuple[dict | list, str]: full_path = settings.upload_path / image_path if not full_path.exists(): - return {} + return {}, "" image_b64 = base64.b64encode(full_path.read_bytes()).decode() prompt = ( - f"这是一张来自{source_app.value}的{page_type.value}截图。" - "请提取其中的交易信息,返回JSON格式,字段包括:" - "trade_time(交易时间,格式YYYY-MM-DD HH:MM:SS), amount(金额,数字), " - "direction(in或out), counterparty_name(对方名称), counterparty_account(对方账号), " - "self_account_tail_no(本方账户尾号), order_no(订单号), remark(备注), confidence(0-1)。" - "如果截图包含多笔交易,返回JSON数组。否则返回单个JSON对象。" + "你是账单OCR结构化引擎。你的输出会被程序直接 json.loads 解析,任何非JSON字符都会导致任务失败。\n" + f"输入图片来自 {source_app.value} / {page_type.value}。\n\n" + "【硬性输出规则】\n" + "1) 只能输出一个合法JSON值:对象或数组;禁止Markdown代码块、禁止解释文字、禁止注释。\n" + "2) 键名必须使用英文双引号;字符串值必须使用英文双引号;禁止单引号、禁止中文引号。\n" + "3) 禁止尾逗号,禁止 NaN/Infinity/undefined。\n" + "4) 若只有1笔交易,输出对象;若>=2笔交易,输出数组。\n" + "5) 每笔交易字段固定为:\n" + ' {"trade_time":"YYYY-MM-DD HH:MM:SS","amount":123.45,"direction":"in|out",' + '"counterparty_name":"","counterparty_account":"","self_account_tail_no":"","order_no":"","remark":"","confidence":0.0}\n' + "6) 字段约束:\n" + "- trade_time: 必须为 YYYY-MM-DD HH:MM:SS;无法识别时填空字符串\"\"。\n" + "- amount: 必须是数字(不要货币符号);无法识别时填 0。\n" + "- direction: 仅允许 in 或 out;无法判断默认 out。\n" + "- confidence: 0~1 的数字;无法判断默认 0.5。\n" + "- 其他文本字段无法识别时填空字符串\"\"。\n\n" + "【输出前自检】\n" + "- 检查是否是严格JSON(可被 json.loads 直接解析)。\n" + "- 检查每条记录都包含全部9个字段。\n" + "- 检查 amount/confidence 为数字类型,direction 仅为 in/out。\n" + "现在只输出最终JSON,不要输出任何额外文本。" ) try: - async with httpx.AsyncClient(timeout=60) as client: - resp = await client.post( - settings.LLM_API_URL, - headers={"Authorization": f"Bearer {settings.LLM_API_KEY}"}, - json={ - "model": settings.LLM_MODEL, - "messages": [ - { - "role": "user", - "content": [ - {"type": "text", "text": prompt}, - {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_b64}"}}, - ], - } - ], - "max_tokens": 2000, - }, - ) - resp.raise_for_status() - text = resp.json()["choices"][0]["message"]["content"] - return json.loads(text.strip().strip("`").removeprefix("json").strip()) + text = await _call_vision(prompt, image_b64, max_tokens=6000) + try: + parsed = _parse_json_response(text) + except Exception as parse_err: + if not ENABLE_LLM_REPAIR or not _llm_available(): + raise + repaired_text = await _repair_broken_json_with_llm(text) + parsed = _parse_json_response(repaired_text) + return parsed, text except Exception as e: logger.warning("extract_transaction_fields API failed: %s", e) - return {} + return {}, text if "text" in locals() else "" # ── mock fallback ──────────────────────────────────────────────────────── diff --git a/backend/app/services/parser_service.py b/backend/app/services/parser_service.py index 579ce5c..3b8c037 100644 --- a/backend/app/services/parser_service.py +++ b/backend/app/services/parser_service.py @@ -16,31 +16,42 @@ def parse_extracted_fields( items = raw if isinstance(raw, list) else [raw] records: list[TransactionRecord] = [] - for item in items: + for idx, item in enumerate(items): if not item or not item.get("amount"): continue + raw_trade_time = item.get("trade_time") try: - trade_time = datetime.fromisoformat(item["trade_time"]) - except (ValueError, KeyError): + if isinstance(raw_trade_time, datetime): + trade_time = raw_trade_time + elif isinstance(raw_trade_time, str): + trade_time = datetime.fromisoformat(raw_trade_time) + else: + raise TypeError("trade_time is not str/datetime") + except (ValueError, KeyError, TypeError): trade_time = datetime.now() direction_str = item.get("direction", "out") direction = Direction.in_ if direction_str == "in" else Direction.out + try: + amount = float(item.get("amount", 0)) + confidence = float(item.get("confidence", 0.5)) + except Exception: + raise record = TransactionRecord( case_id=case_id, evidence_image_id=image_id, source_app=source_app, trade_time=trade_time, - amount=float(item.get("amount", 0)), + amount=amount, direction=direction, counterparty_name=str(item.get("counterparty_name", "")), counterparty_account=str(item.get("counterparty_account", "")), self_account_tail_no=str(item.get("self_account_tail_no", "")), order_no=str(item.get("order_no", "")), remark=str(item.get("remark", "")), - confidence=float(item.get("confidence", 0.5)), + confidence=confidence, ) records.append(record) diff --git a/backend/app/workers/ocr_tasks.py b/backend/app/workers/ocr_tasks.py index e0d8e40..8667f68 100644 --- a/backend/app/workers/ocr_tasks.py +++ b/backend/app/workers/ocr_tasks.py @@ -1,6 +1,7 @@ """Celery tasks for OCR processing of uploaded screenshots.""" import asyncio import logging +import json from uuid import UUID from app.workers.celery_app import celery_app @@ -20,18 +21,19 @@ def _run_async(coro): @celery_app.task(name="app.workers.ocr_tasks.process_image_ocr", bind=True, max_retries=3) def process_image_ocr(self, image_id: str): """Process a single image: classify page, extract fields, save to DB.""" - _run_async(_process(image_id)) + _run_async(process_image_ocr_async(image_id)) -async def _process(image_id_str: str): +async def process_image_ocr_async(image_id_str: str): from app.core.database import async_session_factory + from sqlalchemy import delete from app.models.evidence_image import EvidenceImage, OcrStatus from app.models.ocr_block import OcrBlock + from app.models.transaction import TransactionRecord from app.services.ocr_service import classify_page, extract_transaction_fields from app.services.parser_service import parse_extracted_fields image_id = UUID(image_id_str) - async with async_session_factory() as db: image = await db.get(EvidenceImage, image_id) if not image: @@ -40,23 +42,39 @@ async def _process(image_id_str: str): image.ocr_status = OcrStatus.processing await db.flush() + # Re-run OCR for this image should replace old OCR blocks/records. + await db.execute(delete(OcrBlock).where(OcrBlock.image_id == image.id)) + await db.execute(delete(TransactionRecord).where(TransactionRecord.evidence_image_id == image.id)) try: source_app, page_type = await classify_page(image.file_path) image.source_app = source_app image.page_type = page_type - raw_fields = await extract_transaction_fields(image.file_path, source_app, page_type) + raw_fields, raw_ocr_text = await extract_transaction_fields(image.file_path, source_app, page_type) - # save raw OCR block - block = OcrBlock( - image_id=image.id, - content=str(raw_fields), - bbox={}, - seq_order=0, - confidence=raw_fields.get("confidence", 0.5) if isinstance(raw_fields, dict) else 0.5, - ) - db.add(block) + is_empty_extract = raw_fields is None or raw_fields == {} or raw_fields == [] + + # save raw OCR block (direct model output for debugging) + raw_block_content = (raw_ocr_text or "").strip() + if raw_block_content: + block = OcrBlock( + image_id=image.id, + content=raw_block_content, + bbox={}, + seq_order=0, + confidence=raw_fields.get("confidence", 0.5) if isinstance(raw_fields, dict) else 0.5, + ) + db.add(block) + elif not is_empty_extract: + block = OcrBlock( + image_id=image.id, + content=json.dumps(raw_fields, ensure_ascii=False), + bbox={}, + seq_order=0, + confidence=raw_fields.get("confidence", 0.5) if isinstance(raw_fields, dict) else 0.5, + ) + db.add(block) # parse into transaction records records = parse_extracted_fields(raw_fields, image.case_id, image.id, source_app) diff --git a/backend/requirements.txt b/backend/requirements.txt new file mode 100644 index 0000000..65265bc --- /dev/null +++ b/backend/requirements.txt @@ -0,0 +1,18 @@ +fastapi==0.135.1 +uvicorn[standard]==0.41.0 +sqlalchemy[asyncio]==2.0.48 +asyncpg==0.31.0 +alembic==1.18.4 +celery[redis]==5.6.2 +redis==6.4.0 +pydantic-settings>=2.0.0 +python-multipart==0.0.22 +Pillow==12.1.0 +httpx==0.28.1 +openpyxl==3.1.5 +python-docx==1.2.0 +psycopg2-binary==2.9.11 + +# dev dependencies +pytest==9.0.2 +pytest-asyncio>=0.24.0