Files
fund-tracer/backend/app/services/llm/openai_vision.py
2026-03-10 14:25:21 +08:00

73 lines
2.6 KiB
Python

"""OpenAI Vision provider (GPT-4o)."""
import base64
import json
import re
from openai import AsyncOpenAI
from app.config import get_settings
from app.schemas.transaction import TransactionExtractItem
from app.services.llm.base import BaseLLMProvider
from app.prompts.extract_transaction import get_extract_messages
class OpenAIVisionProvider(BaseLLMProvider):
def _get_model(self) -> str:
return self._model_override or get_settings().openai_model
async def extract_from_image(self, image_bytes: bytes) -> list[TransactionExtractItem]:
settings = get_settings()
if not settings.openai_api_key:
raise ValueError("OPENAI_API_KEY is not set")
client = AsyncOpenAI(api_key=settings.openai_api_key)
b64 = base64.standard_b64encode(image_bytes).decode("ascii")
messages = get_extract_messages(b64)
response = await client.chat.completions.create(
model=self._get_model(),
messages=messages,
max_tokens=4096,
)
text = response.choices[0].message.content or "[]"
return _parse_json_array(text)
async def chat(self, system: str, user: str) -> str:
settings = get_settings()
if not settings.openai_api_key:
raise ValueError("OPENAI_API_KEY is not set")
client = AsyncOpenAI(api_key=settings.openai_api_key)
response = await client.chat.completions.create(
model=self._get_model(),
messages=[
{"role": "system", "content": system},
{"role": "user", "content": user},
],
max_tokens=4096,
)
return response.choices[0].message.content or ""
def _parse_json_array(text: str) -> list[TransactionExtractItem]:
"""Parse LLM response into list of TransactionExtractItem. Tolerates markdown and extra text."""
text = text.strip()
if text.startswith("```"):
text = re.sub(r"^```(?:json)?\s*", "", text)
text = re.sub(r"\s*```\s*$", "", text)
try:
data = json.loads(text)
except json.JSONDecodeError:
return []
if not isinstance(data, list):
return []
result: list[TransactionExtractItem] = []
for item in data:
if not isinstance(item, dict):
continue
try:
if isinstance(item.get("transaction_time"), str) and item["transaction_time"]:
from dateutil import parser as date_parser
item["transaction_time"] = date_parser.isoparse(item["transaction_time"])
result.append(TransactionExtractItem.model_validate(item))
except Exception:
continue
return result