57 lines
2.1 KiB
Python
57 lines
2.1 KiB
Python
|
|
"""OpenAI Vision provider (GPT-4o)."""
|
||
|
|
|
||
|
|
import base64
|
||
|
|
import json
|
||
|
|
import re
|
||
|
|
from openai import AsyncOpenAI
|
||
|
|
|
||
|
|
from app.config import get_settings
|
||
|
|
from app.schemas.transaction import TransactionExtractItem
|
||
|
|
from app.services.llm.base import BaseLLMProvider
|
||
|
|
from app.prompts.extract_transaction import get_extract_messages
|
||
|
|
|
||
|
|
|
||
|
|
class OpenAIVisionProvider(BaseLLMProvider):
|
||
|
|
async def extract_from_image(self, image_bytes: bytes) -> list[TransactionExtractItem]:
|
||
|
|
settings = get_settings()
|
||
|
|
if not settings.openai_api_key:
|
||
|
|
raise ValueError("OPENAI_API_KEY is not set")
|
||
|
|
client = AsyncOpenAI(api_key=settings.openai_api_key)
|
||
|
|
b64 = base64.standard_b64encode(image_bytes).decode("ascii")
|
||
|
|
messages = get_extract_messages(b64)
|
||
|
|
response = await client.chat.completions.create(
|
||
|
|
model=settings.openai_model,
|
||
|
|
messages=messages,
|
||
|
|
max_tokens=4096,
|
||
|
|
)
|
||
|
|
text = response.choices[0].message.content or "[]"
|
||
|
|
return _parse_json_array(text)
|
||
|
|
|
||
|
|
|
||
|
|
def _parse_json_array(text: str) -> list[TransactionExtractItem]:
|
||
|
|
"""Parse LLM response into list of TransactionExtractItem. Tolerates markdown and extra text."""
|
||
|
|
text = text.strip()
|
||
|
|
# Remove optional markdown code block
|
||
|
|
if text.startswith("```"):
|
||
|
|
text = re.sub(r"^```(?:json)?\s*", "", text)
|
||
|
|
text = re.sub(r"\s*```\s*$", "", text)
|
||
|
|
try:
|
||
|
|
data = json.loads(text)
|
||
|
|
except json.JSONDecodeError:
|
||
|
|
return []
|
||
|
|
if not isinstance(data, list):
|
||
|
|
return []
|
||
|
|
result: list[TransactionExtractItem] = []
|
||
|
|
for item in data:
|
||
|
|
if not isinstance(item, dict):
|
||
|
|
continue
|
||
|
|
try:
|
||
|
|
# Normalize transaction_time: allow string -> datetime
|
||
|
|
if isinstance(item.get("transaction_time"), str) and item["transaction_time"]:
|
||
|
|
from dateutil import parser as date_parser
|
||
|
|
item["transaction_time"] = date_parser.isoparse(item["transaction_time"])
|
||
|
|
result.append(TransactionExtractItem.model_validate(item))
|
||
|
|
except Exception:
|
||
|
|
continue
|
||
|
|
return result
|