777 lines
30 KiB
Python
777 lines
30 KiB
Python
|
|
#!/usr/bin/env python3
|
|||
|
|
# -*- coding: utf-8 -*-
|
|||
|
|
"""
|
|||
|
|
批量图片识别脚本
|
|||
|
|
|
|||
|
|
功能:
|
|||
|
|
1. 遍历指定目录下的图片文件
|
|||
|
|
2. 将图片上传至大模型 API(OpenAI、Anthropic,或本地模拟模式)
|
|||
|
|
3. 输出包含图片中文本、物品描述、风险概述的 Excel 报表
|
|||
|
|
|
|||
|
|
使用示例:
|
|||
|
|
cd scripts
|
|||
|
|
python3 image_batch_recognizer.py --limit 5 --api-type openai
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
from __future__ import annotations
|
|||
|
|
|
|||
|
|
import argparse
|
|||
|
|
import base64
|
|||
|
|
import json
|
|||
|
|
import mimetypes
|
|||
|
|
import os
|
|||
|
|
import sys
|
|||
|
|
import time
|
|||
|
|
from dataclasses import dataclass, field
|
|||
|
|
from datetime import datetime
|
|||
|
|
from pathlib import Path
|
|||
|
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|||
|
|
from typing import Any, Dict, List, Optional, Sequence, Tuple
|
|||
|
|
|
|||
|
|
import pandas as pd
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
import requests # type: ignore
|
|||
|
|
# 禁用 SSL 警告(用于自签名证书)
|
|||
|
|
import urllib3
|
|||
|
|
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
|||
|
|
except ImportError: # pragma: no cover - requests 是可选依赖
|
|||
|
|
requests = None
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
from tqdm import tqdm # type: ignore
|
|||
|
|
except ImportError:
|
|||
|
|
# 如果没有安装 tqdm,提供一个简单的替代
|
|||
|
|
class tqdm:
|
|||
|
|
def __init__(self, iterable=None, total=None, desc=None, **kwargs):
|
|||
|
|
self.iterable = iterable
|
|||
|
|
self.total = total or (len(iterable) if iterable else 0)
|
|||
|
|
self.desc = desc
|
|||
|
|
self.n = 0
|
|||
|
|
|
|||
|
|
def __iter__(self):
|
|||
|
|
for item in self.iterable:
|
|||
|
|
yield item
|
|||
|
|
self.n += 1
|
|||
|
|
|
|||
|
|
def update(self, n=1):
|
|||
|
|
self.n += n
|
|||
|
|
|
|||
|
|
def set_postfix_str(self, s):
|
|||
|
|
# 简化版本,不做任何操作
|
|||
|
|
pass
|
|||
|
|
|
|||
|
|
def close(self):
|
|||
|
|
pass
|
|||
|
|
|
|||
|
|
SEPARATOR = "|||"
|
|||
|
|
DEFAULT_SLEEP_SECONDS = 0.0
|
|||
|
|
DEFAULT_TIMEOUT = 90
|
|||
|
|
DEFAULT_MAX_TOKENS = 800
|
|||
|
|
SUPPORTED_EXTENSIONS = {".jpg", ".jpeg", ".png", ".webp"}
|
|||
|
|
|
|||
|
|
DEFAULT_ANALYSIS_PROMPT = """
|
|||
|
|
你是一名涉毒风控场景下的多模态分析师。请识别我提供的图片内容,并输出 JSON。
|
|||
|
|
JSON 字段要求:
|
|||
|
|
{
|
|||
|
|
"detected_text": ["按行列出的文字内容,若无文字返回空数组"],
|
|||
|
|
"detected_objects": ["出现的主要物品或场景要素,中文描述"],
|
|||
|
|
"sensitive_items": ["可疑化学品、药物、工具等,如无则为空数组"],
|
|||
|
|
"summary": "2-3 句话概括图片与潜在风险,中文",
|
|||
|
|
"confidence": "High | Medium | Low"
|
|||
|
|
}
|
|||
|
|
只返回 JSON,不要额外说明。
|
|||
|
|
""".strip()
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ----------- 数据结构 -----------
|
|||
|
|
@dataclass
|
|||
|
|
class RecognitionResult:
|
|||
|
|
image_name: str
|
|||
|
|
image_path: str
|
|||
|
|
detected_text: List[str] = field(default_factory=list)
|
|||
|
|
detected_objects: List[str] = field(default_factory=list)
|
|||
|
|
sensitive_items: List[str] = field(default_factory=list)
|
|||
|
|
summary: str = ""
|
|||
|
|
confidence: str = ""
|
|||
|
|
raw_response: str = ""
|
|||
|
|
api_type: str = ""
|
|||
|
|
model_name: str = ""
|
|||
|
|
latency_seconds: float = 0.0
|
|||
|
|
analyzed_at: str = ""
|
|||
|
|
error: Optional[str] = None
|
|||
|
|
|
|||
|
|
def to_row(self) -> Dict[str, Any]:
|
|||
|
|
"""转换为写入 Excel 的行数据"""
|
|||
|
|
return {
|
|||
|
|
"image_name": self.image_name,
|
|||
|
|
"image_path": self.image_path,
|
|||
|
|
"detected_text": "\n".join(self.detected_text).strip(),
|
|||
|
|
"detected_objects": SEPARATOR.join(self.detected_objects),
|
|||
|
|
"sensitive_items": SEPARATOR.join(self.sensitive_items),
|
|||
|
|
"summary": self.summary,
|
|||
|
|
"confidence": self.confidence,
|
|||
|
|
"latency_seconds": round(self.latency_seconds, 2),
|
|||
|
|
"api_type": self.api_type,
|
|||
|
|
"model": self.model_name,
|
|||
|
|
"analyzed_at": self.analyzed_at,
|
|||
|
|
"error": self.error or "",
|
|||
|
|
"raw_response": self.raw_response,
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
@property
|
|||
|
|
def is_success(self) -> bool:
|
|||
|
|
return self.error is None
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ----------- 工具函数 -----------
|
|||
|
|
def encode_image(image_path: Path) -> Tuple[str, str]:
|
|||
|
|
"""以 base64 形式读取图片,同时推断 MIME 类型"""
|
|||
|
|
data = image_path.read_bytes()
|
|||
|
|
encoded = base64.b64encode(data).decode("utf-8")
|
|||
|
|
mime_type, _ = mimetypes.guess_type(image_path.name)
|
|||
|
|
return encoded, mime_type or "image/jpeg"
|
|||
|
|
|
|||
|
|
|
|||
|
|
def extract_json_from_response(text: str) -> Tuple[Optional[Dict[str, Any]], Optional[str]]:
|
|||
|
|
"""
|
|||
|
|
从模型返回的文本中提取 JSON。
|
|||
|
|
优先尝试直接解析,失败后截取首尾花括号进行再次解析。
|
|||
|
|
"""
|
|||
|
|
if not text:
|
|||
|
|
return None, "空响应"
|
|||
|
|
|
|||
|
|
text = text.strip()
|
|||
|
|
try:
|
|||
|
|
return json.loads(text), None
|
|||
|
|
except json.JSONDecodeError:
|
|||
|
|
start = text.find("{")
|
|||
|
|
end = text.rfind("}")
|
|||
|
|
if start != -1 and end != -1 and end > start:
|
|||
|
|
candidate = text[start:end + 1]
|
|||
|
|
try:
|
|||
|
|
return json.loads(candidate), None
|
|||
|
|
except json.JSONDecodeError as exc:
|
|||
|
|
return None, f"JSON 解析失败: {exc}"
|
|||
|
|
return None, "响应中未找到 JSON 对象"
|
|||
|
|
|
|||
|
|
|
|||
|
|
def normalize_list(value: Any) -> List[str]:
|
|||
|
|
"""将模型返回的字段整理为字符串列表"""
|
|||
|
|
if value is None:
|
|||
|
|
return []
|
|||
|
|
if isinstance(value, str):
|
|||
|
|
return [part.strip() for part in value.split(SEPARATOR) if part.strip()]
|
|||
|
|
if isinstance(value, (list, tuple, set)):
|
|||
|
|
return [str(item).strip() for item in value if str(item).strip()]
|
|||
|
|
return [str(value).strip()]
|
|||
|
|
|
|||
|
|
|
|||
|
|
def http_post_json(url: str, headers: Dict[str, str], payload: Dict[str, Any], timeout: int) -> Tuple[int, str]:
|
|||
|
|
"""发送 JSON POST 请求(优先使用 requests,回退到 urllib)"""
|
|||
|
|
body = json.dumps(payload)
|
|||
|
|
headers = {**headers, "Content-Type": "application/json"}
|
|||
|
|
|
|||
|
|
if requests:
|
|||
|
|
# 禁用 SSL 证书验证(用于自签名证书)
|
|||
|
|
response = requests.post(url, headers=headers, data=body, timeout=timeout, verify=False)
|
|||
|
|
return response.status_code, response.text
|
|||
|
|
|
|||
|
|
# requests 不可用时使用 urllib
|
|||
|
|
import urllib.error
|
|||
|
|
import urllib.request
|
|||
|
|
|
|||
|
|
req = urllib.request.Request(url, data=body.encode("utf-8"), headers=headers, method="POST")
|
|||
|
|
try:
|
|||
|
|
with urllib.request.urlopen(req, timeout=timeout) as resp:
|
|||
|
|
response_body = resp.read().decode("utf-8")
|
|||
|
|
return resp.getcode(), response_body
|
|||
|
|
except urllib.error.HTTPError as exc: # pragma: no cover - 网络错误场景
|
|||
|
|
return exc.code, exc.read().decode("utf-8")
|
|||
|
|
|
|||
|
|
|
|||
|
|
def http_post_multipart(url: str, headers: Dict[str, str],
|
|||
|
|
data: Dict[str, str], files: Dict[str, Any],
|
|||
|
|
timeout: int) -> Tuple[int, str]:
|
|||
|
|
"""发送 multipart/form-data POST 请求(用于 Dify 文件上传)"""
|
|||
|
|
if not requests:
|
|||
|
|
raise RuntimeError(
|
|||
|
|
"Dify 文件上传需要 requests 库,请安装: pip install requests"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 禁用 SSL 证书验证(用于自签名证书)
|
|||
|
|
response = requests.post(url, headers=headers, data=data,
|
|||
|
|
files=files, timeout=timeout, verify=False)
|
|||
|
|
return response.status_code, response.text
|
|||
|
|
|
|||
|
|
|
|||
|
|
def build_endpoint(base_url: Optional[str], path: str, default_base: str) -> str:
|
|||
|
|
"""根据 base_url 和 path 生成完整 API 地址"""
|
|||
|
|
base = (base_url or default_base).rstrip("/")
|
|||
|
|
path = path if path.startswith("/") else f"/{path}"
|
|||
|
|
return f"{base}{path}"
|
|||
|
|
|
|||
|
|
|
|||
|
|
def load_env_file(env_file: Optional[str]) -> Optional[Path]:
|
|||
|
|
"""从 .env 文件加载环境变量"""
|
|||
|
|
if not env_file:
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
env_path = Path(env_file).expanduser()
|
|||
|
|
if not env_path.exists():
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
for raw_line in env_path.read_text(encoding="utf-8").splitlines():
|
|||
|
|
line = raw_line.strip()
|
|||
|
|
if not line or line.startswith("#"):
|
|||
|
|
continue
|
|||
|
|
if line.startswith("export "):
|
|||
|
|
line = line[len("export "):].strip()
|
|||
|
|
if "=" not in line:
|
|||
|
|
continue
|
|||
|
|
key, value = line.split("=", 1)
|
|||
|
|
key = key.strip()
|
|||
|
|
value = value.strip().strip('"').strip("'")
|
|||
|
|
if not key:
|
|||
|
|
continue
|
|||
|
|
os.environ[key] = value
|
|||
|
|
|
|||
|
|
print(f"已从 {env_path} 加载环境变量")
|
|||
|
|
return env_path
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ----------- 基础客户端 -----------
|
|||
|
|
class BaseVisionClient:
|
|||
|
|
def __init__(self, model: str, api_type: str, timeout: int = DEFAULT_TIMEOUT):
|
|||
|
|
self.model = model
|
|||
|
|
self.api_type = api_type
|
|||
|
|
self.timeout = timeout
|
|||
|
|
|
|||
|
|
def analyze_image(self, image_path: Path, prompt: str) -> RecognitionResult:
|
|||
|
|
encoded, mime_type = encode_image(image_path)
|
|||
|
|
start = time.time()
|
|||
|
|
analyzed_at = datetime.now().isoformat()
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
raw_text = self._send_request(
|
|||
|
|
image_base64=encoded,
|
|||
|
|
mime_type=mime_type,
|
|||
|
|
prompt=prompt,
|
|||
|
|
image_name=image_path.name,
|
|||
|
|
)
|
|||
|
|
except Exception as exc: # pragma: no cover - 网络错误较难稳定复现
|
|||
|
|
latency = time.time() - start
|
|||
|
|
return RecognitionResult(
|
|||
|
|
image_name=image_path.name,
|
|||
|
|
image_path=str(image_path),
|
|||
|
|
api_type=self.api_type,
|
|||
|
|
model_name=self.model,
|
|||
|
|
latency_seconds=latency,
|
|||
|
|
analyzed_at=analyzed_at,
|
|||
|
|
raw_response="",
|
|||
|
|
error=str(exc),
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
latency = time.time() - start
|
|||
|
|
parsed, parse_error = extract_json_from_response(raw_text)
|
|||
|
|
|
|||
|
|
detected_text = normalize_list(parsed.get("detected_text")) if parsed else []
|
|||
|
|
detected_objects = normalize_list(parsed.get("detected_objects")) if parsed else []
|
|||
|
|
sensitive_items = normalize_list(parsed.get("sensitive_items")) if parsed else []
|
|||
|
|
summary = (parsed or {}).get("summary", "")
|
|||
|
|
confidence = (parsed or {}).get("confidence", "")
|
|||
|
|
|
|||
|
|
return RecognitionResult(
|
|||
|
|
image_name=image_path.name,
|
|||
|
|
image_path=str(image_path),
|
|||
|
|
detected_text=detected_text,
|
|||
|
|
detected_objects=detected_objects,
|
|||
|
|
sensitive_items=sensitive_items,
|
|||
|
|
summary=str(summary),
|
|||
|
|
confidence=str(confidence),
|
|||
|
|
raw_response=raw_text,
|
|||
|
|
api_type=self.api_type,
|
|||
|
|
model_name=self.model,
|
|||
|
|
latency_seconds=latency,
|
|||
|
|
analyzed_at=analyzed_at,
|
|||
|
|
error=parse_error,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
def _send_request(self, image_base64: str, mime_type: str, prompt: str, image_name: str) -> str:
|
|||
|
|
raise NotImplementedError
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ----------- OpenAI 客户端 -----------
|
|||
|
|
class OpenAIVisionClient(BaseVisionClient):
|
|||
|
|
def __init__(self, api_key: str, model: str, timeout: int, base_url: Optional[str] = None):
|
|||
|
|
super().__init__(model=model, api_type="openai", timeout=timeout)
|
|||
|
|
if not api_key:
|
|||
|
|
raise ValueError("OPENAI_API_KEY 未设置")
|
|||
|
|
self.api_key = api_key
|
|||
|
|
self.endpoint = build_endpoint(base_url, "/v1/chat/completions", "https://api.openai.com")
|
|||
|
|
|
|||
|
|
def _send_request(self, image_base64: str, mime_type: str, prompt: str, image_name: str) -> str:
|
|||
|
|
image_url = f"data:{mime_type};base64,{image_base64}"
|
|||
|
|
payload = {
|
|||
|
|
"model": self.model,
|
|||
|
|
"temperature": 0,
|
|||
|
|
"max_tokens": DEFAULT_MAX_TOKENS,
|
|||
|
|
"messages": [
|
|||
|
|
{
|
|||
|
|
"role": "system",
|
|||
|
|
"content": "You are a multimodal analyst that returns structured JSON responses in Chinese.",
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"role": "user",
|
|||
|
|
"content": [
|
|||
|
|
{"type": "text", "text": f"{prompt}\n图片文件名: {image_name}"},
|
|||
|
|
{"type": "image_url", "image_url": {"url": image_url}},
|
|||
|
|
],
|
|||
|
|
},
|
|||
|
|
],
|
|||
|
|
}
|
|||
|
|
headers = {"Authorization": f"Bearer {self.api_key}"}
|
|||
|
|
status_code, response_text = http_post_json(self.endpoint, headers, payload, self.timeout)
|
|||
|
|
|
|||
|
|
if status_code >= 400:
|
|||
|
|
raise RuntimeError(f"OpenAI API 调用失败 ({status_code}): {response_text}")
|
|||
|
|
|
|||
|
|
response_json = json.loads(response_text)
|
|||
|
|
content = response_json["choices"][0]["message"]["content"]
|
|||
|
|
if isinstance(content, list):
|
|||
|
|
text = "\n".join(part.get("text", "") for part in content if isinstance(part, dict))
|
|||
|
|
else:
|
|||
|
|
text = content
|
|||
|
|
return text.strip()
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ----------- DMX (文档上传方式) 客户端 -----------
|
|||
|
|
class DMXVisionClient(BaseVisionClient):
|
|||
|
|
def __init__(self, api_key: str, model: str, timeout: int, base_url: Optional[str] = None):
|
|||
|
|
super().__init__(model=model, api_type="dmx", timeout=timeout)
|
|||
|
|
if not api_key:
|
|||
|
|
raise ValueError("DMX_API_KEY 未设置")
|
|||
|
|
self.api_key = api_key
|
|||
|
|
self.endpoint = build_endpoint(base_url, "/v1/chat/completions", "https://www.dmxapi.cn")
|
|||
|
|
|
|||
|
|
def _send_request(self, image_base64: str, mime_type: str, prompt: str, image_name: str) -> str:
|
|||
|
|
image_url = f"data:{mime_type};base64,{image_base64}"
|
|||
|
|
payload = {
|
|||
|
|
"model": self.model,
|
|||
|
|
"temperature": 0.1,
|
|||
|
|
"messages": [
|
|||
|
|
{
|
|||
|
|
"role": "user",
|
|||
|
|
"content": [
|
|||
|
|
{"type": "text", "text": f"{prompt}\n图片文件名: {image_name}"},
|
|||
|
|
{"type": "image_url", "image_url": {"url": image_url}},
|
|||
|
|
],
|
|||
|
|
}
|
|||
|
|
],
|
|||
|
|
}
|
|||
|
|
headers = {"Authorization": f"Bearer {self.api_key}"}
|
|||
|
|
status_code, response_text = http_post_json(self.endpoint, headers, payload, self.timeout)
|
|||
|
|
if status_code >= 400:
|
|||
|
|
raise RuntimeError(f"DMX API 调用失败 ({status_code}): {response_text}")
|
|||
|
|
|
|||
|
|
response_json = json.loads(response_text)
|
|||
|
|
content = response_json["choices"][0]["message"]["content"]
|
|||
|
|
if isinstance(content, list):
|
|||
|
|
text = "\n".join(part.get("text", "") for part in content if isinstance(part, dict))
|
|||
|
|
else:
|
|||
|
|
text = content
|
|||
|
|
return text.strip()
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ----------- Anthropic 客户端 -----------
|
|||
|
|
class AnthropicVisionClient(BaseVisionClient):
|
|||
|
|
def __init__(self, api_key: str, model: str, timeout: int, base_url: Optional[str] = None):
|
|||
|
|
super().__init__(model=model, api_type="anthropic", timeout=timeout)
|
|||
|
|
if not api_key:
|
|||
|
|
raise ValueError("ANTHROPIC_API_KEY 未设置")
|
|||
|
|
self.api_key = api_key
|
|||
|
|
self.endpoint = build_endpoint(base_url, "/v1/messages", "https://api.anthropic.com")
|
|||
|
|
|
|||
|
|
def _send_request(self, image_base64: str, mime_type: str, prompt: str, image_name: str) -> str:
|
|||
|
|
payload = {
|
|||
|
|
"model": self.model,
|
|||
|
|
"max_tokens": DEFAULT_MAX_TOKENS,
|
|||
|
|
"system": "You are a multimodal analyst that responds with JSON only.",
|
|||
|
|
"messages": [
|
|||
|
|
{
|
|||
|
|
"role": "user",
|
|||
|
|
"content": [
|
|||
|
|
{"type": "text", "text": f"{prompt}\n图片文件名: {image_name}"},
|
|||
|
|
{
|
|||
|
|
"type": "image",
|
|||
|
|
"source": {
|
|||
|
|
"type": "base64",
|
|||
|
|
"media_type": mime_type,
|
|||
|
|
"data": image_base64,
|
|||
|
|
},
|
|||
|
|
},
|
|||
|
|
],
|
|||
|
|
}
|
|||
|
|
],
|
|||
|
|
}
|
|||
|
|
headers = {
|
|||
|
|
"x-api-key": self.api_key,
|
|||
|
|
"anthropic-version": "2023-06-01",
|
|||
|
|
}
|
|||
|
|
status_code, response_text = http_post_json(self.endpoint, headers, payload, self.timeout)
|
|||
|
|
if status_code >= 400:
|
|||
|
|
raise RuntimeError(f"Anthropic API 调用失败 ({status_code}): {response_text}")
|
|||
|
|
|
|||
|
|
response_json = json.loads(response_text)
|
|||
|
|
text_parts = [part.get("text", "") for part in response_json.get("content", []) if part.get("type") == "text"]
|
|||
|
|
return "\n".join(text_parts).strip()
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ----------- 模拟客户端(无 API 时使用) -----------
|
|||
|
|
class MockVisionClient(BaseVisionClient):
|
|||
|
|
def __init__(self):
|
|||
|
|
super().__init__(model="mock", api_type="mock", timeout=1)
|
|||
|
|
|
|||
|
|
def _send_request(self, image_base64: str, mime_type: str, prompt: str, image_name: str) -> str:
|
|||
|
|
del image_base64, mime_type, prompt # 未使用
|
|||
|
|
fake_json = {
|
|||
|
|
"detected_text": ["(模拟模式)未调用真实 API"],
|
|||
|
|
"detected_objects": ["sample-object-from-filename"],
|
|||
|
|
"sensitive_items": [],
|
|||
|
|
"summary": f"Mock 模式:演示 {image_name} 的返回格式。",
|
|||
|
|
"confidence": "Low",
|
|||
|
|
}
|
|||
|
|
return json.dumps(fake_json, ensure_ascii=False)
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ----------- Dify Chatflow 客户端 -----------
|
|||
|
|
class DifyVisionClient(BaseVisionClient):
|
|||
|
|
"""Dify Chatflow API client for vision analysis"""
|
|||
|
|
|
|||
|
|
def __init__(self, api_key: str, model: str, timeout: int,
|
|||
|
|
base_url: Optional[str] = None, user_id: Optional[str] = None):
|
|||
|
|
super().__init__(model=model, api_type="dify", timeout=timeout)
|
|||
|
|
if not api_key:
|
|||
|
|
raise ValueError("DIFY_API_KEY 未设置")
|
|||
|
|
self.api_key = api_key
|
|||
|
|
self.base_url = (base_url or "https://dify.example.com").rstrip("/")
|
|||
|
|
self.user_id = user_id or "default-user"
|
|||
|
|
|
|||
|
|
def _upload_file(self, image_path: Path) -> str:
|
|||
|
|
"""上传图片到 Dify,返回文件ID (UUID)"""
|
|||
|
|
url = f"{self.base_url}/v1/files/upload"
|
|||
|
|
headers = {"Authorization": f"Bearer {self.api_key}"}
|
|||
|
|
data = {"user": self.user_id}
|
|||
|
|
|
|||
|
|
# 使用 multipart/form-data 上传
|
|||
|
|
mime_type, _ = mimetypes.guess_type(image_path.name)
|
|||
|
|
with open(image_path, "rb") as f:
|
|||
|
|
files = {"file": (image_path.name, f, mime_type or "image/jpeg")}
|
|||
|
|
status_code, response_text = http_post_multipart(
|
|||
|
|
url, headers, data, files, self.timeout
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
if status_code >= 400:
|
|||
|
|
raise RuntimeError(f"Dify 文件上传失败 ({status_code}): {response_text}")
|
|||
|
|
|
|||
|
|
response_json = json.loads(response_text)
|
|||
|
|
return response_json["id"] # 返回文件UUID
|
|||
|
|
|
|||
|
|
def analyze_image(self, image_path: Path, prompt: str) -> RecognitionResult:
|
|||
|
|
"""重写完整流程:上传文件 → 发送工作流请求"""
|
|||
|
|
start = time.time()
|
|||
|
|
analyzed_at = datetime.now().isoformat()
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
# Step 1: 上传图片
|
|||
|
|
file_id = self._upload_file(image_path)
|
|||
|
|
|
|||
|
|
# Step 2: 发送工作流请求(使用 /v1/workflows/run)
|
|||
|
|
url = f"{self.base_url}/v1/workflows/run"
|
|||
|
|
headers = {
|
|||
|
|
"Authorization": f"Bearer {self.api_key}",
|
|||
|
|
"Content-Type": "application/json"
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
# 构建 inputs
|
|||
|
|
inputs = {
|
|||
|
|
"prompt01": prompt, # 段落类型,无字符限制
|
|||
|
|
"pho01": { # Workflow 需要的图片参数
|
|||
|
|
"type": "image",
|
|||
|
|
"transfer_method": "local_file",
|
|||
|
|
"upload_file_id": file_id
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
payload = {
|
|||
|
|
"inputs": inputs,
|
|||
|
|
"response_mode": "blocking",
|
|||
|
|
"user": self.user_id
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
status_code, response_text = http_post_json(
|
|||
|
|
url, headers, payload, self.timeout
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
if status_code >= 400:
|
|||
|
|
raise RuntimeError(f"Dify 工作流失败 ({status_code}): {response_text}")
|
|||
|
|
|
|||
|
|
# Step 3: 提取返回的 JSON 内容(Workflow 返回在 data.outputs.text)
|
|||
|
|
response_json = json.loads(response_text)
|
|||
|
|
# Workflow API 返回格式: {"data": {"outputs": {"text": "..."}}}
|
|||
|
|
raw_text = response_json.get("data", {}).get("outputs", {}).get("text", "")
|
|||
|
|
|
|||
|
|
except Exception as exc:
|
|||
|
|
latency = time.time() - start
|
|||
|
|
return RecognitionResult(
|
|||
|
|
image_name=image_path.name,
|
|||
|
|
image_path=str(image_path),
|
|||
|
|
api_type=self.api_type,
|
|||
|
|
model_name=self.model,
|
|||
|
|
latency_seconds=latency,
|
|||
|
|
analyzed_at=analyzed_at,
|
|||
|
|
raw_response="",
|
|||
|
|
error=str(exc),
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# Step 4: 解析 JSON 结果(复用现有逻辑)
|
|||
|
|
latency = time.time() - start
|
|||
|
|
parsed, parse_error = extract_json_from_response(raw_text)
|
|||
|
|
|
|||
|
|
detected_text = normalize_list(parsed.get("detected_text")) if parsed else []
|
|||
|
|
detected_objects = normalize_list(parsed.get("detected_objects")) if parsed else []
|
|||
|
|
sensitive_items = normalize_list(parsed.get("sensitive_items")) if parsed else []
|
|||
|
|
summary = (parsed or {}).get("summary", "")
|
|||
|
|
confidence = (parsed or {}).get("confidence", "")
|
|||
|
|
|
|||
|
|
return RecognitionResult(
|
|||
|
|
image_name=image_path.name,
|
|||
|
|
image_path=str(image_path),
|
|||
|
|
detected_text=detected_text,
|
|||
|
|
detected_objects=detected_objects,
|
|||
|
|
sensitive_items=sensitive_items,
|
|||
|
|
summary=str(summary),
|
|||
|
|
confidence=str(confidence),
|
|||
|
|
raw_response=raw_text,
|
|||
|
|
api_type=self.api_type,
|
|||
|
|
model_name=self.model,
|
|||
|
|
latency_seconds=latency,
|
|||
|
|
analyzed_at=analyzed_at,
|
|||
|
|
error=parse_error,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ----------- 主执行逻辑 -----------
|
|||
|
|
def list_image_files(image_dir: Path, recursive: bool = False) -> List[Path]:
|
|||
|
|
"""列出待处理的图片文件"""
|
|||
|
|
if recursive:
|
|||
|
|
candidates = (p for p in image_dir.rglob("*") if p.is_file())
|
|||
|
|
else:
|
|||
|
|
candidates = (p for p in image_dir.iterdir() if p.is_file())
|
|||
|
|
|
|||
|
|
files = [p for p in candidates if p.suffix.lower() in SUPPORTED_EXTENSIONS]
|
|||
|
|
return sorted(files)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def create_client(args: argparse.Namespace) -> BaseVisionClient:
|
|||
|
|
"""根据用户配置创建合适的客户端"""
|
|||
|
|
api_type = (args.api_type or os.getenv("LLM_API_TYPE") or "openai").lower()
|
|||
|
|
model = args.model or os.getenv("OPENAI_MODEL") or "gpt-4o-mini"
|
|||
|
|
|
|||
|
|
if api_type == "mock" or args.mock:
|
|||
|
|
return MockVisionClient()
|
|||
|
|
|
|||
|
|
if api_type == "openai":
|
|||
|
|
model = args.model or os.getenv("OPENAI_MODEL") or "gpt-4o-mini"
|
|||
|
|
api_key = os.getenv("OPENAI_API_KEY", "")
|
|||
|
|
base_url = os.getenv("OPENAI_BASE_URL")
|
|||
|
|
return OpenAIVisionClient(api_key=api_key, model=model, timeout=args.timeout, base_url=base_url)
|
|||
|
|
|
|||
|
|
if api_type == "dmx":
|
|||
|
|
model = args.model or os.getenv("DMX_MODEL") or "gpt-5-mini"
|
|||
|
|
api_key = os.getenv("DMX_API_KEY", "")
|
|||
|
|
base_url = os.getenv("DMX_BASE_URL")
|
|||
|
|
return DMXVisionClient(api_key=api_key, model=model, timeout=args.timeout, base_url=base_url)
|
|||
|
|
|
|||
|
|
if api_type == "anthropic":
|
|||
|
|
model = args.model or os.getenv("ANTHROPIC_MODEL") or "claude-3-5-sonnet-20241022"
|
|||
|
|
api_key = os.getenv("ANTHROPIC_API_KEY", "")
|
|||
|
|
base_url = os.getenv("ANTHROPIC_BASE_URL")
|
|||
|
|
return AnthropicVisionClient(api_key=api_key, model=model, timeout=args.timeout, base_url=base_url)
|
|||
|
|
|
|||
|
|
if api_type == "dify":
|
|||
|
|
model = args.model or os.getenv("DIFY_MODEL") or "dify-chatflow"
|
|||
|
|
api_key = os.getenv("DIFY_API_KEY", "")
|
|||
|
|
base_url = os.getenv("DIFY_BASE_URL")
|
|||
|
|
user_id = os.getenv("DIFY_USER_ID") or "default-user"
|
|||
|
|
return DifyVisionClient(api_key=api_key, model=model,
|
|||
|
|
timeout=args.timeout, base_url=base_url,
|
|||
|
|
user_id=user_id)
|
|||
|
|
|
|||
|
|
raise ValueError(f"不支持的 api_type: {api_type}")
|
|||
|
|
|
|||
|
|
|
|||
|
|
def load_prompt(prompt_file: Optional[str]) -> str:
|
|||
|
|
if prompt_file:
|
|||
|
|
return Path(prompt_file).read_text(encoding="utf-8").strip()
|
|||
|
|
env_prompt = os.getenv("VISION_ANALYSIS_PROMPT")
|
|||
|
|
return (env_prompt or DEFAULT_ANALYSIS_PROMPT).strip()
|
|||
|
|
|
|||
|
|
|
|||
|
|
def run_batch_recognition(args: argparse.Namespace) -> List[RecognitionResult]:
|
|||
|
|
image_dir = Path(args.image_dir).resolve()
|
|||
|
|
if not image_dir.exists():
|
|||
|
|
raise FileNotFoundError(f"图片目录不存在: {image_dir}")
|
|||
|
|
|
|||
|
|
output_path = Path(args.output).resolve()
|
|||
|
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|||
|
|
|
|||
|
|
prompt = load_prompt(args.prompt_file)
|
|||
|
|
client = create_client(args)
|
|||
|
|
|
|||
|
|
images = list_image_files(image_dir, recursive=args.recursive)
|
|||
|
|
if args.offset:
|
|||
|
|
images = images[args.offset :]
|
|||
|
|
if args.limit:
|
|||
|
|
images = images[: args.limit]
|
|||
|
|
|
|||
|
|
if not images:
|
|||
|
|
print("未找到任何图片文件。")
|
|||
|
|
return []
|
|||
|
|
|
|||
|
|
max_workers = max(1, args.max_workers)
|
|||
|
|
debug = args.debug
|
|||
|
|
|
|||
|
|
# 简化的输出头部
|
|||
|
|
if not debug:
|
|||
|
|
print(f"- 开始处理 {len(images)} 张图片 | API: {client.api_type} ({client.model})")
|
|||
|
|
else:
|
|||
|
|
print("=" * 60)
|
|||
|
|
print(f"图片目录: {image_dir}")
|
|||
|
|
print(f"输出文件: {output_path}")
|
|||
|
|
print(f"匹配到 {len(images)} 张图片,将使用 {client.api_type} ({client.model}) 执行识别")
|
|||
|
|
print(f"并发 worker 数: {max_workers}")
|
|||
|
|
print("=" * 60)
|
|||
|
|
|
|||
|
|
results: List[RecognitionResult] = []
|
|||
|
|
total_images = len(images)
|
|||
|
|
|
|||
|
|
if max_workers == 1:
|
|||
|
|
# 串行处理 - 使用进度条
|
|||
|
|
pbar = tqdm(images, desc="处理进度", unit="图", disable=debug)
|
|||
|
|
for idx, image_path in enumerate(pbar, 1):
|
|||
|
|
if debug:
|
|||
|
|
print(f"[{idx}/{total_images}] 识别 {image_path.name} ...")
|
|||
|
|
|
|||
|
|
result = client.analyze_image(image_path, prompt)
|
|||
|
|
results.append(result)
|
|||
|
|
|
|||
|
|
if debug:
|
|||
|
|
if result.is_success:
|
|||
|
|
print(
|
|||
|
|
f" ✓ 成功,文字 {len(result.detected_text)} 行,物品 {len(result.detected_objects)} 项,"
|
|||
|
|
f"耗时 {result.latency_seconds:.1f}s"
|
|||
|
|
)
|
|||
|
|
else:
|
|||
|
|
print(f" ✗ 失败: {result.error}")
|
|||
|
|
else:
|
|||
|
|
# 精简输出:更新进度条描述
|
|||
|
|
status = "✓" if result.is_success else "✗"
|
|||
|
|
pbar.set_postfix_str(f"{status} {image_path.name[:20]}...")
|
|||
|
|
|
|||
|
|
if args.sleep > 0:
|
|||
|
|
time.sleep(args.sleep)
|
|||
|
|
pbar.close()
|
|||
|
|
else:
|
|||
|
|
if args.sleep > 0:
|
|||
|
|
print("提示: 并发模式下忽略 --sleep 节流参数。")
|
|||
|
|
|
|||
|
|
result_lookup: Dict[Path, RecognitionResult] = {}
|
|||
|
|
|
|||
|
|
# 并发处理 - 使用进度条
|
|||
|
|
pbar = tqdm(total=total_images, desc="处理进度", unit="图", disable=debug)
|
|||
|
|
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
|||
|
|
future_map = {
|
|||
|
|
executor.submit(client.analyze_image, image_path, prompt): image_path
|
|||
|
|
for image_path in images
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
for completed_idx, future in enumerate(as_completed(future_map), 1):
|
|||
|
|
image_path = future_map[future]
|
|||
|
|
result = future.result()
|
|||
|
|
result_lookup[image_path] = result
|
|||
|
|
|
|||
|
|
if debug:
|
|||
|
|
print(f"[{completed_idx}/{total_images}] 识别 {image_path.name} ...")
|
|||
|
|
if result.is_success:
|
|||
|
|
print(
|
|||
|
|
f" ✓ 成功,文字 {len(result.detected_text)} 行,物品 {len(result.detected_objects)} 项,"
|
|||
|
|
f"耗时 {result.latency_seconds:.1f}s"
|
|||
|
|
)
|
|||
|
|
else:
|
|||
|
|
print(f" ✗ 失败: {result.error}")
|
|||
|
|
else:
|
|||
|
|
# 精简输出:更新进度条
|
|||
|
|
status = "✓" if result.is_success else "✗"
|
|||
|
|
pbar.set_postfix_str(f"{status} {image_path.name[:20]}...")
|
|||
|
|
pbar.update(1)
|
|||
|
|
|
|||
|
|
pbar.close()
|
|||
|
|
results = [result_lookup[path] for path in images]
|
|||
|
|
|
|||
|
|
save_results(results, output_path)
|
|||
|
|
return results
|
|||
|
|
|
|||
|
|
|
|||
|
|
def save_results(results: Sequence[RecognitionResult], output_path: Path) -> None:
|
|||
|
|
if not results:
|
|||
|
|
print("没有可保存的结果。")
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
df = pd.DataFrame([res.to_row() for res in results])
|
|||
|
|
df.to_excel(output_path, index=False, engine="openpyxl")
|
|||
|
|
success_count = sum(1 for res in results if res.is_success)
|
|||
|
|
print(f"\n识别完成,成功 {success_count}/{len(results)},结果已写入 {output_path}")
|
|||
|
|
|
|||
|
|
|
|||
|
|
def parse_args(argv: Optional[Sequence[str]] = None) -> argparse.Namespace:
|
|||
|
|
parser = argparse.ArgumentParser(
|
|||
|
|
description="批量图片多模态识别(文字 + 物品)",
|
|||
|
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
|||
|
|
)
|
|||
|
|
base_dir = Path(__file__).resolve().parent.parent
|
|||
|
|
default_image_dir = base_dir / "data" / "images"
|
|||
|
|
default_output = base_dir / "data" / "output" / "image_recognition_results.xlsx"
|
|||
|
|
default_env_file = base_dir / ".env"
|
|||
|
|
|
|||
|
|
parser.add_argument("--image-dir", type=str, default=str(default_image_dir), help="图片目录")
|
|||
|
|
parser.add_argument("--output", type=str, default=str(default_output), help="识别结果输出文件(Excel)")
|
|||
|
|
parser.add_argument("--api-type", choices=["openai", "anthropic", "dmx", "dify", "mock"], help="选择使用的 API 类型")
|
|||
|
|
parser.add_argument("--model", type=str, help="指定模型名称(默认读取环境变量)")
|
|||
|
|
parser.add_argument("--prompt-file", type=str, help="自定义提示词文件")
|
|||
|
|
parser.add_argument("--recursive", action="store_true", help="递归搜索子目录")
|
|||
|
|
parser.add_argument("--limit", type=int, help="限制最大处理图片数")
|
|||
|
|
parser.add_argument("--offset", type=int, default=0, help="跳过前 N 张图片")
|
|||
|
|
parser.add_argument("--sleep", type=float, default=DEFAULT_SLEEP_SECONDS, help="每次请求后的等待秒数")
|
|||
|
|
parser.add_argument("--timeout", type=int, default=DEFAULT_TIMEOUT, help="HTTP 请求超时")
|
|||
|
|
parser.add_argument("--mock", action="store_true", help="强制使用模拟模式(无需 API)")
|
|||
|
|
parser.add_argument("--env-file", type=str, default=str(default_env_file), help="包含 API Key / Base URL / Model 配置的 .env 文件")
|
|||
|
|
parser.add_argument("--max-workers", type=int, default=1, help="并行识别线程数")
|
|||
|
|
parser.add_argument("--debug", action="store_true", help="启用详细输出(默认为精简模式)")
|
|||
|
|
|
|||
|
|
return parser.parse_args(argv)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def main(argv: Optional[Sequence[str]] = None) -> None:
|
|||
|
|
args = parse_args(argv)
|
|||
|
|
load_env_file(args.env_file)
|
|||
|
|
try:
|
|||
|
|
run_batch_recognition(args)
|
|||
|
|
except Exception as exc:
|
|||
|
|
print(f"执行失败: {exc}")
|
|||
|
|
sys.exit(1)
|
|||
|
|
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
main()
|