update: upload fix
This commit is contained in:
@@ -6,13 +6,16 @@ from app.schemas.transaction import TransactionExtractItem
|
||||
|
||||
|
||||
class BaseLLMProvider(ABC):
|
||||
"""Abstract base for LLM vision providers. Each provider implements extract_from_image."""
|
||||
"""Abstract base for LLM providers. Supports optional model override."""
|
||||
|
||||
def __init__(self, model_override: str | None = None):
|
||||
self._model_override = model_override
|
||||
|
||||
@abstractmethod
|
||||
async def extract_from_image(self, image_bytes: bytes) -> list[TransactionExtractItem]:
|
||||
"""
|
||||
Analyze a billing screenshot and return structured transaction list.
|
||||
:param image_bytes: Raw image file content (PNG/JPEG)
|
||||
:return: List of extracted transactions (may be empty or partial on failure)
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def chat(self, system: str, user: str) -> str:
|
||||
"""Plain text chat (for inference/reasoning tasks like report generation)."""
|
||||
pass
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
"""Anthropic Claude Vision provider."""
|
||||
|
||||
import base64
|
||||
import json
|
||||
import re
|
||||
from anthropic import AsyncAnthropic
|
||||
|
||||
from app.config import get_settings
|
||||
@@ -13,6 +11,9 @@ from app.services.llm.openai_vision import _parse_json_array
|
||||
|
||||
|
||||
class ClaudeVisionProvider(BaseLLMProvider):
|
||||
def _get_model(self) -> str:
|
||||
return self._model_override or get_settings().anthropic_model
|
||||
|
||||
async def extract_from_image(self, image_bytes: bytes) -> list[TransactionExtractItem]:
|
||||
settings = get_settings()
|
||||
if not settings.anthropic_api_key:
|
||||
@@ -20,14 +21,12 @@ class ClaudeVisionProvider(BaseLLMProvider):
|
||||
client = AsyncAnthropic(api_key=settings.anthropic_api_key)
|
||||
b64 = base64.standard_b64encode(image_bytes).decode("ascii")
|
||||
messages = get_extract_messages(b64)
|
||||
# Claude API: user message with content block list
|
||||
user_content = messages[1]["content"]
|
||||
content_blocks = []
|
||||
for block in user_content:
|
||||
if block["type"] == "text":
|
||||
content_blocks.append({"type": "text", "text": block["text"]})
|
||||
elif block["type"] == "image_url":
|
||||
# Claude expects base64 without data URL prefix
|
||||
content_blocks.append({
|
||||
"type": "image",
|
||||
"source": {
|
||||
@@ -37,7 +36,7 @@ class ClaudeVisionProvider(BaseLLMProvider):
|
||||
},
|
||||
})
|
||||
response = await client.messages.create(
|
||||
model=settings.anthropic_model,
|
||||
model=self._get_model(),
|
||||
max_tokens=4096,
|
||||
system=messages[0]["content"],
|
||||
messages=[{"role": "user", "content": content_blocks}],
|
||||
@@ -47,3 +46,20 @@ class ClaudeVisionProvider(BaseLLMProvider):
|
||||
if hasattr(block, "text"):
|
||||
text += block.text
|
||||
return _parse_json_array(text or "[]")
|
||||
|
||||
async def chat(self, system: str, user: str) -> str:
|
||||
settings = get_settings()
|
||||
if not settings.anthropic_api_key:
|
||||
raise ValueError("ANTHROPIC_API_KEY is not set")
|
||||
client = AsyncAnthropic(api_key=settings.anthropic_api_key)
|
||||
response = await client.messages.create(
|
||||
model=self._get_model(),
|
||||
max_tokens=4096,
|
||||
system=system,
|
||||
messages=[{"role": "user", "content": user}],
|
||||
)
|
||||
text = ""
|
||||
for block in response.content:
|
||||
if hasattr(block, "text"):
|
||||
text += block.text
|
||||
return text or ""
|
||||
|
||||
@@ -11,22 +11,44 @@ from app.services.llm.openai_vision import _parse_json_array
|
||||
|
||||
|
||||
class CustomOpenAICompatibleProvider(BaseLLMProvider):
|
||||
async def extract_from_image(self, image_bytes: bytes) -> list[TransactionExtractItem]:
|
||||
def _get_client(self) -> AsyncOpenAI:
|
||||
settings = get_settings()
|
||||
if not settings.custom_openai_api_key:
|
||||
raise ValueError("CUSTOM_OPENAI_API_KEY is not set")
|
||||
if not settings.custom_openai_base_url:
|
||||
raise ValueError("CUSTOM_OPENAI_BASE_URL is not set")
|
||||
client = AsyncOpenAI(
|
||||
return AsyncOpenAI(
|
||||
api_key=settings.custom_openai_api_key,
|
||||
base_url=settings.custom_openai_base_url,
|
||||
timeout=45.0,
|
||||
max_retries=1,
|
||||
)
|
||||
|
||||
def _get_model(self) -> str:
|
||||
return self._model_override or get_settings().custom_openai_model
|
||||
|
||||
async def extract_from_image(self, image_bytes: bytes) -> list[TransactionExtractItem]:
|
||||
client = self._get_client()
|
||||
b64 = base64.standard_b64encode(image_bytes).decode("ascii")
|
||||
messages = get_extract_messages(b64)
|
||||
response = await client.chat.completions.create(
|
||||
model=settings.custom_openai_model,
|
||||
model=self._get_model(),
|
||||
messages=messages,
|
||||
max_tokens=4096,
|
||||
timeout=45.0,
|
||||
)
|
||||
text = response.choices[0].message.content or "[]"
|
||||
return _parse_json_array(text)
|
||||
|
||||
async def chat(self, system: str, user: str) -> str:
|
||||
client = self._get_client()
|
||||
response = await client.chat.completions.create(
|
||||
model=self._get_model(),
|
||||
messages=[
|
||||
{"role": "system", "content": system},
|
||||
{"role": "user", "content": user},
|
||||
],
|
||||
max_tokens=4096,
|
||||
timeout=45.0,
|
||||
)
|
||||
return response.choices[0].message.content or ""
|
||||
|
||||
@@ -9,26 +9,39 @@ from app.services.llm.base import BaseLLMProvider
|
||||
from app.prompts.extract_transaction import get_extract_messages
|
||||
from app.services.llm.openai_vision import _parse_json_array
|
||||
|
||||
|
||||
# DeepSeek vision endpoint (OpenAI-compatible)
|
||||
DEEPSEEK_BASE = "https://api.deepseek.com"
|
||||
|
||||
|
||||
class DeepSeekVisionProvider(BaseLLMProvider):
|
||||
async def extract_from_image(self, image_bytes: bytes) -> list[TransactionExtractItem]:
|
||||
def _get_client(self) -> AsyncOpenAI:
|
||||
settings = get_settings()
|
||||
if not settings.deepseek_api_key:
|
||||
raise ValueError("DEEPSEEK_API_KEY is not set")
|
||||
client = AsyncOpenAI(
|
||||
api_key=settings.deepseek_api_key,
|
||||
base_url=DEEPSEEK_BASE,
|
||||
)
|
||||
return AsyncOpenAI(api_key=settings.deepseek_api_key, base_url=DEEPSEEK_BASE)
|
||||
|
||||
def _get_model(self) -> str:
|
||||
return self._model_override or get_settings().deepseek_model
|
||||
|
||||
async def extract_from_image(self, image_bytes: bytes) -> list[TransactionExtractItem]:
|
||||
client = self._get_client()
|
||||
b64 = base64.standard_b64encode(image_bytes).decode("ascii")
|
||||
messages = get_extract_messages(b64)
|
||||
response = await client.chat.completions.create(
|
||||
model=settings.deepseek_model,
|
||||
model=self._get_model(),
|
||||
messages=messages,
|
||||
max_tokens=4096,
|
||||
)
|
||||
text = response.choices[0].message.content or "[]"
|
||||
return _parse_json_array(text)
|
||||
|
||||
async def chat(self, system: str, user: str) -> str:
|
||||
client = self._get_client()
|
||||
response = await client.chat.completions.create(
|
||||
model=self._get_model(),
|
||||
messages=[
|
||||
{"role": "system", "content": system},
|
||||
{"role": "user", "content": user},
|
||||
],
|
||||
max_tokens=4096,
|
||||
)
|
||||
return response.choices[0].message.content or ""
|
||||
|
||||
@@ -12,6 +12,9 @@ from app.prompts.extract_transaction import get_extract_messages
|
||||
|
||||
|
||||
class OpenAIVisionProvider(BaseLLMProvider):
|
||||
def _get_model(self) -> str:
|
||||
return self._model_override or get_settings().openai_model
|
||||
|
||||
async def extract_from_image(self, image_bytes: bytes) -> list[TransactionExtractItem]:
|
||||
settings = get_settings()
|
||||
if not settings.openai_api_key:
|
||||
@@ -20,18 +23,32 @@ class OpenAIVisionProvider(BaseLLMProvider):
|
||||
b64 = base64.standard_b64encode(image_bytes).decode("ascii")
|
||||
messages = get_extract_messages(b64)
|
||||
response = await client.chat.completions.create(
|
||||
model=settings.openai_model,
|
||||
model=self._get_model(),
|
||||
messages=messages,
|
||||
max_tokens=4096,
|
||||
)
|
||||
text = response.choices[0].message.content or "[]"
|
||||
return _parse_json_array(text)
|
||||
|
||||
async def chat(self, system: str, user: str) -> str:
|
||||
settings = get_settings()
|
||||
if not settings.openai_api_key:
|
||||
raise ValueError("OPENAI_API_KEY is not set")
|
||||
client = AsyncOpenAI(api_key=settings.openai_api_key)
|
||||
response = await client.chat.completions.create(
|
||||
model=self._get_model(),
|
||||
messages=[
|
||||
{"role": "system", "content": system},
|
||||
{"role": "user", "content": user},
|
||||
],
|
||||
max_tokens=4096,
|
||||
)
|
||||
return response.choices[0].message.content or ""
|
||||
|
||||
|
||||
def _parse_json_array(text: str) -> list[TransactionExtractItem]:
|
||||
"""Parse LLM response into list of TransactionExtractItem. Tolerates markdown and extra text."""
|
||||
text = text.strip()
|
||||
# Remove optional markdown code block
|
||||
if text.startswith("```"):
|
||||
text = re.sub(r"^```(?:json)?\s*", "", text)
|
||||
text = re.sub(r"\s*```\s*$", "", text)
|
||||
@@ -46,7 +63,6 @@ def _parse_json_array(text: str) -> list[TransactionExtractItem]:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
try:
|
||||
# Normalize transaction_time: allow string -> datetime
|
||||
if isinstance(item.get("transaction_time"), str) and item["transaction_time"]:
|
||||
from dateutil import parser as date_parser
|
||||
item["transaction_time"] = date_parser.isoparse(item["transaction_time"])
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""LLM provider factory - returns provider by config."""
|
||||
"""LLM provider factory - returns provider by config, split by role (ocr / inference)."""
|
||||
|
||||
from app.config import get_settings
|
||||
from app.config import get_settings, get_ocr_model, get_inference_model
|
||||
from app.services.llm.base import BaseLLMProvider
|
||||
from app.services.llm.openai_vision import OpenAIVisionProvider
|
||||
from app.services.llm.claude_vision import ClaudeVisionProvider
|
||||
@@ -8,15 +8,25 @@ from app.services.llm.deepseek_vision import DeepSeekVisionProvider
|
||||
from app.services.llm.custom_openai_vision import CustomOpenAICompatibleProvider
|
||||
|
||||
|
||||
def get_llm_provider() -> BaseLLMProvider:
|
||||
def _make_provider(provider_name: str, model_override: str | None = None) -> BaseLLMProvider:
|
||||
name = (provider_name or "openai").lower()
|
||||
if name == "openai":
|
||||
return OpenAIVisionProvider(model_override=model_override)
|
||||
if name == "anthropic":
|
||||
return ClaudeVisionProvider(model_override=model_override)
|
||||
if name == "deepseek":
|
||||
return DeepSeekVisionProvider(model_override=model_override)
|
||||
if name == "custom_openai":
|
||||
return CustomOpenAICompatibleProvider(model_override=model_override)
|
||||
return OpenAIVisionProvider(model_override=model_override)
|
||||
|
||||
|
||||
def get_llm_provider(role: str = "ocr") -> BaseLLMProvider:
|
||||
"""
|
||||
role="ocr" -> uses ocr_provider + ocr_model
|
||||
role="inference" -> uses inference_provider + inference_model
|
||||
"""
|
||||
settings = get_settings()
|
||||
provider = (settings.llm_provider or "openai").lower()
|
||||
if provider == "openai":
|
||||
return OpenAIVisionProvider()
|
||||
if provider == "anthropic":
|
||||
return ClaudeVisionProvider()
|
||||
if provider == "deepseek":
|
||||
return DeepSeekVisionProvider()
|
||||
if provider == "custom_openai":
|
||||
return CustomOpenAICompatibleProvider()
|
||||
return OpenAIVisionProvider()
|
||||
if role == "inference":
|
||||
return _make_provider(settings.inference_provider, get_inference_model())
|
||||
return _make_provider(settings.ocr_provider, get_ocr_model())
|
||||
|
||||
Reference in New Issue
Block a user