This commit is contained in:
2026-03-12 12:32:29 +08:00
parent c0f9ddabbf
commit 470446fa6f
18 changed files with 591 additions and 142 deletions

View File

@@ -1,8 +1,16 @@
from logging.config import fileConfig from logging.config import fileConfig
import sys
from pathlib import Path
from sqlalchemy import engine_from_config, pool from sqlalchemy import engine_from_config, pool
from alembic import context from alembic import context
# Ensure `backend/` is on sys.path so `import app...` works
# no matter where `alembic` is executed from.
BACKEND_ROOT = Path(__file__).resolve().parents[1]
if str(BACKEND_ROOT) not in sys.path:
sys.path.insert(0, str(BACKEND_ROOT))
from app.core.config import settings from app.core.config import settings
from app.core.database import Base from app.core.database import Base
import app.models # noqa: F401 ensure all models are imported import app.models # noqa: F401 ensure all models are imported

View File

@@ -8,6 +8,7 @@ from typing import Sequence, Union
from alembic import op from alembic import op
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy import Text
${imports if imports else ""} ${imports if imports else ""}
revision: str = ${repr(up_revision)} revision: str = ${repr(up_revision)}

View File

@@ -0,0 +1,169 @@
"""init_sqlite_test
Revision ID: be562b8079e3
Revises:
Create Date: 2026-03-11 17:33:30.695730
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from sqlalchemy import Text
from sqlalchemy.dialects import postgresql
revision: str = 'be562b8079e3'
down_revision: Union[str, None] = None
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('cases',
sa.Column('id', sa.UUID(), nullable=False),
sa.Column('case_no', sa.String(length=64), nullable=False),
sa.Column('title', sa.String(length=256), nullable=False),
sa.Column('victim_name', sa.String(length=128), nullable=False),
sa.Column('handler', sa.String(length=128), nullable=False),
sa.Column('status', sa.Enum('pending', 'uploading', 'analyzing', 'reviewing', 'completed', name='casestatus'), nullable=False),
sa.Column('image_count', sa.Integer(), nullable=False),
sa.Column('total_amount', sa.Numeric(precision=14, scale=2), nullable=False),
sa.Column('created_by', sa.String(length=128), nullable=False),
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=False),
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=False),
sa.Column('deleted_at', sa.DateTime(timezone=True), nullable=True),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_cases_case_no'), 'cases', ['case_no'], unique=True)
op.create_table('evidence_images',
sa.Column('id', sa.UUID(), nullable=False),
sa.Column('case_id', sa.UUID(), nullable=False),
sa.Column('file_path', sa.String(length=512), nullable=False),
sa.Column('thumb_path', sa.String(length=512), nullable=False),
sa.Column('source_app', sa.Enum('wechat', 'alipay', 'bank', 'digital_wallet', 'other', name='sourceapp'), nullable=False),
sa.Column('page_type', sa.Enum('bill_list', 'bill_detail', 'transfer_receipt', 'sms_notice', 'balance', 'unknown', name='pagetype'), nullable=False),
sa.Column('ocr_status', sa.Enum('pending', 'processing', 'done', 'failed', name='ocrstatus'), nullable=False),
sa.Column('file_hash', sa.String(length=128), nullable=False),
sa.Column('file_size', sa.Integer(), nullable=False),
sa.Column('uploaded_at', sa.DateTime(timezone=True), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=False),
sa.ForeignKeyConstraint(['case_id'], ['cases.id'], ),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_evidence_images_case_id'), 'evidence_images', ['case_id'], unique=False)
op.create_index(op.f('ix_evidence_images_file_hash'), 'evidence_images', ['file_hash'], unique=True)
op.create_table('export_reports',
sa.Column('id', sa.UUID(), nullable=False),
sa.Column('case_id', sa.UUID(), nullable=False),
sa.Column('report_type', sa.Enum('pdf', 'excel', 'word', name='reporttype'), nullable=False),
sa.Column('file_path', sa.String(length=512), nullable=False),
sa.Column('version', sa.Integer(), nullable=False),
sa.Column('content_snapshot', postgresql.JSONB(astext_type=Text()), nullable=False),
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=False),
sa.ForeignKeyConstraint(['case_id'], ['cases.id'], ),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_export_reports_case_id'), 'export_reports', ['case_id'], unique=False)
op.create_table('fund_flow_edges',
sa.Column('id', sa.UUID(), nullable=False),
sa.Column('case_id', sa.UUID(), nullable=False),
sa.Column('source_node', sa.String(length=256), nullable=False),
sa.Column('target_node', sa.String(length=256), nullable=False),
sa.Column('source_type', sa.String(length=32), nullable=False),
sa.Column('target_type', sa.String(length=32), nullable=False),
sa.Column('amount', sa.Numeric(precision=14, scale=2), nullable=False),
sa.Column('tx_count', sa.Integer(), nullable=False),
sa.Column('earliest_time', sa.DateTime(timezone=True), nullable=True),
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=False),
sa.ForeignKeyConstraint(['case_id'], ['cases.id'], ),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_fund_flow_edges_case_id'), 'fund_flow_edges', ['case_id'], unique=False)
op.create_table('transaction_clusters',
sa.Column('id', sa.UUID(), nullable=False),
sa.Column('case_id', sa.UUID(), nullable=False),
sa.Column('primary_tx_id', sa.UUID(), nullable=True),
sa.Column('match_reason', sa.String(length=512), nullable=False),
sa.ForeignKeyConstraint(['case_id'], ['cases.id'], ),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_transaction_clusters_case_id'), 'transaction_clusters', ['case_id'], unique=False)
op.create_table('ocr_blocks',
sa.Column('id', sa.UUID(), nullable=False),
sa.Column('image_id', sa.UUID(), nullable=False),
sa.Column('content', sa.Text(), nullable=False),
sa.Column('bbox', postgresql.JSONB(astext_type=Text()), nullable=False),
sa.Column('seq_order', sa.Integer(), nullable=False),
sa.Column('confidence', sa.Float(), nullable=False),
sa.ForeignKeyConstraint(['image_id'], ['evidence_images.id'], ),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_ocr_blocks_image_id'), 'ocr_blocks', ['image_id'], unique=False)
op.create_table('transaction_records',
sa.Column('id', sa.UUID(), nullable=False),
sa.Column('case_id', sa.UUID(), nullable=False),
sa.Column('evidence_image_id', sa.UUID(), nullable=True),
sa.Column('cluster_id', sa.UUID(), nullable=True),
sa.Column('source_app', sa.Enum('wechat', 'alipay', 'bank', 'digital_wallet', 'other', name='sourceapp'), nullable=False),
sa.Column('trade_time', sa.DateTime(timezone=True), nullable=False),
sa.Column('amount', sa.Numeric(precision=14, scale=2), nullable=False),
sa.Column('direction', sa.Enum('in_', 'out', name='direction'), nullable=False),
sa.Column('counterparty_name', sa.String(length=256), nullable=False),
sa.Column('counterparty_account', sa.String(length=256), nullable=False),
sa.Column('self_account_tail_no', sa.String(length=32), nullable=False),
sa.Column('order_no', sa.String(length=128), nullable=False),
sa.Column('remark', sa.Text(), nullable=False),
sa.Column('confidence', sa.Float(), nullable=False),
sa.Column('is_duplicate', sa.Boolean(), nullable=False),
sa.Column('is_transit', sa.Boolean(), nullable=False),
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=False),
sa.ForeignKeyConstraint(['case_id'], ['cases.id'], ),
sa.ForeignKeyConstraint(['cluster_id'], ['transaction_clusters.id'], ),
sa.ForeignKeyConstraint(['evidence_image_id'], ['evidence_images.id'], ),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_transaction_records_case_id'), 'transaction_records', ['case_id'], unique=False)
op.create_index(op.f('ix_transaction_records_order_no'), 'transaction_records', ['order_no'], unique=False)
op.create_index(op.f('ix_transaction_records_trade_time'), 'transaction_records', ['trade_time'], unique=False)
op.create_table('fraud_assessments',
sa.Column('id', sa.UUID(), nullable=False),
sa.Column('case_id', sa.UUID(), nullable=False),
sa.Column('transaction_id', sa.UUID(), nullable=False),
sa.Column('confidence_level', sa.Enum('high', 'medium', 'low', name='confidencelevel'), nullable=False),
sa.Column('assessed_amount', sa.Numeric(precision=14, scale=2), nullable=False),
sa.Column('reason', sa.Text(), nullable=False),
sa.Column('exclude_reason', sa.Text(), nullable=False),
sa.Column('review_status', sa.Enum('pending', 'confirmed', 'rejected', 'needs_info', name='reviewstatus'), nullable=False),
sa.Column('review_note', sa.Text(), nullable=False),
sa.Column('reviewed_by', sa.String(length=128), nullable=False),
sa.Column('reviewed_at', sa.DateTime(timezone=True), nullable=True),
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=False),
sa.ForeignKeyConstraint(['case_id'], ['cases.id'], ),
sa.ForeignKeyConstraint(['transaction_id'], ['transaction_records.id'], ),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_fraud_assessments_case_id'), 'fraud_assessments', ['case_id'], unique=False)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index(op.f('ix_fraud_assessments_case_id'), table_name='fraud_assessments')
op.drop_table('fraud_assessments')
op.drop_index(op.f('ix_transaction_records_trade_time'), table_name='transaction_records')
op.drop_index(op.f('ix_transaction_records_order_no'), table_name='transaction_records')
op.drop_index(op.f('ix_transaction_records_case_id'), table_name='transaction_records')
op.drop_table('transaction_records')
op.drop_index(op.f('ix_ocr_blocks_image_id'), table_name='ocr_blocks')
op.drop_table('ocr_blocks')
op.drop_index(op.f('ix_transaction_clusters_case_id'), table_name='transaction_clusters')
op.drop_table('transaction_clusters')
op.drop_index(op.f('ix_fund_flow_edges_case_id'), table_name='fund_flow_edges')
op.drop_table('fund_flow_edges')
op.drop_index(op.f('ix_export_reports_case_id'), table_name='export_reports')
op.drop_table('export_reports')
op.drop_index(op.f('ix_evidence_images_file_hash'), table_name='evidence_images')
op.drop_index(op.f('ix_evidence_images_case_id'), table_name='evidence_images')
op.drop_table('evidence_images')
op.drop_index(op.f('ix_cases_case_no'), table_name='cases')
op.drop_table('cases')
# ### end Alembic commands ###

View File

@@ -2,10 +2,12 @@ from uuid import UUID
from datetime import datetime, timezone from datetime import datetime, timezone
from fastapi import APIRouter, Depends, Query, HTTPException from fastapi import APIRouter, Depends, Query, HTTPException
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from app.core.database import get_db from app.core.database import get_db
from app.models.assessment import ConfidenceLevel from app.models.assessment import FraudAssessment, ConfidenceLevel
from app.repositories.assessment_repo import AssessmentRepository from app.repositories.assessment_repo import AssessmentRepository
from app.schemas.assessment import ( from app.schemas.assessment import (
AssessmentOut, AssessmentOut,
@@ -46,6 +48,14 @@ async def review_assessment(
"reviewed_by": body.reviewed_by, "reviewed_by": body.reviewed_by,
"reviewed_at": datetime.now(timezone.utc), "reviewed_at": datetime.now(timezone.utc),
}) })
# eager-load the transaction relationship to avoid lazy-load in async context
result = await db.execute(
select(FraudAssessment)
.options(selectinload(FraudAssessment.transaction))
.where(FraudAssessment.id == assessment_id)
)
assessment = result.scalar_one()
return assessment return assessment

View File

@@ -1,15 +1,16 @@
from uuid import UUID from uuid import UUID
import asyncio
from fastapi import APIRouter, Depends, UploadFile, File, HTTPException, Query from fastapi import APIRouter, Depends, UploadFile, File, HTTPException
from fastapi.responses import FileResponse from fastapi.responses import FileResponse
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.core.config import settings from app.core.config import settings
from app.core.database import get_db from app.core.database import get_db
from app.models.evidence_image import EvidenceImage, SourceApp, PageType from app.models.evidence_image import EvidenceImage, SourceApp, PageType, OcrStatus
from app.repositories.image_repo import ImageRepository from app.repositories.image_repo import ImageRepository
from app.repositories.case_repo import CaseRepository from app.repositories.case_repo import CaseRepository
from app.schemas.image import ImageOut, ImageDetailOut, OcrFieldCorrection from app.schemas.image import ImageOut, ImageDetailOut, OcrFieldCorrection, CaseOcrStartIn
from app.utils.hash import sha256_file from app.utils.hash import sha256_file
from app.utils.file_storage import save_upload from app.utils.file_storage import save_upload
@@ -32,9 +33,11 @@ async def upload_images(
for f in files: for f in files:
data = await f.read() data = await f.read()
file_hash = sha256_file(data) raw_hash = sha256_file(data)
# Scope hash by case to avoid cross-case unique conflicts while still deduplicating inside one case.
scoped_hash = f"{raw_hash}:{case_id}"
existing = await img_repo.find_by_hash(file_hash) existing = await img_repo.find_by_hash_in_case(case_id, [raw_hash, scoped_hash])
if existing: if existing:
results.append(existing) results.append(existing)
continue continue
@@ -44,7 +47,7 @@ async def upload_images(
case_id=case_id, case_id=case_id,
file_path=file_path, file_path=file_path,
thumb_path=thumb_path, thumb_path=thumb_path,
file_hash=file_hash, file_hash=scoped_hash,
file_size=len(data), file_size=len(data),
) )
image = await img_repo.create(image) image = await img_repo.create(image)
@@ -53,14 +56,11 @@ async def upload_images(
case.image_count = await img_repo.count_by_case(case_id) case.image_count = await img_repo.count_by_case(case_id)
await db.flush() await db.flush()
# trigger OCR tasks (non-blocking) # trigger OCR tasks in-process background (non-blocking for API response)
from app.workers.ocr_tasks import process_image_ocr from app.workers.ocr_tasks import process_image_ocr_async
for img in results: for img in results:
if img.ocr_status.value == "pending": if img.ocr_status.value == "pending":
try: asyncio.create_task(process_image_ocr_async(str(img.id)))
process_image_ocr.delay(str(img.id))
except Exception:
pass
return results return results
@@ -73,7 +73,21 @@ async def list_images(
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
): ):
repo = ImageRepository(db) repo = ImageRepository(db)
return await repo.list_by_case(case_id, source_app=source_app, page_type=page_type) images = await repo.list_by_case(case_id, source_app=source_app, page_type=page_type)
return [
ImageOut(
id=img.id,
case_id=img.case_id,
url=f"/api/v1/images/{img.id}/file",
thumb_url=f"/api/v1/images/{img.id}/file",
source_app=img.source_app,
page_type=img.page_type,
ocr_status=img.ocr_status,
file_hash=img.file_hash,
uploaded_at=img.uploaded_at,
)
for img in images
]
@router.get("/images/{image_id}", response_model=ImageDetailOut) @router.get("/images/{image_id}", response_model=ImageDetailOut)
@@ -128,3 +142,43 @@ async def get_image_file(image_id: UUID, db: AsyncSession = Depends(get_db)):
if not full_path.exists(): if not full_path.exists():
raise HTTPException(404, "文件不存在") raise HTTPException(404, "文件不存在")
return FileResponse(full_path) return FileResponse(full_path)
@router.post("/cases/{case_id}/ocr/start")
async def start_case_ocr(
case_id: UUID,
payload: CaseOcrStartIn | None = None,
db: AsyncSession = Depends(get_db),
):
case_repo = CaseRepository(db)
case = await case_repo.get(case_id)
if not case:
raise HTTPException(404, "案件不存在")
repo = ImageRepository(db)
include_done = payload.include_done if payload else False
image_ids = payload.image_ids if payload else []
if image_ids:
images = await repo.list_by_ids_in_case(case_id, image_ids)
# For explicit re-run, mark selected images as processing immediately
# so frontend can reflect state transition without full page refresh.
for img in images:
img.ocr_status = OcrStatus.processing
await db.flush()
await db.commit()
else:
images = await repo.list_for_ocr(case_id, include_done=include_done)
from app.workers.ocr_tasks import process_image_ocr_async
submitted = 0
for img in images:
asyncio.create_task(process_image_ocr_async(str(img.id)))
submitted += 1
return {
"caseId": str(case_id),
"submitted": submitted,
"totalCandidates": len(images),
"message": f"已提交 {submitted} 张截图的 OCR 任务",
}

View File

@@ -17,6 +17,7 @@ class Settings(BaseSettings):
OCR_API_KEY: str = "" OCR_API_KEY: str = ""
OCR_API_URL: str = "" OCR_API_URL: str = ""
OCR_MODEL: str = ""
LLM_API_KEY: str = "" LLM_API_KEY: str = ""
LLM_API_URL: str = "" LLM_API_URL: str = ""
LLM_MODEL: str = "" LLM_MODEL: str = ""

View File

@@ -3,7 +3,7 @@ from uuid import UUID
from sqlalchemy import select, func from sqlalchemy import select, func
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.models.evidence_image import EvidenceImage, SourceApp, PageType from app.models.evidence_image import EvidenceImage, SourceApp, PageType, OcrStatus
from app.repositories.base import BaseRepository from app.repositories.base import BaseRepository
@@ -17,6 +17,17 @@ class ImageRepository(BaseRepository[EvidenceImage]):
) )
return result.scalar_one_or_none() return result.scalar_one_or_none()
async def find_by_hash_in_case(self, case_id: UUID, file_hashes: list[str]) -> EvidenceImage | None:
if not file_hashes:
return None
result = await self.session.execute(
select(EvidenceImage).where(
EvidenceImage.case_id == case_id,
EvidenceImage.file_hash.in_(file_hashes),
)
)
return result.scalar_one_or_none()
async def list_by_case( async def list_by_case(
self, self,
case_id: UUID, case_id: UUID,
@@ -37,3 +48,20 @@ class ImageRepository(BaseRepository[EvidenceImage]):
select(func.count()).select_from(EvidenceImage).where(EvidenceImage.case_id == case_id) select(func.count()).select_from(EvidenceImage).where(EvidenceImage.case_id == case_id)
) )
return result.scalar() or 0 return result.scalar() or 0
async def list_for_ocr(self, case_id: UUID, include_done: bool = False) -> list[EvidenceImage]:
query = select(EvidenceImage).where(EvidenceImage.case_id == case_id)
if not include_done:
query = query.where(EvidenceImage.ocr_status != OcrStatus.done)
result = await self.session.execute(query.order_by(EvidenceImage.uploaded_at.desc()))
return list(result.scalars().all())
async def list_by_ids_in_case(self, case_id: UUID, image_ids: list[UUID]) -> list[EvidenceImage]:
if not image_ids:
return []
result = await self.session.execute(
select(EvidenceImage)
.where(EvidenceImage.case_id == case_id, EvidenceImage.id.in_(image_ids))
.order_by(EvidenceImage.uploaded_at.desc())
)
return list(result.scalars().all())

View File

@@ -1,7 +1,7 @@
from pydantic import BaseModel from app.schemas.base import CamelModel
class AnalysisStatusOut(BaseModel): class AnalysisStatusOut(CamelModel):
case_id: str case_id: str
status: str status: str
progress: int = 0 progress: int = 0
@@ -9,6 +9,6 @@ class AnalysisStatusOut(BaseModel):
message: str = "" message: str = ""
class AnalysisTriggerOut(BaseModel): class AnalysisTriggerOut(CamelModel):
task_id: str task_id: str
message: str message: str

View File

@@ -1,13 +1,12 @@
from datetime import datetime from datetime import datetime
from uuid import UUID from uuid import UUID
from pydantic import BaseModel
from app.models.assessment import ConfidenceLevel, ReviewStatus from app.models.assessment import ConfidenceLevel, ReviewStatus
from app.schemas.base import CamelModel
from app.schemas.transaction import TransactionOut from app.schemas.transaction import TransactionOut
class AssessmentOut(BaseModel): class AssessmentOut(CamelModel):
id: UUID id: UUID
case_id: UUID case_id: UUID
transaction_id: UUID transaction_id: UUID
@@ -21,19 +20,17 @@ class AssessmentOut(BaseModel):
reviewed_by: str reviewed_by: str
reviewed_at: datetime | None = None reviewed_at: datetime | None = None
model_config = {"from_attributes": True}
class AssessmentListOut(CamelModel):
class AssessmentListOut(BaseModel):
items: list[AssessmentOut] items: list[AssessmentOut]
total: int total: int
class ReviewSubmit(BaseModel): class ReviewSubmit(CamelModel):
review_status: ReviewStatus review_status: ReviewStatus
review_note: str = "" review_note: str = ""
reviewed_by: str = "demo_user" reviewed_by: str = "demo_user"
class InquirySuggestionOut(BaseModel): class InquirySuggestionOut(CamelModel):
suggestions: list[str] suggestions: list[str]

View File

@@ -0,0 +1,15 @@
"""Base schema with camelCase JSON serialization."""
from pydantic import BaseModel, ConfigDict
def to_camel(s: str) -> str:
parts = s.split("_")
return parts[0] + "".join(p.capitalize() for p in parts[1:])
class CamelModel(BaseModel):
model_config = ConfigDict(
from_attributes=True,
alias_generator=to_camel,
populate_by_name=True,
)

View File

@@ -1,26 +1,25 @@
from datetime import datetime from datetime import datetime
from uuid import UUID from uuid import UUID
from pydantic import BaseModel
from app.models.case import CaseStatus from app.models.case import CaseStatus
from app.schemas.base import CamelModel
class CaseCreate(BaseModel): class CaseCreate(CamelModel):
case_no: str case_no: str
title: str title: str
victim_name: str victim_name: str
handler: str = "" handler: str = ""
class CaseUpdate(BaseModel): class CaseUpdate(CamelModel):
title: str | None = None title: str | None = None
victim_name: str | None = None victim_name: str | None = None
handler: str | None = None handler: str | None = None
status: CaseStatus | None = None status: CaseStatus | None = None
class CaseOut(BaseModel): class CaseOut(CamelModel):
id: UUID id: UUID
case_no: str case_no: str
title: str title: str
@@ -32,9 +31,7 @@ class CaseOut(BaseModel):
created_at: datetime created_at: datetime
updated_at: datetime updated_at: datetime
model_config = {"from_attributes": True}
class CaseListOut(CamelModel):
class CaseListOut(BaseModel):
items: list[CaseOut] items: list[CaseOut]
total: int total: int

View File

@@ -1,12 +1,11 @@
from datetime import datetime from datetime import datetime
from uuid import UUID from uuid import UUID
from pydantic import BaseModel
from app.models.evidence_image import SourceApp, PageType, OcrStatus from app.models.evidence_image import SourceApp, PageType, OcrStatus
from app.schemas.base import CamelModel
class ImageOut(BaseModel): class ImageOut(CamelModel):
id: UUID id: UUID
case_id: UUID case_id: UUID
url: str = "" url: str = ""
@@ -17,24 +16,25 @@ class ImageOut(BaseModel):
file_hash: str file_hash: str
uploaded_at: datetime uploaded_at: datetime
model_config = {"from_attributes": True}
class OcrBlockOut(CamelModel):
class OcrBlockOut(BaseModel):
id: UUID id: UUID
content: str content: str
bbox: dict bbox: dict
seq_order: int seq_order: int
confidence: float confidence: float
model_config = {"from_attributes": True}
class ImageDetailOut(ImageOut): class ImageDetailOut(ImageOut):
ocr_blocks: list[OcrBlockOut] = [] ocr_blocks: list[OcrBlockOut] = []
class OcrFieldCorrection(BaseModel): class OcrFieldCorrection(CamelModel):
field_name: str field_name: str
old_value: str old_value: str
new_value: str new_value: str
class CaseOcrStartIn(CamelModel):
include_done: bool = False
image_ids: list[UUID] = []

View File

@@ -1,12 +1,11 @@
from datetime import datetime from datetime import datetime
from uuid import UUID from uuid import UUID
from pydantic import BaseModel
from app.models.report import ReportType from app.models.report import ReportType
from app.schemas.base import CamelModel
class ReportCreate(BaseModel): class ReportCreate(CamelModel):
report_type: ReportType report_type: ReportType
include_summary: bool = True include_summary: bool = True
include_transactions: bool = True include_transactions: bool = True
@@ -17,7 +16,7 @@ class ReportCreate(BaseModel):
include_screenshots: bool = False include_screenshots: bool = False
class ReportOut(BaseModel): class ReportOut(CamelModel):
id: UUID id: UUID
case_id: UUID case_id: UUID
report_type: ReportType report_type: ReportType
@@ -25,9 +24,7 @@ class ReportOut(BaseModel):
version: int version: int
created_at: datetime created_at: datetime
model_config = {"from_attributes": True}
class ReportListOut(CamelModel):
class ReportListOut(BaseModel):
items: list[ReportOut] items: list[ReportOut]
total: int total: int

View File

@@ -1,13 +1,12 @@
from datetime import datetime from datetime import datetime
from uuid import UUID from uuid import UUID
from pydantic import BaseModel
from app.models.evidence_image import SourceApp from app.models.evidence_image import SourceApp
from app.models.transaction import Direction from app.models.transaction import Direction
from app.schemas.base import CamelModel
class TransactionOut(BaseModel): class TransactionOut(CamelModel):
id: UUID id: UUID
case_id: UUID case_id: UUID
source_app: SourceApp source_app: SourceApp
@@ -25,21 +24,19 @@ class TransactionOut(BaseModel):
is_duplicate: bool is_duplicate: bool
is_transit: bool is_transit: bool
model_config = {"from_attributes": True}
class TransactionListOut(CamelModel):
class TransactionListOut(BaseModel):
items: list[TransactionOut] items: list[TransactionOut]
total: int total: int
class FlowNodeOut(BaseModel): class FlowNodeOut(CamelModel):
id: str id: str
label: str label: str
type: str type: str
class FlowEdgeOut(BaseModel): class FlowEdgeOut(CamelModel):
source: str source: str
target: str target: str
amount: float amount: float
@@ -47,6 +44,6 @@ class FlowEdgeOut(BaseModel):
trade_time: str trade_time: str
class FlowGraphOut(BaseModel): class FlowGraphOut(CamelModel):
nodes: list[FlowNodeOut] nodes: list[FlowNodeOut]
edges: list[FlowEdgeOut] edges: list[FlowEdgeOut]

View File

@@ -1,12 +1,20 @@
"""OCR and multimodal extraction service. """OCR and multimodal extraction service.
Wraps calls to cloud OCR / multimodal APIs with a provider-agnostic interface. Both classify_page and extract_transaction_fields use OpenAI-compatible
When API keys are not configured, falls back to a mock implementation that multimodal chat completion APIs. OCR and LLM can point to different
returns placeholder data (sufficient for demo / competition). providers / models via separate env vars:
OCR_API_URL / OCR_API_KEY / OCR_MODEL — for page classification & field extraction
LLM_API_URL / LLM_API_KEY / LLM_MODEL — for reasoning tasks (assessment, suggestions)
When OCR keys are not set, falls back to LLM keys.
When neither is set, returns mock data (sufficient for demo).
""" """
import base64
import json import json
import logging import logging
from pathlib import Path import ast
import re
import httpx import httpx
@@ -14,28 +22,163 @@ from app.core.config import settings
from app.models.evidence_image import SourceApp, PageType from app.models.evidence_image import SourceApp, PageType
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
ENABLE_LLM_REPAIR = False # temporary: disabled per debugging request
def _ocr_config() -> tuple[str, str, str]:
"""Return (api_url, api_key, model) for OCR, falling back to LLM config."""
url = settings.OCR_API_URL or settings.LLM_API_URL
key = settings.OCR_API_KEY or settings.LLM_API_KEY
model = settings.OCR_MODEL or settings.LLM_MODEL
return url, key, model
def _llm_config() -> tuple[str, str, str]:
"""Return (api_url, api_key, model) for text-only LLM repair."""
return settings.LLM_API_URL, settings.LLM_API_KEY, settings.LLM_MODEL
def _ocr_available() -> bool:
url, key, model = _ocr_config()
return bool(url and key and model)
def _llm_available() -> bool:
url, key, model = _llm_config()
return bool(url and key and model)
# ── provider-agnostic interface ────────────────────────────────────────── # ── provider-agnostic interface ──────────────────────────────────────────
async def classify_page(image_path: str) -> tuple[SourceApp, PageType]: async def classify_page(image_path: str) -> tuple[SourceApp, PageType]:
"""Identify the source app and page type of a screenshot.""" """Identify the source app and page type of a screenshot."""
if settings.LLM_API_KEY and settings.LLM_API_URL: if _ocr_available():
return await _classify_via_api(image_path) return await _classify_via_api(image_path)
return _classify_mock(image_path) return _classify_mock(image_path)
async def extract_transaction_fields(image_path: str, source_app: SourceApp, page_type: PageType) -> dict: async def extract_transaction_fields(
image_path: str, source_app: SourceApp, page_type: PageType
) -> tuple[dict | list, str]:
"""Extract structured transaction fields from a screenshot.""" """Extract structured transaction fields from a screenshot."""
if settings.LLM_API_KEY and settings.LLM_API_URL: if _ocr_available():
return await _extract_via_api(image_path, source_app, page_type) return await _extract_via_api(image_path, source_app, page_type)
return _extract_mock(image_path, source_app, page_type) mock_data = _extract_mock(image_path, source_app, page_type)
return mock_data, json.dumps(mock_data, ensure_ascii=False)
# ── real API implementation ────────────────────────────────────────────── # ── OpenAI-compatible API implementation ─────────────────────────────────
async def _call_vision(prompt: str, image_b64: str, max_tokens: int = 2000) -> str:
"""Send a vision request to the OCR model endpoint."""
url, key, model = _ocr_config()
async with httpx.AsyncClient(timeout=60) as client:
resp = await client.post(
url,
headers={"Authorization": f"Bearer {key}"},
json={
"model": model,
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": prompt},
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_b64}"}},
],
}
],
"max_tokens": max_tokens,
"temperature": 0,
},
)
resp.raise_for_status()
body = resp.json()
choice0 = (body.get("choices") or [{}])[0]
msg = choice0.get("message") or {}
content = msg.get("content", "")
finish_reason = choice0.get("finish_reason")
usage = body.get("usage") or {}
return content if isinstance(content, str) else str(content)
async def _call_text_llm(prompt: str, max_tokens: int = 1200) -> str:
"""Send a text-only request to the reasoning LLM endpoint."""
url, key, model = _llm_config()
async with httpx.AsyncClient(timeout=60) as client:
resp = await client.post(
url,
headers={"Authorization": f"Bearer {key}"},
json={
"model": model,
"messages": [{"role": "user", "content": prompt}],
"max_tokens": max_tokens,
},
)
resp.raise_for_status()
return resp.json()["choices"][0]["message"]["content"]
def _parse_json_response(text: str):
"""Strip markdown fences and parse JSON from LLM output."""
cleaned = text.strip()
if cleaned.startswith("```"):
cleaned = cleaned.split("\n", 1)[-1]
if cleaned.endswith("```"):
cleaned = cleaned.rsplit("```", 1)[0]
cleaned = cleaned.strip().removeprefix("json").strip()
# Try parsing the full body first.
try:
return json.loads(cleaned)
except Exception:
pass
# Extract likely JSON block when model wraps with extra text.
match = re.search(r"(\{[\s\S]*\}|\[[\s\S]*\])", cleaned)
candidate = match.group(1) if match else cleaned
# Normalize common model output issues:
# - smart quotes
# - trailing commas
normalized = (
candidate
.replace("", "\"")
.replace("", "\"")
.replace("", "'")
.replace("", "'")
)
normalized = re.sub(r",\s*([}\]])", r"\1", normalized)
try:
return json.loads(normalized)
except Exception:
pass
# Last fallback: Python-like literal payloads with single quotes / True/False/None.
py_like = re.sub(r"\btrue\b", "True", normalized, flags=re.IGNORECASE)
py_like = re.sub(r"\bfalse\b", "False", py_like, flags=re.IGNORECASE)
py_like = re.sub(r"\bnull\b", "None", py_like, flags=re.IGNORECASE)
parsed = ast.literal_eval(py_like)
if isinstance(parsed, (dict, list)):
return parsed
raise ValueError("Parsed payload is neither dict nor list")
async def _repair_broken_json_with_llm(broken_text: str) -> str:
"""Use LLM_MODEL to repair malformed OCR JSON text without adding new semantics."""
prompt = (
"你是JSON修复器。下面是OCR模型返回的可能损坏JSON文本。\n"
"任务:仅修复语法并输出一个可被 json.loads 直接解析的JSON。\n"
"硬性要求:\n"
"1) 只能输出JSON不要Markdown不要解释。\n"
"2) 保留原有字段和语义,不新增未出现的信息。\n"
"3) 若出现截断,允许最小闭合修复(补齐缺失括号/引号)。\n"
"4) 结果必须是对象或数组。\n\n"
"损坏文本如下:\n"
f"{broken_text}"
)
return await _call_text_llm(prompt, max_tokens=1200)
async def _classify_via_api(image_path: str) -> tuple[SourceApp, PageType]: async def _classify_via_api(image_path: str) -> tuple[SourceApp, PageType]:
import base64
full_path = settings.upload_path / image_path full_path = settings.upload_path / image_path
if not full_path.exists(): if not full_path.exists():
return SourceApp.other, PageType.unknown return SourceApp.other, PageType.unknown
@@ -43,79 +186,64 @@ async def _classify_via_api(image_path: str) -> tuple[SourceApp, PageType]:
image_b64 = base64.b64encode(full_path.read_bytes()).decode() image_b64 = base64.b64encode(full_path.read_bytes()).decode()
prompt = ( prompt = (
"请分析这张手机截图判断它来自哪个APPwechat/alipay/bank/digital_wallet/other" "请分析这张手机截图判断它来自哪个APPwechat/alipay/bank/digital_wallet/other"
"以及页面类型bill_list/bill_detail/transfer_receipt/sms_notice/balance/unknown" "以及页面类型bill_list/bill_detail/transfer_receipt/sms_notice/balance/unknown\n"
"只返回JSON: {\"source_app\": \"...\", \"page_type\": \"...\"}" '只返回JSON,格式: {"source_app": "...", "page_type": "..."}'
) )
try: try:
async with httpx.AsyncClient(timeout=30) as client: text = await _call_vision(prompt, image_b64, max_tokens=600)
resp = await client.post( data = _parse_json_response(text)
settings.LLM_API_URL,
headers={"Authorization": f"Bearer {settings.LLM_API_KEY}"},
json={
"model": settings.LLM_MODEL,
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": prompt},
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_b64}"}},
],
}
],
"max_tokens": 200,
},
)
resp.raise_for_status()
text = resp.json()["choices"][0]["message"]["content"]
data = json.loads(text.strip().strip("`").removeprefix("json").strip())
return SourceApp(data.get("source_app", "other")), PageType(data.get("page_type", "unknown")) return SourceApp(data.get("source_app", "other")), PageType(data.get("page_type", "unknown"))
except Exception as e: except Exception as e:
logger.warning("classify_page API failed: %s", e) logger.warning("classify_page API failed: %s", e)
return SourceApp.other, PageType.unknown return SourceApp.other, PageType.unknown
async def _extract_via_api(image_path: str, source_app: SourceApp, page_type: PageType) -> dict: async def _extract_via_api(
import base64 image_path: str, source_app: SourceApp, page_type: PageType
) -> tuple[dict | list, str]:
full_path = settings.upload_path / image_path full_path = settings.upload_path / image_path
if not full_path.exists(): if not full_path.exists():
return {} return {}, ""
image_b64 = base64.b64encode(full_path.read_bytes()).decode() image_b64 = base64.b64encode(full_path.read_bytes()).decode()
prompt = ( prompt = (
f"这是一张来自{source_app.value}{page_type.value}截图。" "你是账单OCR结构化引擎。你的输出会被程序直接 json.loads 解析任何非JSON字符都会导致任务失败。\n"
"请提取其中的交易信息返回JSON格式字段包括" f"输入图片来自 {source_app.value} / {page_type.value}\n\n"
"trade_time(交易时间,格式YYYY-MM-DD HH:MM:SS), amount(金额,数字), " "【硬性输出规则】\n"
"direction(in或out), counterparty_name(对方名称), counterparty_account(对方账号), " "1) 只能输出一个合法JSON值对象或数组禁止Markdown代码块、禁止解释文字、禁止注释。\n"
"self_account_tail_no(本方账户尾号), order_no(订单号), remark(备注), confidence(0-1)。" "2) 键名必须使用英文双引号;字符串值必须使用英文双引号;禁止单引号、禁止中文引号。\n"
"如果截图包含多笔交易返回JSON数组。否则返回单个JSON对象。" "3) 禁止尾逗号,禁止 NaN/Infinity/undefined。\n"
"4) 若只有1笔交易输出对象若>=2笔交易输出数组。\n"
"5) 每笔交易字段固定为:\n"
' {"trade_time":"YYYY-MM-DD HH:MM:SS","amount":123.45,"direction":"in|out",'
'"counterparty_name":"","counterparty_account":"","self_account_tail_no":"","order_no":"","remark":"","confidence":0.0}\n'
"6) 字段约束:\n"
"- trade_time: 必须为 YYYY-MM-DD HH:MM:SS无法识别时填空字符串\"\"\n"
"- amount: 必须是数字(不要货币符号);无法识别时填 0。\n"
"- direction: 仅允许 in 或 out无法判断默认 out。\n"
"- confidence: 0~1 的数字;无法判断默认 0.5。\n"
"- 其他文本字段无法识别时填空字符串\"\"\n\n"
"【输出前自检】\n"
"- 检查是否是严格JSON可被 json.loads 直接解析)。\n"
"- 检查每条记录都包含全部9个字段。\n"
"- 检查 amount/confidence 为数字类型direction 仅为 in/out。\n"
"现在只输出最终JSON不要输出任何额外文本。"
) )
try: try:
async with httpx.AsyncClient(timeout=60) as client: text = await _call_vision(prompt, image_b64, max_tokens=6000)
resp = await client.post( try:
settings.LLM_API_URL, parsed = _parse_json_response(text)
headers={"Authorization": f"Bearer {settings.LLM_API_KEY}"}, except Exception as parse_err:
json={ if not ENABLE_LLM_REPAIR or not _llm_available():
"model": settings.LLM_MODEL, raise
"messages": [ repaired_text = await _repair_broken_json_with_llm(text)
{ parsed = _parse_json_response(repaired_text)
"role": "user", return parsed, text
"content": [
{"type": "text", "text": prompt},
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_b64}"}},
],
}
],
"max_tokens": 2000,
},
)
resp.raise_for_status()
text = resp.json()["choices"][0]["message"]["content"]
return json.loads(text.strip().strip("`").removeprefix("json").strip())
except Exception as e: except Exception as e:
logger.warning("extract_transaction_fields API failed: %s", e) logger.warning("extract_transaction_fields API failed: %s", e)
return {} return {}, text if "text" in locals() else ""
# ── mock fallback ──────────────────────────────────────────────────────── # ── mock fallback ────────────────────────────────────────────────────────

View File

@@ -16,31 +16,42 @@ def parse_extracted_fields(
items = raw if isinstance(raw, list) else [raw] items = raw if isinstance(raw, list) else [raw]
records: list[TransactionRecord] = [] records: list[TransactionRecord] = []
for item in items: for idx, item in enumerate(items):
if not item or not item.get("amount"): if not item or not item.get("amount"):
continue continue
raw_trade_time = item.get("trade_time")
try: try:
trade_time = datetime.fromisoformat(item["trade_time"]) if isinstance(raw_trade_time, datetime):
except (ValueError, KeyError): trade_time = raw_trade_time
elif isinstance(raw_trade_time, str):
trade_time = datetime.fromisoformat(raw_trade_time)
else:
raise TypeError("trade_time is not str/datetime")
except (ValueError, KeyError, TypeError):
trade_time = datetime.now() trade_time = datetime.now()
direction_str = item.get("direction", "out") direction_str = item.get("direction", "out")
direction = Direction.in_ if direction_str == "in" else Direction.out direction = Direction.in_ if direction_str == "in" else Direction.out
try:
amount = float(item.get("amount", 0))
confidence = float(item.get("confidence", 0.5))
except Exception:
raise
record = TransactionRecord( record = TransactionRecord(
case_id=case_id, case_id=case_id,
evidence_image_id=image_id, evidence_image_id=image_id,
source_app=source_app, source_app=source_app,
trade_time=trade_time, trade_time=trade_time,
amount=float(item.get("amount", 0)), amount=amount,
direction=direction, direction=direction,
counterparty_name=str(item.get("counterparty_name", "")), counterparty_name=str(item.get("counterparty_name", "")),
counterparty_account=str(item.get("counterparty_account", "")), counterparty_account=str(item.get("counterparty_account", "")),
self_account_tail_no=str(item.get("self_account_tail_no", "")), self_account_tail_no=str(item.get("self_account_tail_no", "")),
order_no=str(item.get("order_no", "")), order_no=str(item.get("order_no", "")),
remark=str(item.get("remark", "")), remark=str(item.get("remark", "")),
confidence=float(item.get("confidence", 0.5)), confidence=confidence,
) )
records.append(record) records.append(record)

View File

@@ -1,6 +1,7 @@
"""Celery tasks for OCR processing of uploaded screenshots.""" """Celery tasks for OCR processing of uploaded screenshots."""
import asyncio import asyncio
import logging import logging
import json
from uuid import UUID from uuid import UUID
from app.workers.celery_app import celery_app from app.workers.celery_app import celery_app
@@ -20,18 +21,19 @@ def _run_async(coro):
@celery_app.task(name="app.workers.ocr_tasks.process_image_ocr", bind=True, max_retries=3) @celery_app.task(name="app.workers.ocr_tasks.process_image_ocr", bind=True, max_retries=3)
def process_image_ocr(self, image_id: str): def process_image_ocr(self, image_id: str):
"""Process a single image: classify page, extract fields, save to DB.""" """Process a single image: classify page, extract fields, save to DB."""
_run_async(_process(image_id)) _run_async(process_image_ocr_async(image_id))
async def _process(image_id_str: str): async def process_image_ocr_async(image_id_str: str):
from app.core.database import async_session_factory from app.core.database import async_session_factory
from sqlalchemy import delete
from app.models.evidence_image import EvidenceImage, OcrStatus from app.models.evidence_image import EvidenceImage, OcrStatus
from app.models.ocr_block import OcrBlock from app.models.ocr_block import OcrBlock
from app.models.transaction import TransactionRecord
from app.services.ocr_service import classify_page, extract_transaction_fields from app.services.ocr_service import classify_page, extract_transaction_fields
from app.services.parser_service import parse_extracted_fields from app.services.parser_service import parse_extracted_fields
image_id = UUID(image_id_str) image_id = UUID(image_id_str)
async with async_session_factory() as db: async with async_session_factory() as db:
image = await db.get(EvidenceImage, image_id) image = await db.get(EvidenceImage, image_id)
if not image: if not image:
@@ -40,18 +42,34 @@ async def _process(image_id_str: str):
image.ocr_status = OcrStatus.processing image.ocr_status = OcrStatus.processing
await db.flush() await db.flush()
# Re-run OCR for this image should replace old OCR blocks/records.
await db.execute(delete(OcrBlock).where(OcrBlock.image_id == image.id))
await db.execute(delete(TransactionRecord).where(TransactionRecord.evidence_image_id == image.id))
try: try:
source_app, page_type = await classify_page(image.file_path) source_app, page_type = await classify_page(image.file_path)
image.source_app = source_app image.source_app = source_app
image.page_type = page_type image.page_type = page_type
raw_fields = await extract_transaction_fields(image.file_path, source_app, page_type) raw_fields, raw_ocr_text = await extract_transaction_fields(image.file_path, source_app, page_type)
# save raw OCR block is_empty_extract = raw_fields is None or raw_fields == {} or raw_fields == []
# save raw OCR block (direct model output for debugging)
raw_block_content = (raw_ocr_text or "").strip()
if raw_block_content:
block = OcrBlock( block = OcrBlock(
image_id=image.id, image_id=image.id,
content=str(raw_fields), content=raw_block_content,
bbox={},
seq_order=0,
confidence=raw_fields.get("confidence", 0.5) if isinstance(raw_fields, dict) else 0.5,
)
db.add(block)
elif not is_empty_extract:
block = OcrBlock(
image_id=image.id,
content=json.dumps(raw_fields, ensure_ascii=False),
bbox={}, bbox={},
seq_order=0, seq_order=0,
confidence=raw_fields.get("confidence", 0.5) if isinstance(raw_fields, dict) else 0.5, confidence=raw_fields.get("confidence", 0.5) if isinstance(raw_fields, dict) else 0.5,

18
backend/requirements.txt Normal file
View File

@@ -0,0 +1,18 @@
fastapi==0.135.1
uvicorn[standard]==0.41.0
sqlalchemy[asyncio]==2.0.48
asyncpg==0.31.0
alembic==1.18.4
celery[redis]==5.6.2
redis==6.4.0
pydantic-settings>=2.0.0
python-multipart==0.0.22
Pillow==12.1.0
httpx==0.28.1
openpyxl==3.1.5
python-docx==1.2.0
psycopg2-binary==2.9.11
# dev dependencies
pytest==9.0.2
pytest-asyncio>=0.24.0