update: upload fix

This commit is contained in:
2026-03-10 14:25:21 +08:00
parent a3d928e697
commit fd2b574d5a
19 changed files with 575 additions and 156 deletions

View File

@@ -1,6 +1,6 @@
"""LLM provider factory - returns provider by config."""
"""LLM provider factory - returns provider by config, split by role (ocr / inference)."""
from app.config import get_settings
from app.config import get_settings, get_ocr_model, get_inference_model
from app.services.llm.base import BaseLLMProvider
from app.services.llm.openai_vision import OpenAIVisionProvider
from app.services.llm.claude_vision import ClaudeVisionProvider
@@ -8,15 +8,25 @@ from app.services.llm.deepseek_vision import DeepSeekVisionProvider
from app.services.llm.custom_openai_vision import CustomOpenAICompatibleProvider
def get_llm_provider() -> BaseLLMProvider:
def _make_provider(provider_name: str, model_override: str | None = None) -> BaseLLMProvider:
name = (provider_name or "openai").lower()
if name == "openai":
return OpenAIVisionProvider(model_override=model_override)
if name == "anthropic":
return ClaudeVisionProvider(model_override=model_override)
if name == "deepseek":
return DeepSeekVisionProvider(model_override=model_override)
if name == "custom_openai":
return CustomOpenAICompatibleProvider(model_override=model_override)
return OpenAIVisionProvider(model_override=model_override)
def get_llm_provider(role: str = "ocr") -> BaseLLMProvider:
"""
role="ocr" -> uses ocr_provider + ocr_model
role="inference" -> uses inference_provider + inference_model
"""
settings = get_settings()
provider = (settings.llm_provider or "openai").lower()
if provider == "openai":
return OpenAIVisionProvider()
if provider == "anthropic":
return ClaudeVisionProvider()
if provider == "deepseek":
return DeepSeekVisionProvider()
if provider == "custom_openai":
return CustomOpenAICompatibleProvider()
return OpenAIVisionProvider()
if role == "inference":
return _make_provider(settings.inference_provider, get_inference_model())
return _make_provider(settings.ocr_provider, get_ocr_model())