fix ocr
This commit is contained in:
@@ -1,8 +1,16 @@
|
|||||||
from logging.config import fileConfig
|
from logging.config import fileConfig
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
from sqlalchemy import engine_from_config, pool
|
from sqlalchemy import engine_from_config, pool
|
||||||
from alembic import context
|
from alembic import context
|
||||||
|
|
||||||
|
# Ensure `backend/` is on sys.path so `import app...` works
|
||||||
|
# no matter where `alembic` is executed from.
|
||||||
|
BACKEND_ROOT = Path(__file__).resolve().parents[1]
|
||||||
|
if str(BACKEND_ROOT) not in sys.path:
|
||||||
|
sys.path.insert(0, str(BACKEND_ROOT))
|
||||||
|
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
from app.core.database import Base
|
from app.core.database import Base
|
||||||
import app.models # noqa: F401 – ensure all models are imported
|
import app.models # noqa: F401 – ensure all models are imported
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ from typing import Sequence, Union
|
|||||||
|
|
||||||
from alembic import op
|
from alembic import op
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy import Text
|
||||||
${imports if imports else ""}
|
${imports if imports else ""}
|
||||||
|
|
||||||
revision: str = ${repr(up_revision)}
|
revision: str = ${repr(up_revision)}
|
||||||
|
|||||||
169
backend/alembic/versions/be562b8079e3_init_sqlite_test.py
Normal file
169
backend/alembic/versions/be562b8079e3_init_sqlite_test.py
Normal file
@@ -0,0 +1,169 @@
|
|||||||
|
"""init_sqlite_test
|
||||||
|
|
||||||
|
Revision ID: be562b8079e3
|
||||||
|
Revises:
|
||||||
|
Create Date: 2026-03-11 17:33:30.695730
|
||||||
|
"""
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy import Text
|
||||||
|
from sqlalchemy.dialects import postgresql
|
||||||
|
|
||||||
|
revision: str = 'be562b8079e3'
|
||||||
|
down_revision: Union[str, None] = None
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.create_table('cases',
|
||||||
|
sa.Column('id', sa.UUID(), nullable=False),
|
||||||
|
sa.Column('case_no', sa.String(length=64), nullable=False),
|
||||||
|
sa.Column('title', sa.String(length=256), nullable=False),
|
||||||
|
sa.Column('victim_name', sa.String(length=128), nullable=False),
|
||||||
|
sa.Column('handler', sa.String(length=128), nullable=False),
|
||||||
|
sa.Column('status', sa.Enum('pending', 'uploading', 'analyzing', 'reviewing', 'completed', name='casestatus'), nullable=False),
|
||||||
|
sa.Column('image_count', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('total_amount', sa.Numeric(precision=14, scale=2), nullable=False),
|
||||||
|
sa.Column('created_by', sa.String(length=128), nullable=False),
|
||||||
|
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=False),
|
||||||
|
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=False),
|
||||||
|
sa.Column('deleted_at', sa.DateTime(timezone=True), nullable=True),
|
||||||
|
sa.PrimaryKeyConstraint('id')
|
||||||
|
)
|
||||||
|
op.create_index(op.f('ix_cases_case_no'), 'cases', ['case_no'], unique=True)
|
||||||
|
op.create_table('evidence_images',
|
||||||
|
sa.Column('id', sa.UUID(), nullable=False),
|
||||||
|
sa.Column('case_id', sa.UUID(), nullable=False),
|
||||||
|
sa.Column('file_path', sa.String(length=512), nullable=False),
|
||||||
|
sa.Column('thumb_path', sa.String(length=512), nullable=False),
|
||||||
|
sa.Column('source_app', sa.Enum('wechat', 'alipay', 'bank', 'digital_wallet', 'other', name='sourceapp'), nullable=False),
|
||||||
|
sa.Column('page_type', sa.Enum('bill_list', 'bill_detail', 'transfer_receipt', 'sms_notice', 'balance', 'unknown', name='pagetype'), nullable=False),
|
||||||
|
sa.Column('ocr_status', sa.Enum('pending', 'processing', 'done', 'failed', name='ocrstatus'), nullable=False),
|
||||||
|
sa.Column('file_hash', sa.String(length=128), nullable=False),
|
||||||
|
sa.Column('file_size', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('uploaded_at', sa.DateTime(timezone=True), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=False),
|
||||||
|
sa.ForeignKeyConstraint(['case_id'], ['cases.id'], ),
|
||||||
|
sa.PrimaryKeyConstraint('id')
|
||||||
|
)
|
||||||
|
op.create_index(op.f('ix_evidence_images_case_id'), 'evidence_images', ['case_id'], unique=False)
|
||||||
|
op.create_index(op.f('ix_evidence_images_file_hash'), 'evidence_images', ['file_hash'], unique=True)
|
||||||
|
op.create_table('export_reports',
|
||||||
|
sa.Column('id', sa.UUID(), nullable=False),
|
||||||
|
sa.Column('case_id', sa.UUID(), nullable=False),
|
||||||
|
sa.Column('report_type', sa.Enum('pdf', 'excel', 'word', name='reporttype'), nullable=False),
|
||||||
|
sa.Column('file_path', sa.String(length=512), nullable=False),
|
||||||
|
sa.Column('version', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('content_snapshot', postgresql.JSONB(astext_type=Text()), nullable=False),
|
||||||
|
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=False),
|
||||||
|
sa.ForeignKeyConstraint(['case_id'], ['cases.id'], ),
|
||||||
|
sa.PrimaryKeyConstraint('id')
|
||||||
|
)
|
||||||
|
op.create_index(op.f('ix_export_reports_case_id'), 'export_reports', ['case_id'], unique=False)
|
||||||
|
op.create_table('fund_flow_edges',
|
||||||
|
sa.Column('id', sa.UUID(), nullable=False),
|
||||||
|
sa.Column('case_id', sa.UUID(), nullable=False),
|
||||||
|
sa.Column('source_node', sa.String(length=256), nullable=False),
|
||||||
|
sa.Column('target_node', sa.String(length=256), nullable=False),
|
||||||
|
sa.Column('source_type', sa.String(length=32), nullable=False),
|
||||||
|
sa.Column('target_type', sa.String(length=32), nullable=False),
|
||||||
|
sa.Column('amount', sa.Numeric(precision=14, scale=2), nullable=False),
|
||||||
|
sa.Column('tx_count', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('earliest_time', sa.DateTime(timezone=True), nullable=True),
|
||||||
|
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=False),
|
||||||
|
sa.ForeignKeyConstraint(['case_id'], ['cases.id'], ),
|
||||||
|
sa.PrimaryKeyConstraint('id')
|
||||||
|
)
|
||||||
|
op.create_index(op.f('ix_fund_flow_edges_case_id'), 'fund_flow_edges', ['case_id'], unique=False)
|
||||||
|
op.create_table('transaction_clusters',
|
||||||
|
sa.Column('id', sa.UUID(), nullable=False),
|
||||||
|
sa.Column('case_id', sa.UUID(), nullable=False),
|
||||||
|
sa.Column('primary_tx_id', sa.UUID(), nullable=True),
|
||||||
|
sa.Column('match_reason', sa.String(length=512), nullable=False),
|
||||||
|
sa.ForeignKeyConstraint(['case_id'], ['cases.id'], ),
|
||||||
|
sa.PrimaryKeyConstraint('id')
|
||||||
|
)
|
||||||
|
op.create_index(op.f('ix_transaction_clusters_case_id'), 'transaction_clusters', ['case_id'], unique=False)
|
||||||
|
op.create_table('ocr_blocks',
|
||||||
|
sa.Column('id', sa.UUID(), nullable=False),
|
||||||
|
sa.Column('image_id', sa.UUID(), nullable=False),
|
||||||
|
sa.Column('content', sa.Text(), nullable=False),
|
||||||
|
sa.Column('bbox', postgresql.JSONB(astext_type=Text()), nullable=False),
|
||||||
|
sa.Column('seq_order', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('confidence', sa.Float(), nullable=False),
|
||||||
|
sa.ForeignKeyConstraint(['image_id'], ['evidence_images.id'], ),
|
||||||
|
sa.PrimaryKeyConstraint('id')
|
||||||
|
)
|
||||||
|
op.create_index(op.f('ix_ocr_blocks_image_id'), 'ocr_blocks', ['image_id'], unique=False)
|
||||||
|
op.create_table('transaction_records',
|
||||||
|
sa.Column('id', sa.UUID(), nullable=False),
|
||||||
|
sa.Column('case_id', sa.UUID(), nullable=False),
|
||||||
|
sa.Column('evidence_image_id', sa.UUID(), nullable=True),
|
||||||
|
sa.Column('cluster_id', sa.UUID(), nullable=True),
|
||||||
|
sa.Column('source_app', sa.Enum('wechat', 'alipay', 'bank', 'digital_wallet', 'other', name='sourceapp'), nullable=False),
|
||||||
|
sa.Column('trade_time', sa.DateTime(timezone=True), nullable=False),
|
||||||
|
sa.Column('amount', sa.Numeric(precision=14, scale=2), nullable=False),
|
||||||
|
sa.Column('direction', sa.Enum('in_', 'out', name='direction'), nullable=False),
|
||||||
|
sa.Column('counterparty_name', sa.String(length=256), nullable=False),
|
||||||
|
sa.Column('counterparty_account', sa.String(length=256), nullable=False),
|
||||||
|
sa.Column('self_account_tail_no', sa.String(length=32), nullable=False),
|
||||||
|
sa.Column('order_no', sa.String(length=128), nullable=False),
|
||||||
|
sa.Column('remark', sa.Text(), nullable=False),
|
||||||
|
sa.Column('confidence', sa.Float(), nullable=False),
|
||||||
|
sa.Column('is_duplicate', sa.Boolean(), nullable=False),
|
||||||
|
sa.Column('is_transit', sa.Boolean(), nullable=False),
|
||||||
|
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=False),
|
||||||
|
sa.ForeignKeyConstraint(['case_id'], ['cases.id'], ),
|
||||||
|
sa.ForeignKeyConstraint(['cluster_id'], ['transaction_clusters.id'], ),
|
||||||
|
sa.ForeignKeyConstraint(['evidence_image_id'], ['evidence_images.id'], ),
|
||||||
|
sa.PrimaryKeyConstraint('id')
|
||||||
|
)
|
||||||
|
op.create_index(op.f('ix_transaction_records_case_id'), 'transaction_records', ['case_id'], unique=False)
|
||||||
|
op.create_index(op.f('ix_transaction_records_order_no'), 'transaction_records', ['order_no'], unique=False)
|
||||||
|
op.create_index(op.f('ix_transaction_records_trade_time'), 'transaction_records', ['trade_time'], unique=False)
|
||||||
|
op.create_table('fraud_assessments',
|
||||||
|
sa.Column('id', sa.UUID(), nullable=False),
|
||||||
|
sa.Column('case_id', sa.UUID(), nullable=False),
|
||||||
|
sa.Column('transaction_id', sa.UUID(), nullable=False),
|
||||||
|
sa.Column('confidence_level', sa.Enum('high', 'medium', 'low', name='confidencelevel'), nullable=False),
|
||||||
|
sa.Column('assessed_amount', sa.Numeric(precision=14, scale=2), nullable=False),
|
||||||
|
sa.Column('reason', sa.Text(), nullable=False),
|
||||||
|
sa.Column('exclude_reason', sa.Text(), nullable=False),
|
||||||
|
sa.Column('review_status', sa.Enum('pending', 'confirmed', 'rejected', 'needs_info', name='reviewstatus'), nullable=False),
|
||||||
|
sa.Column('review_note', sa.Text(), nullable=False),
|
||||||
|
sa.Column('reviewed_by', sa.String(length=128), nullable=False),
|
||||||
|
sa.Column('reviewed_at', sa.DateTime(timezone=True), nullable=True),
|
||||||
|
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=False),
|
||||||
|
sa.ForeignKeyConstraint(['case_id'], ['cases.id'], ),
|
||||||
|
sa.ForeignKeyConstraint(['transaction_id'], ['transaction_records.id'], ),
|
||||||
|
sa.PrimaryKeyConstraint('id')
|
||||||
|
)
|
||||||
|
op.create_index(op.f('ix_fraud_assessments_case_id'), 'fraud_assessments', ['case_id'], unique=False)
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.drop_index(op.f('ix_fraud_assessments_case_id'), table_name='fraud_assessments')
|
||||||
|
op.drop_table('fraud_assessments')
|
||||||
|
op.drop_index(op.f('ix_transaction_records_trade_time'), table_name='transaction_records')
|
||||||
|
op.drop_index(op.f('ix_transaction_records_order_no'), table_name='transaction_records')
|
||||||
|
op.drop_index(op.f('ix_transaction_records_case_id'), table_name='transaction_records')
|
||||||
|
op.drop_table('transaction_records')
|
||||||
|
op.drop_index(op.f('ix_ocr_blocks_image_id'), table_name='ocr_blocks')
|
||||||
|
op.drop_table('ocr_blocks')
|
||||||
|
op.drop_index(op.f('ix_transaction_clusters_case_id'), table_name='transaction_clusters')
|
||||||
|
op.drop_table('transaction_clusters')
|
||||||
|
op.drop_index(op.f('ix_fund_flow_edges_case_id'), table_name='fund_flow_edges')
|
||||||
|
op.drop_table('fund_flow_edges')
|
||||||
|
op.drop_index(op.f('ix_export_reports_case_id'), table_name='export_reports')
|
||||||
|
op.drop_table('export_reports')
|
||||||
|
op.drop_index(op.f('ix_evidence_images_file_hash'), table_name='evidence_images')
|
||||||
|
op.drop_index(op.f('ix_evidence_images_case_id'), table_name='evidence_images')
|
||||||
|
op.drop_table('evidence_images')
|
||||||
|
op.drop_index(op.f('ix_cases_case_no'), table_name='cases')
|
||||||
|
op.drop_table('cases')
|
||||||
|
# ### end Alembic commands ###
|
||||||
@@ -2,10 +2,12 @@ from uuid import UUID
|
|||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, Query, HTTPException
|
from fastapi import APIRouter, Depends, Query, HTTPException
|
||||||
|
from sqlalchemy import select
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy.orm import selectinload
|
||||||
|
|
||||||
from app.core.database import get_db
|
from app.core.database import get_db
|
||||||
from app.models.assessment import ConfidenceLevel
|
from app.models.assessment import FraudAssessment, ConfidenceLevel
|
||||||
from app.repositories.assessment_repo import AssessmentRepository
|
from app.repositories.assessment_repo import AssessmentRepository
|
||||||
from app.schemas.assessment import (
|
from app.schemas.assessment import (
|
||||||
AssessmentOut,
|
AssessmentOut,
|
||||||
@@ -46,6 +48,14 @@ async def review_assessment(
|
|||||||
"reviewed_by": body.reviewed_by,
|
"reviewed_by": body.reviewed_by,
|
||||||
"reviewed_at": datetime.now(timezone.utc),
|
"reviewed_at": datetime.now(timezone.utc),
|
||||||
})
|
})
|
||||||
|
|
||||||
|
# eager-load the transaction relationship to avoid lazy-load in async context
|
||||||
|
result = await db.execute(
|
||||||
|
select(FraudAssessment)
|
||||||
|
.options(selectinload(FraudAssessment.transaction))
|
||||||
|
.where(FraudAssessment.id == assessment_id)
|
||||||
|
)
|
||||||
|
assessment = result.scalar_one()
|
||||||
return assessment
|
return assessment
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,15 +1,16 @@
|
|||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
import asyncio
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, UploadFile, File, HTTPException, Query
|
from fastapi import APIRouter, Depends, UploadFile, File, HTTPException
|
||||||
from fastapi.responses import FileResponse
|
from fastapi.responses import FileResponse
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
from app.core.database import get_db
|
from app.core.database import get_db
|
||||||
from app.models.evidence_image import EvidenceImage, SourceApp, PageType
|
from app.models.evidence_image import EvidenceImage, SourceApp, PageType, OcrStatus
|
||||||
from app.repositories.image_repo import ImageRepository
|
from app.repositories.image_repo import ImageRepository
|
||||||
from app.repositories.case_repo import CaseRepository
|
from app.repositories.case_repo import CaseRepository
|
||||||
from app.schemas.image import ImageOut, ImageDetailOut, OcrFieldCorrection
|
from app.schemas.image import ImageOut, ImageDetailOut, OcrFieldCorrection, CaseOcrStartIn
|
||||||
from app.utils.hash import sha256_file
|
from app.utils.hash import sha256_file
|
||||||
from app.utils.file_storage import save_upload
|
from app.utils.file_storage import save_upload
|
||||||
|
|
||||||
@@ -32,9 +33,11 @@ async def upload_images(
|
|||||||
|
|
||||||
for f in files:
|
for f in files:
|
||||||
data = await f.read()
|
data = await f.read()
|
||||||
file_hash = sha256_file(data)
|
raw_hash = sha256_file(data)
|
||||||
|
# Scope hash by case to avoid cross-case unique conflicts while still deduplicating inside one case.
|
||||||
|
scoped_hash = f"{raw_hash}:{case_id}"
|
||||||
|
|
||||||
existing = await img_repo.find_by_hash(file_hash)
|
existing = await img_repo.find_by_hash_in_case(case_id, [raw_hash, scoped_hash])
|
||||||
if existing:
|
if existing:
|
||||||
results.append(existing)
|
results.append(existing)
|
||||||
continue
|
continue
|
||||||
@@ -44,7 +47,7 @@ async def upload_images(
|
|||||||
case_id=case_id,
|
case_id=case_id,
|
||||||
file_path=file_path,
|
file_path=file_path,
|
||||||
thumb_path=thumb_path,
|
thumb_path=thumb_path,
|
||||||
file_hash=file_hash,
|
file_hash=scoped_hash,
|
||||||
file_size=len(data),
|
file_size=len(data),
|
||||||
)
|
)
|
||||||
image = await img_repo.create(image)
|
image = await img_repo.create(image)
|
||||||
@@ -53,14 +56,11 @@ async def upload_images(
|
|||||||
case.image_count = await img_repo.count_by_case(case_id)
|
case.image_count = await img_repo.count_by_case(case_id)
|
||||||
await db.flush()
|
await db.flush()
|
||||||
|
|
||||||
# trigger OCR tasks (non-blocking)
|
# trigger OCR tasks in-process background (non-blocking for API response)
|
||||||
from app.workers.ocr_tasks import process_image_ocr
|
from app.workers.ocr_tasks import process_image_ocr_async
|
||||||
for img in results:
|
for img in results:
|
||||||
if img.ocr_status.value == "pending":
|
if img.ocr_status.value == "pending":
|
||||||
try:
|
asyncio.create_task(process_image_ocr_async(str(img.id)))
|
||||||
process_image_ocr.delay(str(img.id))
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
@@ -73,7 +73,21 @@ async def list_images(
|
|||||||
db: AsyncSession = Depends(get_db),
|
db: AsyncSession = Depends(get_db),
|
||||||
):
|
):
|
||||||
repo = ImageRepository(db)
|
repo = ImageRepository(db)
|
||||||
return await repo.list_by_case(case_id, source_app=source_app, page_type=page_type)
|
images = await repo.list_by_case(case_id, source_app=source_app, page_type=page_type)
|
||||||
|
return [
|
||||||
|
ImageOut(
|
||||||
|
id=img.id,
|
||||||
|
case_id=img.case_id,
|
||||||
|
url=f"/api/v1/images/{img.id}/file",
|
||||||
|
thumb_url=f"/api/v1/images/{img.id}/file",
|
||||||
|
source_app=img.source_app,
|
||||||
|
page_type=img.page_type,
|
||||||
|
ocr_status=img.ocr_status,
|
||||||
|
file_hash=img.file_hash,
|
||||||
|
uploaded_at=img.uploaded_at,
|
||||||
|
)
|
||||||
|
for img in images
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@router.get("/images/{image_id}", response_model=ImageDetailOut)
|
@router.get("/images/{image_id}", response_model=ImageDetailOut)
|
||||||
@@ -128,3 +142,43 @@ async def get_image_file(image_id: UUID, db: AsyncSession = Depends(get_db)):
|
|||||||
if not full_path.exists():
|
if not full_path.exists():
|
||||||
raise HTTPException(404, "文件不存在")
|
raise HTTPException(404, "文件不存在")
|
||||||
return FileResponse(full_path)
|
return FileResponse(full_path)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/cases/{case_id}/ocr/start")
|
||||||
|
async def start_case_ocr(
|
||||||
|
case_id: UUID,
|
||||||
|
payload: CaseOcrStartIn | None = None,
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
):
|
||||||
|
case_repo = CaseRepository(db)
|
||||||
|
case = await case_repo.get(case_id)
|
||||||
|
if not case:
|
||||||
|
raise HTTPException(404, "案件不存在")
|
||||||
|
|
||||||
|
repo = ImageRepository(db)
|
||||||
|
include_done = payload.include_done if payload else False
|
||||||
|
image_ids = payload.image_ids if payload else []
|
||||||
|
if image_ids:
|
||||||
|
images = await repo.list_by_ids_in_case(case_id, image_ids)
|
||||||
|
# For explicit re-run, mark selected images as processing immediately
|
||||||
|
# so frontend can reflect state transition without full page refresh.
|
||||||
|
for img in images:
|
||||||
|
img.ocr_status = OcrStatus.processing
|
||||||
|
await db.flush()
|
||||||
|
await db.commit()
|
||||||
|
else:
|
||||||
|
images = await repo.list_for_ocr(case_id, include_done=include_done)
|
||||||
|
|
||||||
|
from app.workers.ocr_tasks import process_image_ocr_async
|
||||||
|
|
||||||
|
submitted = 0
|
||||||
|
for img in images:
|
||||||
|
asyncio.create_task(process_image_ocr_async(str(img.id)))
|
||||||
|
submitted += 1
|
||||||
|
|
||||||
|
return {
|
||||||
|
"caseId": str(case_id),
|
||||||
|
"submitted": submitted,
|
||||||
|
"totalCandidates": len(images),
|
||||||
|
"message": f"已提交 {submitted} 张截图的 OCR 任务",
|
||||||
|
}
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ class Settings(BaseSettings):
|
|||||||
|
|
||||||
OCR_API_KEY: str = ""
|
OCR_API_KEY: str = ""
|
||||||
OCR_API_URL: str = ""
|
OCR_API_URL: str = ""
|
||||||
|
OCR_MODEL: str = ""
|
||||||
LLM_API_KEY: str = ""
|
LLM_API_KEY: str = ""
|
||||||
LLM_API_URL: str = ""
|
LLM_API_URL: str = ""
|
||||||
LLM_MODEL: str = ""
|
LLM_MODEL: str = ""
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from uuid import UUID
|
|||||||
from sqlalchemy import select, func
|
from sqlalchemy import select, func
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.models.evidence_image import EvidenceImage, SourceApp, PageType
|
from app.models.evidence_image import EvidenceImage, SourceApp, PageType, OcrStatus
|
||||||
from app.repositories.base import BaseRepository
|
from app.repositories.base import BaseRepository
|
||||||
|
|
||||||
|
|
||||||
@@ -17,6 +17,17 @@ class ImageRepository(BaseRepository[EvidenceImage]):
|
|||||||
)
|
)
|
||||||
return result.scalar_one_or_none()
|
return result.scalar_one_or_none()
|
||||||
|
|
||||||
|
async def find_by_hash_in_case(self, case_id: UUID, file_hashes: list[str]) -> EvidenceImage | None:
|
||||||
|
if not file_hashes:
|
||||||
|
return None
|
||||||
|
result = await self.session.execute(
|
||||||
|
select(EvidenceImage).where(
|
||||||
|
EvidenceImage.case_id == case_id,
|
||||||
|
EvidenceImage.file_hash.in_(file_hashes),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return result.scalar_one_or_none()
|
||||||
|
|
||||||
async def list_by_case(
|
async def list_by_case(
|
||||||
self,
|
self,
|
||||||
case_id: UUID,
|
case_id: UUID,
|
||||||
@@ -37,3 +48,20 @@ class ImageRepository(BaseRepository[EvidenceImage]):
|
|||||||
select(func.count()).select_from(EvidenceImage).where(EvidenceImage.case_id == case_id)
|
select(func.count()).select_from(EvidenceImage).where(EvidenceImage.case_id == case_id)
|
||||||
)
|
)
|
||||||
return result.scalar() or 0
|
return result.scalar() or 0
|
||||||
|
|
||||||
|
async def list_for_ocr(self, case_id: UUID, include_done: bool = False) -> list[EvidenceImage]:
|
||||||
|
query = select(EvidenceImage).where(EvidenceImage.case_id == case_id)
|
||||||
|
if not include_done:
|
||||||
|
query = query.where(EvidenceImage.ocr_status != OcrStatus.done)
|
||||||
|
result = await self.session.execute(query.order_by(EvidenceImage.uploaded_at.desc()))
|
||||||
|
return list(result.scalars().all())
|
||||||
|
|
||||||
|
async def list_by_ids_in_case(self, case_id: UUID, image_ids: list[UUID]) -> list[EvidenceImage]:
|
||||||
|
if not image_ids:
|
||||||
|
return []
|
||||||
|
result = await self.session.execute(
|
||||||
|
select(EvidenceImage)
|
||||||
|
.where(EvidenceImage.case_id == case_id, EvidenceImage.id.in_(image_ids))
|
||||||
|
.order_by(EvidenceImage.uploaded_at.desc())
|
||||||
|
)
|
||||||
|
return list(result.scalars().all())
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
from pydantic import BaseModel
|
from app.schemas.base import CamelModel
|
||||||
|
|
||||||
|
|
||||||
class AnalysisStatusOut(BaseModel):
|
class AnalysisStatusOut(CamelModel):
|
||||||
case_id: str
|
case_id: str
|
||||||
status: str
|
status: str
|
||||||
progress: int = 0
|
progress: int = 0
|
||||||
@@ -9,6 +9,6 @@ class AnalysisStatusOut(BaseModel):
|
|||||||
message: str = ""
|
message: str = ""
|
||||||
|
|
||||||
|
|
||||||
class AnalysisTriggerOut(BaseModel):
|
class AnalysisTriggerOut(CamelModel):
|
||||||
task_id: str
|
task_id: str
|
||||||
message: str
|
message: str
|
||||||
|
|||||||
@@ -1,13 +1,12 @@
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from app.models.assessment import ConfidenceLevel, ReviewStatus
|
from app.models.assessment import ConfidenceLevel, ReviewStatus
|
||||||
|
from app.schemas.base import CamelModel
|
||||||
from app.schemas.transaction import TransactionOut
|
from app.schemas.transaction import TransactionOut
|
||||||
|
|
||||||
|
|
||||||
class AssessmentOut(BaseModel):
|
class AssessmentOut(CamelModel):
|
||||||
id: UUID
|
id: UUID
|
||||||
case_id: UUID
|
case_id: UUID
|
||||||
transaction_id: UUID
|
transaction_id: UUID
|
||||||
@@ -21,19 +20,17 @@ class AssessmentOut(BaseModel):
|
|||||||
reviewed_by: str
|
reviewed_by: str
|
||||||
reviewed_at: datetime | None = None
|
reviewed_at: datetime | None = None
|
||||||
|
|
||||||
model_config = {"from_attributes": True}
|
|
||||||
|
|
||||||
|
class AssessmentListOut(CamelModel):
|
||||||
class AssessmentListOut(BaseModel):
|
|
||||||
items: list[AssessmentOut]
|
items: list[AssessmentOut]
|
||||||
total: int
|
total: int
|
||||||
|
|
||||||
|
|
||||||
class ReviewSubmit(BaseModel):
|
class ReviewSubmit(CamelModel):
|
||||||
review_status: ReviewStatus
|
review_status: ReviewStatus
|
||||||
review_note: str = ""
|
review_note: str = ""
|
||||||
reviewed_by: str = "demo_user"
|
reviewed_by: str = "demo_user"
|
||||||
|
|
||||||
|
|
||||||
class InquirySuggestionOut(BaseModel):
|
class InquirySuggestionOut(CamelModel):
|
||||||
suggestions: list[str]
|
suggestions: list[str]
|
||||||
|
|||||||
15
backend/app/schemas/base.py
Normal file
15
backend/app/schemas/base.py
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
"""Base schema with camelCase JSON serialization."""
|
||||||
|
from pydantic import BaseModel, ConfigDict
|
||||||
|
|
||||||
|
|
||||||
|
def to_camel(s: str) -> str:
|
||||||
|
parts = s.split("_")
|
||||||
|
return parts[0] + "".join(p.capitalize() for p in parts[1:])
|
||||||
|
|
||||||
|
|
||||||
|
class CamelModel(BaseModel):
|
||||||
|
model_config = ConfigDict(
|
||||||
|
from_attributes=True,
|
||||||
|
alias_generator=to_camel,
|
||||||
|
populate_by_name=True,
|
||||||
|
)
|
||||||
@@ -1,26 +1,25 @@
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from app.models.case import CaseStatus
|
from app.models.case import CaseStatus
|
||||||
|
from app.schemas.base import CamelModel
|
||||||
|
|
||||||
|
|
||||||
class CaseCreate(BaseModel):
|
class CaseCreate(CamelModel):
|
||||||
case_no: str
|
case_no: str
|
||||||
title: str
|
title: str
|
||||||
victim_name: str
|
victim_name: str
|
||||||
handler: str = ""
|
handler: str = ""
|
||||||
|
|
||||||
|
|
||||||
class CaseUpdate(BaseModel):
|
class CaseUpdate(CamelModel):
|
||||||
title: str | None = None
|
title: str | None = None
|
||||||
victim_name: str | None = None
|
victim_name: str | None = None
|
||||||
handler: str | None = None
|
handler: str | None = None
|
||||||
status: CaseStatus | None = None
|
status: CaseStatus | None = None
|
||||||
|
|
||||||
|
|
||||||
class CaseOut(BaseModel):
|
class CaseOut(CamelModel):
|
||||||
id: UUID
|
id: UUID
|
||||||
case_no: str
|
case_no: str
|
||||||
title: str
|
title: str
|
||||||
@@ -32,9 +31,7 @@ class CaseOut(BaseModel):
|
|||||||
created_at: datetime
|
created_at: datetime
|
||||||
updated_at: datetime
|
updated_at: datetime
|
||||||
|
|
||||||
model_config = {"from_attributes": True}
|
|
||||||
|
|
||||||
|
class CaseListOut(CamelModel):
|
||||||
class CaseListOut(BaseModel):
|
|
||||||
items: list[CaseOut]
|
items: list[CaseOut]
|
||||||
total: int
|
total: int
|
||||||
|
|||||||
@@ -1,12 +1,11 @@
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from app.models.evidence_image import SourceApp, PageType, OcrStatus
|
from app.models.evidence_image import SourceApp, PageType, OcrStatus
|
||||||
|
from app.schemas.base import CamelModel
|
||||||
|
|
||||||
|
|
||||||
class ImageOut(BaseModel):
|
class ImageOut(CamelModel):
|
||||||
id: UUID
|
id: UUID
|
||||||
case_id: UUID
|
case_id: UUID
|
||||||
url: str = ""
|
url: str = ""
|
||||||
@@ -17,24 +16,25 @@ class ImageOut(BaseModel):
|
|||||||
file_hash: str
|
file_hash: str
|
||||||
uploaded_at: datetime
|
uploaded_at: datetime
|
||||||
|
|
||||||
model_config = {"from_attributes": True}
|
|
||||||
|
|
||||||
|
class OcrBlockOut(CamelModel):
|
||||||
class OcrBlockOut(BaseModel):
|
|
||||||
id: UUID
|
id: UUID
|
||||||
content: str
|
content: str
|
||||||
bbox: dict
|
bbox: dict
|
||||||
seq_order: int
|
seq_order: int
|
||||||
confidence: float
|
confidence: float
|
||||||
|
|
||||||
model_config = {"from_attributes": True}
|
|
||||||
|
|
||||||
|
|
||||||
class ImageDetailOut(ImageOut):
|
class ImageDetailOut(ImageOut):
|
||||||
ocr_blocks: list[OcrBlockOut] = []
|
ocr_blocks: list[OcrBlockOut] = []
|
||||||
|
|
||||||
|
|
||||||
class OcrFieldCorrection(BaseModel):
|
class OcrFieldCorrection(CamelModel):
|
||||||
field_name: str
|
field_name: str
|
||||||
old_value: str
|
old_value: str
|
||||||
new_value: str
|
new_value: str
|
||||||
|
|
||||||
|
|
||||||
|
class CaseOcrStartIn(CamelModel):
|
||||||
|
include_done: bool = False
|
||||||
|
image_ids: list[UUID] = []
|
||||||
|
|||||||
@@ -1,12 +1,11 @@
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from app.models.report import ReportType
|
from app.models.report import ReportType
|
||||||
|
from app.schemas.base import CamelModel
|
||||||
|
|
||||||
|
|
||||||
class ReportCreate(BaseModel):
|
class ReportCreate(CamelModel):
|
||||||
report_type: ReportType
|
report_type: ReportType
|
||||||
include_summary: bool = True
|
include_summary: bool = True
|
||||||
include_transactions: bool = True
|
include_transactions: bool = True
|
||||||
@@ -17,7 +16,7 @@ class ReportCreate(BaseModel):
|
|||||||
include_screenshots: bool = False
|
include_screenshots: bool = False
|
||||||
|
|
||||||
|
|
||||||
class ReportOut(BaseModel):
|
class ReportOut(CamelModel):
|
||||||
id: UUID
|
id: UUID
|
||||||
case_id: UUID
|
case_id: UUID
|
||||||
report_type: ReportType
|
report_type: ReportType
|
||||||
@@ -25,9 +24,7 @@ class ReportOut(BaseModel):
|
|||||||
version: int
|
version: int
|
||||||
created_at: datetime
|
created_at: datetime
|
||||||
|
|
||||||
model_config = {"from_attributes": True}
|
|
||||||
|
|
||||||
|
class ReportListOut(CamelModel):
|
||||||
class ReportListOut(BaseModel):
|
|
||||||
items: list[ReportOut]
|
items: list[ReportOut]
|
||||||
total: int
|
total: int
|
||||||
|
|||||||
@@ -1,13 +1,12 @@
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from app.models.evidence_image import SourceApp
|
from app.models.evidence_image import SourceApp
|
||||||
from app.models.transaction import Direction
|
from app.models.transaction import Direction
|
||||||
|
from app.schemas.base import CamelModel
|
||||||
|
|
||||||
|
|
||||||
class TransactionOut(BaseModel):
|
class TransactionOut(CamelModel):
|
||||||
id: UUID
|
id: UUID
|
||||||
case_id: UUID
|
case_id: UUID
|
||||||
source_app: SourceApp
|
source_app: SourceApp
|
||||||
@@ -25,21 +24,19 @@ class TransactionOut(BaseModel):
|
|||||||
is_duplicate: bool
|
is_duplicate: bool
|
||||||
is_transit: bool
|
is_transit: bool
|
||||||
|
|
||||||
model_config = {"from_attributes": True}
|
|
||||||
|
|
||||||
|
class TransactionListOut(CamelModel):
|
||||||
class TransactionListOut(BaseModel):
|
|
||||||
items: list[TransactionOut]
|
items: list[TransactionOut]
|
||||||
total: int
|
total: int
|
||||||
|
|
||||||
|
|
||||||
class FlowNodeOut(BaseModel):
|
class FlowNodeOut(CamelModel):
|
||||||
id: str
|
id: str
|
||||||
label: str
|
label: str
|
||||||
type: str
|
type: str
|
||||||
|
|
||||||
|
|
||||||
class FlowEdgeOut(BaseModel):
|
class FlowEdgeOut(CamelModel):
|
||||||
source: str
|
source: str
|
||||||
target: str
|
target: str
|
||||||
amount: float
|
amount: float
|
||||||
@@ -47,6 +44,6 @@ class FlowEdgeOut(BaseModel):
|
|||||||
trade_time: str
|
trade_time: str
|
||||||
|
|
||||||
|
|
||||||
class FlowGraphOut(BaseModel):
|
class FlowGraphOut(CamelModel):
|
||||||
nodes: list[FlowNodeOut]
|
nodes: list[FlowNodeOut]
|
||||||
edges: list[FlowEdgeOut]
|
edges: list[FlowEdgeOut]
|
||||||
|
|||||||
@@ -1,12 +1,20 @@
|
|||||||
"""OCR and multimodal extraction service.
|
"""OCR and multimodal extraction service.
|
||||||
|
|
||||||
Wraps calls to cloud OCR / multimodal APIs with a provider-agnostic interface.
|
Both classify_page and extract_transaction_fields use OpenAI-compatible
|
||||||
When API keys are not configured, falls back to a mock implementation that
|
multimodal chat completion APIs. OCR and LLM can point to different
|
||||||
returns placeholder data (sufficient for demo / competition).
|
providers / models via separate env vars:
|
||||||
|
|
||||||
|
OCR_API_URL / OCR_API_KEY / OCR_MODEL — for page classification & field extraction
|
||||||
|
LLM_API_URL / LLM_API_KEY / LLM_MODEL — for reasoning tasks (assessment, suggestions)
|
||||||
|
|
||||||
|
When OCR keys are not set, falls back to LLM keys.
|
||||||
|
When neither is set, returns mock data (sufficient for demo).
|
||||||
"""
|
"""
|
||||||
|
import base64
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
import ast
|
||||||
|
import re
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
@@ -14,28 +22,163 @@ from app.core.config import settings
|
|||||||
from app.models.evidence_image import SourceApp, PageType
|
from app.models.evidence_image import SourceApp, PageType
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
ENABLE_LLM_REPAIR = False # temporary: disabled per debugging request
|
||||||
|
|
||||||
|
|
||||||
|
def _ocr_config() -> tuple[str, str, str]:
|
||||||
|
"""Return (api_url, api_key, model) for OCR, falling back to LLM config."""
|
||||||
|
url = settings.OCR_API_URL or settings.LLM_API_URL
|
||||||
|
key = settings.OCR_API_KEY or settings.LLM_API_KEY
|
||||||
|
model = settings.OCR_MODEL or settings.LLM_MODEL
|
||||||
|
return url, key, model
|
||||||
|
|
||||||
|
|
||||||
|
def _llm_config() -> tuple[str, str, str]:
|
||||||
|
"""Return (api_url, api_key, model) for text-only LLM repair."""
|
||||||
|
return settings.LLM_API_URL, settings.LLM_API_KEY, settings.LLM_MODEL
|
||||||
|
|
||||||
|
|
||||||
|
def _ocr_available() -> bool:
|
||||||
|
url, key, model = _ocr_config()
|
||||||
|
return bool(url and key and model)
|
||||||
|
|
||||||
|
|
||||||
|
def _llm_available() -> bool:
|
||||||
|
url, key, model = _llm_config()
|
||||||
|
return bool(url and key and model)
|
||||||
|
|
||||||
|
|
||||||
# ── provider-agnostic interface ──────────────────────────────────────────
|
# ── provider-agnostic interface ──────────────────────────────────────────
|
||||||
|
|
||||||
async def classify_page(image_path: str) -> tuple[SourceApp, PageType]:
|
async def classify_page(image_path: str) -> tuple[SourceApp, PageType]:
|
||||||
"""Identify the source app and page type of a screenshot."""
|
"""Identify the source app and page type of a screenshot."""
|
||||||
if settings.LLM_API_KEY and settings.LLM_API_URL:
|
if _ocr_available():
|
||||||
return await _classify_via_api(image_path)
|
return await _classify_via_api(image_path)
|
||||||
return _classify_mock(image_path)
|
return _classify_mock(image_path)
|
||||||
|
|
||||||
|
|
||||||
async def extract_transaction_fields(image_path: str, source_app: SourceApp, page_type: PageType) -> dict:
|
async def extract_transaction_fields(
|
||||||
|
image_path: str, source_app: SourceApp, page_type: PageType
|
||||||
|
) -> tuple[dict | list, str]:
|
||||||
"""Extract structured transaction fields from a screenshot."""
|
"""Extract structured transaction fields from a screenshot."""
|
||||||
if settings.LLM_API_KEY and settings.LLM_API_URL:
|
if _ocr_available():
|
||||||
return await _extract_via_api(image_path, source_app, page_type)
|
return await _extract_via_api(image_path, source_app, page_type)
|
||||||
return _extract_mock(image_path, source_app, page_type)
|
mock_data = _extract_mock(image_path, source_app, page_type)
|
||||||
|
return mock_data, json.dumps(mock_data, ensure_ascii=False)
|
||||||
|
|
||||||
|
|
||||||
# ── real API implementation ──────────────────────────────────────────────
|
# ── OpenAI-compatible API implementation ─────────────────────────────────
|
||||||
|
|
||||||
|
async def _call_vision(prompt: str, image_b64: str, max_tokens: int = 2000) -> str:
|
||||||
|
"""Send a vision request to the OCR model endpoint."""
|
||||||
|
url, key, model = _ocr_config()
|
||||||
|
async with httpx.AsyncClient(timeout=60) as client:
|
||||||
|
resp = await client.post(
|
||||||
|
url,
|
||||||
|
headers={"Authorization": f"Bearer {key}"},
|
||||||
|
json={
|
||||||
|
"model": model,
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": prompt},
|
||||||
|
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_b64}"}},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"max_tokens": max_tokens,
|
||||||
|
"temperature": 0,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
resp.raise_for_status()
|
||||||
|
body = resp.json()
|
||||||
|
choice0 = (body.get("choices") or [{}])[0]
|
||||||
|
msg = choice0.get("message") or {}
|
||||||
|
content = msg.get("content", "")
|
||||||
|
finish_reason = choice0.get("finish_reason")
|
||||||
|
usage = body.get("usage") or {}
|
||||||
|
return content if isinstance(content, str) else str(content)
|
||||||
|
|
||||||
|
|
||||||
|
async def _call_text_llm(prompt: str, max_tokens: int = 1200) -> str:
|
||||||
|
"""Send a text-only request to the reasoning LLM endpoint."""
|
||||||
|
url, key, model = _llm_config()
|
||||||
|
async with httpx.AsyncClient(timeout=60) as client:
|
||||||
|
resp = await client.post(
|
||||||
|
url,
|
||||||
|
headers={"Authorization": f"Bearer {key}"},
|
||||||
|
json={
|
||||||
|
"model": model,
|
||||||
|
"messages": [{"role": "user", "content": prompt}],
|
||||||
|
"max_tokens": max_tokens,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
resp.raise_for_status()
|
||||||
|
return resp.json()["choices"][0]["message"]["content"]
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_json_response(text: str):
|
||||||
|
"""Strip markdown fences and parse JSON from LLM output."""
|
||||||
|
cleaned = text.strip()
|
||||||
|
if cleaned.startswith("```"):
|
||||||
|
cleaned = cleaned.split("\n", 1)[-1]
|
||||||
|
if cleaned.endswith("```"):
|
||||||
|
cleaned = cleaned.rsplit("```", 1)[0]
|
||||||
|
cleaned = cleaned.strip().removeprefix("json").strip()
|
||||||
|
# Try parsing the full body first.
|
||||||
|
try:
|
||||||
|
return json.loads(cleaned)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Extract likely JSON block when model wraps with extra text.
|
||||||
|
match = re.search(r"(\{[\s\S]*\}|\[[\s\S]*\])", cleaned)
|
||||||
|
candidate = match.group(1) if match else cleaned
|
||||||
|
|
||||||
|
# Normalize common model output issues:
|
||||||
|
# - smart quotes
|
||||||
|
# - trailing commas
|
||||||
|
normalized = (
|
||||||
|
candidate
|
||||||
|
.replace("“", "\"")
|
||||||
|
.replace("”", "\"")
|
||||||
|
.replace("‘", "'")
|
||||||
|
.replace("’", "'")
|
||||||
|
)
|
||||||
|
normalized = re.sub(r",\s*([}\]])", r"\1", normalized)
|
||||||
|
try:
|
||||||
|
return json.loads(normalized)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Last fallback: Python-like literal payloads with single quotes / True/False/None.
|
||||||
|
py_like = re.sub(r"\btrue\b", "True", normalized, flags=re.IGNORECASE)
|
||||||
|
py_like = re.sub(r"\bfalse\b", "False", py_like, flags=re.IGNORECASE)
|
||||||
|
py_like = re.sub(r"\bnull\b", "None", py_like, flags=re.IGNORECASE)
|
||||||
|
parsed = ast.literal_eval(py_like)
|
||||||
|
if isinstance(parsed, (dict, list)):
|
||||||
|
return parsed
|
||||||
|
raise ValueError("Parsed payload is neither dict nor list")
|
||||||
|
|
||||||
|
|
||||||
|
async def _repair_broken_json_with_llm(broken_text: str) -> str:
|
||||||
|
"""Use LLM_MODEL to repair malformed OCR JSON text without adding new semantics."""
|
||||||
|
prompt = (
|
||||||
|
"你是JSON修复器。下面是OCR模型返回的可能损坏JSON文本。\n"
|
||||||
|
"任务:仅修复语法并输出一个可被 json.loads 直接解析的JSON。\n"
|
||||||
|
"硬性要求:\n"
|
||||||
|
"1) 只能输出JSON,不要Markdown,不要解释。\n"
|
||||||
|
"2) 保留原有字段和语义,不新增未出现的信息。\n"
|
||||||
|
"3) 若出现截断,允许最小闭合修复(补齐缺失括号/引号)。\n"
|
||||||
|
"4) 结果必须是对象或数组。\n\n"
|
||||||
|
"损坏文本如下:\n"
|
||||||
|
f"{broken_text}"
|
||||||
|
)
|
||||||
|
return await _call_text_llm(prompt, max_tokens=1200)
|
||||||
|
|
||||||
|
|
||||||
async def _classify_via_api(image_path: str) -> tuple[SourceApp, PageType]:
|
async def _classify_via_api(image_path: str) -> tuple[SourceApp, PageType]:
|
||||||
import base64
|
|
||||||
full_path = settings.upload_path / image_path
|
full_path = settings.upload_path / image_path
|
||||||
if not full_path.exists():
|
if not full_path.exists():
|
||||||
return SourceApp.other, PageType.unknown
|
return SourceApp.other, PageType.unknown
|
||||||
@@ -43,79 +186,64 @@ async def _classify_via_api(image_path: str) -> tuple[SourceApp, PageType]:
|
|||||||
image_b64 = base64.b64encode(full_path.read_bytes()).decode()
|
image_b64 = base64.b64encode(full_path.read_bytes()).decode()
|
||||||
prompt = (
|
prompt = (
|
||||||
"请分析这张手机截图,判断它来自哪个APP(wechat/alipay/bank/digital_wallet/other)"
|
"请分析这张手机截图,判断它来自哪个APP(wechat/alipay/bank/digital_wallet/other)"
|
||||||
"以及页面类型(bill_list/bill_detail/transfer_receipt/sms_notice/balance/unknown)。"
|
"以及页面类型(bill_list/bill_detail/transfer_receipt/sms_notice/balance/unknown)。\n"
|
||||||
"只返回JSON: {\"source_app\": \"...\", \"page_type\": \"...\"}"
|
'只返回JSON,格式: {"source_app": "...", "page_type": "..."}'
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async with httpx.AsyncClient(timeout=30) as client:
|
text = await _call_vision(prompt, image_b64, max_tokens=600)
|
||||||
resp = await client.post(
|
data = _parse_json_response(text)
|
||||||
settings.LLM_API_URL,
|
|
||||||
headers={"Authorization": f"Bearer {settings.LLM_API_KEY}"},
|
|
||||||
json={
|
|
||||||
"model": settings.LLM_MODEL,
|
|
||||||
"messages": [
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": [
|
|
||||||
{"type": "text", "text": prompt},
|
|
||||||
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_b64}"}},
|
|
||||||
],
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"max_tokens": 200,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
resp.raise_for_status()
|
|
||||||
text = resp.json()["choices"][0]["message"]["content"]
|
|
||||||
data = json.loads(text.strip().strip("`").removeprefix("json").strip())
|
|
||||||
return SourceApp(data.get("source_app", "other")), PageType(data.get("page_type", "unknown"))
|
return SourceApp(data.get("source_app", "other")), PageType(data.get("page_type", "unknown"))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning("classify_page API failed: %s", e)
|
logger.warning("classify_page API failed: %s", e)
|
||||||
return SourceApp.other, PageType.unknown
|
return SourceApp.other, PageType.unknown
|
||||||
|
|
||||||
|
|
||||||
async def _extract_via_api(image_path: str, source_app: SourceApp, page_type: PageType) -> dict:
|
async def _extract_via_api(
|
||||||
import base64
|
image_path: str, source_app: SourceApp, page_type: PageType
|
||||||
|
) -> tuple[dict | list, str]:
|
||||||
full_path = settings.upload_path / image_path
|
full_path = settings.upload_path / image_path
|
||||||
if not full_path.exists():
|
if not full_path.exists():
|
||||||
return {}
|
return {}, ""
|
||||||
|
|
||||||
image_b64 = base64.b64encode(full_path.read_bytes()).decode()
|
image_b64 = base64.b64encode(full_path.read_bytes()).decode()
|
||||||
prompt = (
|
prompt = (
|
||||||
f"这是一张来自{source_app.value}的{page_type.value}截图。"
|
"你是账单OCR结构化引擎。你的输出会被程序直接 json.loads 解析,任何非JSON字符都会导致任务失败。\n"
|
||||||
"请提取其中的交易信息,返回JSON格式,字段包括:"
|
f"输入图片来自 {source_app.value} / {page_type.value}。\n\n"
|
||||||
"trade_time(交易时间,格式YYYY-MM-DD HH:MM:SS), amount(金额,数字), "
|
"【硬性输出规则】\n"
|
||||||
"direction(in或out), counterparty_name(对方名称), counterparty_account(对方账号), "
|
"1) 只能输出一个合法JSON值:对象或数组;禁止Markdown代码块、禁止解释文字、禁止注释。\n"
|
||||||
"self_account_tail_no(本方账户尾号), order_no(订单号), remark(备注), confidence(0-1)。"
|
"2) 键名必须使用英文双引号;字符串值必须使用英文双引号;禁止单引号、禁止中文引号。\n"
|
||||||
"如果截图包含多笔交易,返回JSON数组。否则返回单个JSON对象。"
|
"3) 禁止尾逗号,禁止 NaN/Infinity/undefined。\n"
|
||||||
|
"4) 若只有1笔交易,输出对象;若>=2笔交易,输出数组。\n"
|
||||||
|
"5) 每笔交易字段固定为:\n"
|
||||||
|
' {"trade_time":"YYYY-MM-DD HH:MM:SS","amount":123.45,"direction":"in|out",'
|
||||||
|
'"counterparty_name":"","counterparty_account":"","self_account_tail_no":"","order_no":"","remark":"","confidence":0.0}\n'
|
||||||
|
"6) 字段约束:\n"
|
||||||
|
"- trade_time: 必须为 YYYY-MM-DD HH:MM:SS;无法识别时填空字符串\"\"。\n"
|
||||||
|
"- amount: 必须是数字(不要货币符号);无法识别时填 0。\n"
|
||||||
|
"- direction: 仅允许 in 或 out;无法判断默认 out。\n"
|
||||||
|
"- confidence: 0~1 的数字;无法判断默认 0.5。\n"
|
||||||
|
"- 其他文本字段无法识别时填空字符串\"\"。\n\n"
|
||||||
|
"【输出前自检】\n"
|
||||||
|
"- 检查是否是严格JSON(可被 json.loads 直接解析)。\n"
|
||||||
|
"- 检查每条记录都包含全部9个字段。\n"
|
||||||
|
"- 检查 amount/confidence 为数字类型,direction 仅为 in/out。\n"
|
||||||
|
"现在只输出最终JSON,不要输出任何额外文本。"
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async with httpx.AsyncClient(timeout=60) as client:
|
text = await _call_vision(prompt, image_b64, max_tokens=6000)
|
||||||
resp = await client.post(
|
try:
|
||||||
settings.LLM_API_URL,
|
parsed = _parse_json_response(text)
|
||||||
headers={"Authorization": f"Bearer {settings.LLM_API_KEY}"},
|
except Exception as parse_err:
|
||||||
json={
|
if not ENABLE_LLM_REPAIR or not _llm_available():
|
||||||
"model": settings.LLM_MODEL,
|
raise
|
||||||
"messages": [
|
repaired_text = await _repair_broken_json_with_llm(text)
|
||||||
{
|
parsed = _parse_json_response(repaired_text)
|
||||||
"role": "user",
|
return parsed, text
|
||||||
"content": [
|
|
||||||
{"type": "text", "text": prompt},
|
|
||||||
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_b64}"}},
|
|
||||||
],
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"max_tokens": 2000,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
resp.raise_for_status()
|
|
||||||
text = resp.json()["choices"][0]["message"]["content"]
|
|
||||||
return json.loads(text.strip().strip("`").removeprefix("json").strip())
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning("extract_transaction_fields API failed: %s", e)
|
logger.warning("extract_transaction_fields API failed: %s", e)
|
||||||
return {}
|
return {}, text if "text" in locals() else ""
|
||||||
|
|
||||||
|
|
||||||
# ── mock fallback ────────────────────────────────────────────────────────
|
# ── mock fallback ────────────────────────────────────────────────────────
|
||||||
|
|||||||
@@ -16,31 +16,42 @@ def parse_extracted_fields(
|
|||||||
items = raw if isinstance(raw, list) else [raw]
|
items = raw if isinstance(raw, list) else [raw]
|
||||||
records: list[TransactionRecord] = []
|
records: list[TransactionRecord] = []
|
||||||
|
|
||||||
for item in items:
|
for idx, item in enumerate(items):
|
||||||
if not item or not item.get("amount"):
|
if not item or not item.get("amount"):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
raw_trade_time = item.get("trade_time")
|
||||||
try:
|
try:
|
||||||
trade_time = datetime.fromisoformat(item["trade_time"])
|
if isinstance(raw_trade_time, datetime):
|
||||||
except (ValueError, KeyError):
|
trade_time = raw_trade_time
|
||||||
|
elif isinstance(raw_trade_time, str):
|
||||||
|
trade_time = datetime.fromisoformat(raw_trade_time)
|
||||||
|
else:
|
||||||
|
raise TypeError("trade_time is not str/datetime")
|
||||||
|
except (ValueError, KeyError, TypeError):
|
||||||
trade_time = datetime.now()
|
trade_time = datetime.now()
|
||||||
|
|
||||||
direction_str = item.get("direction", "out")
|
direction_str = item.get("direction", "out")
|
||||||
direction = Direction.in_ if direction_str == "in" else Direction.out
|
direction = Direction.in_ if direction_str == "in" else Direction.out
|
||||||
|
try:
|
||||||
|
amount = float(item.get("amount", 0))
|
||||||
|
confidence = float(item.get("confidence", 0.5))
|
||||||
|
except Exception:
|
||||||
|
raise
|
||||||
|
|
||||||
record = TransactionRecord(
|
record = TransactionRecord(
|
||||||
case_id=case_id,
|
case_id=case_id,
|
||||||
evidence_image_id=image_id,
|
evidence_image_id=image_id,
|
||||||
source_app=source_app,
|
source_app=source_app,
|
||||||
trade_time=trade_time,
|
trade_time=trade_time,
|
||||||
amount=float(item.get("amount", 0)),
|
amount=amount,
|
||||||
direction=direction,
|
direction=direction,
|
||||||
counterparty_name=str(item.get("counterparty_name", "")),
|
counterparty_name=str(item.get("counterparty_name", "")),
|
||||||
counterparty_account=str(item.get("counterparty_account", "")),
|
counterparty_account=str(item.get("counterparty_account", "")),
|
||||||
self_account_tail_no=str(item.get("self_account_tail_no", "")),
|
self_account_tail_no=str(item.get("self_account_tail_no", "")),
|
||||||
order_no=str(item.get("order_no", "")),
|
order_no=str(item.get("order_no", "")),
|
||||||
remark=str(item.get("remark", "")),
|
remark=str(item.get("remark", "")),
|
||||||
confidence=float(item.get("confidence", 0.5)),
|
confidence=confidence,
|
||||||
)
|
)
|
||||||
records.append(record)
|
records.append(record)
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""Celery tasks for OCR processing of uploaded screenshots."""
|
"""Celery tasks for OCR processing of uploaded screenshots."""
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
|
import json
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from app.workers.celery_app import celery_app
|
from app.workers.celery_app import celery_app
|
||||||
@@ -20,18 +21,19 @@ def _run_async(coro):
|
|||||||
@celery_app.task(name="app.workers.ocr_tasks.process_image_ocr", bind=True, max_retries=3)
|
@celery_app.task(name="app.workers.ocr_tasks.process_image_ocr", bind=True, max_retries=3)
|
||||||
def process_image_ocr(self, image_id: str):
|
def process_image_ocr(self, image_id: str):
|
||||||
"""Process a single image: classify page, extract fields, save to DB."""
|
"""Process a single image: classify page, extract fields, save to DB."""
|
||||||
_run_async(_process(image_id))
|
_run_async(process_image_ocr_async(image_id))
|
||||||
|
|
||||||
|
|
||||||
async def _process(image_id_str: str):
|
async def process_image_ocr_async(image_id_str: str):
|
||||||
from app.core.database import async_session_factory
|
from app.core.database import async_session_factory
|
||||||
|
from sqlalchemy import delete
|
||||||
from app.models.evidence_image import EvidenceImage, OcrStatus
|
from app.models.evidence_image import EvidenceImage, OcrStatus
|
||||||
from app.models.ocr_block import OcrBlock
|
from app.models.ocr_block import OcrBlock
|
||||||
|
from app.models.transaction import TransactionRecord
|
||||||
from app.services.ocr_service import classify_page, extract_transaction_fields
|
from app.services.ocr_service import classify_page, extract_transaction_fields
|
||||||
from app.services.parser_service import parse_extracted_fields
|
from app.services.parser_service import parse_extracted_fields
|
||||||
|
|
||||||
image_id = UUID(image_id_str)
|
image_id = UUID(image_id_str)
|
||||||
|
|
||||||
async with async_session_factory() as db:
|
async with async_session_factory() as db:
|
||||||
image = await db.get(EvidenceImage, image_id)
|
image = await db.get(EvidenceImage, image_id)
|
||||||
if not image:
|
if not image:
|
||||||
@@ -40,18 +42,34 @@ async def _process(image_id_str: str):
|
|||||||
|
|
||||||
image.ocr_status = OcrStatus.processing
|
image.ocr_status = OcrStatus.processing
|
||||||
await db.flush()
|
await db.flush()
|
||||||
|
# Re-run OCR for this image should replace old OCR blocks/records.
|
||||||
|
await db.execute(delete(OcrBlock).where(OcrBlock.image_id == image.id))
|
||||||
|
await db.execute(delete(TransactionRecord).where(TransactionRecord.evidence_image_id == image.id))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
source_app, page_type = await classify_page(image.file_path)
|
source_app, page_type = await classify_page(image.file_path)
|
||||||
image.source_app = source_app
|
image.source_app = source_app
|
||||||
image.page_type = page_type
|
image.page_type = page_type
|
||||||
|
|
||||||
raw_fields = await extract_transaction_fields(image.file_path, source_app, page_type)
|
raw_fields, raw_ocr_text = await extract_transaction_fields(image.file_path, source_app, page_type)
|
||||||
|
|
||||||
# save raw OCR block
|
is_empty_extract = raw_fields is None or raw_fields == {} or raw_fields == []
|
||||||
|
|
||||||
|
# save raw OCR block (direct model output for debugging)
|
||||||
|
raw_block_content = (raw_ocr_text or "").strip()
|
||||||
|
if raw_block_content:
|
||||||
block = OcrBlock(
|
block = OcrBlock(
|
||||||
image_id=image.id,
|
image_id=image.id,
|
||||||
content=str(raw_fields),
|
content=raw_block_content,
|
||||||
|
bbox={},
|
||||||
|
seq_order=0,
|
||||||
|
confidence=raw_fields.get("confidence", 0.5) if isinstance(raw_fields, dict) else 0.5,
|
||||||
|
)
|
||||||
|
db.add(block)
|
||||||
|
elif not is_empty_extract:
|
||||||
|
block = OcrBlock(
|
||||||
|
image_id=image.id,
|
||||||
|
content=json.dumps(raw_fields, ensure_ascii=False),
|
||||||
bbox={},
|
bbox={},
|
||||||
seq_order=0,
|
seq_order=0,
|
||||||
confidence=raw_fields.get("confidence", 0.5) if isinstance(raw_fields, dict) else 0.5,
|
confidence=raw_fields.get("confidence", 0.5) if isinstance(raw_fields, dict) else 0.5,
|
||||||
|
|||||||
18
backend/requirements.txt
Normal file
18
backend/requirements.txt
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
fastapi==0.135.1
|
||||||
|
uvicorn[standard]==0.41.0
|
||||||
|
sqlalchemy[asyncio]==2.0.48
|
||||||
|
asyncpg==0.31.0
|
||||||
|
alembic==1.18.4
|
||||||
|
celery[redis]==5.6.2
|
||||||
|
redis==6.4.0
|
||||||
|
pydantic-settings>=2.0.0
|
||||||
|
python-multipart==0.0.22
|
||||||
|
Pillow==12.1.0
|
||||||
|
httpx==0.28.1
|
||||||
|
openpyxl==3.1.5
|
||||||
|
python-docx==1.2.0
|
||||||
|
psycopg2-binary==2.9.11
|
||||||
|
|
||||||
|
# dev dependencies
|
||||||
|
pytest==9.0.2
|
||||||
|
pytest-asyncio>=0.24.0
|
||||||
Reference in New Issue
Block a user