init: first upload
This commit is contained in:
776
scripts/image_batch_recognizer.py
Normal file
776
scripts/image_batch_recognizer.py
Normal file
@@ -0,0 +1,776 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
批量图片识别脚本
|
||||
|
||||
功能:
|
||||
1. 遍历指定目录下的图片文件
|
||||
2. 将图片上传至大模型 API(OpenAI、Anthropic,或本地模拟模式)
|
||||
3. 输出包含图片中文本、物品描述、风险概述的 Excel 报表
|
||||
|
||||
使用示例:
|
||||
cd scripts
|
||||
python3 image_batch_recognizer.py --limit 5 --api-type openai
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import base64
|
||||
import json
|
||||
import mimetypes
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from typing import Any, Dict, List, Optional, Sequence, Tuple
|
||||
|
||||
import pandas as pd
|
||||
|
||||
try:
|
||||
import requests # type: ignore
|
||||
# 禁用 SSL 警告(用于自签名证书)
|
||||
import urllib3
|
||||
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
||||
except ImportError: # pragma: no cover - requests 是可选依赖
|
||||
requests = None
|
||||
|
||||
try:
|
||||
from tqdm import tqdm # type: ignore
|
||||
except ImportError:
|
||||
# 如果没有安装 tqdm,提供一个简单的替代
|
||||
class tqdm:
|
||||
def __init__(self, iterable=None, total=None, desc=None, **kwargs):
|
||||
self.iterable = iterable
|
||||
self.total = total or (len(iterable) if iterable else 0)
|
||||
self.desc = desc
|
||||
self.n = 0
|
||||
|
||||
def __iter__(self):
|
||||
for item in self.iterable:
|
||||
yield item
|
||||
self.n += 1
|
||||
|
||||
def update(self, n=1):
|
||||
self.n += n
|
||||
|
||||
def set_postfix_str(self, s):
|
||||
# 简化版本,不做任何操作
|
||||
pass
|
||||
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
SEPARATOR = "|||"
|
||||
DEFAULT_SLEEP_SECONDS = 0.0
|
||||
DEFAULT_TIMEOUT = 90
|
||||
DEFAULT_MAX_TOKENS = 800
|
||||
SUPPORTED_EXTENSIONS = {".jpg", ".jpeg", ".png", ".webp"}
|
||||
|
||||
DEFAULT_ANALYSIS_PROMPT = """
|
||||
你是一名涉毒风控场景下的多模态分析师。请识别我提供的图片内容,并输出 JSON。
|
||||
JSON 字段要求:
|
||||
{
|
||||
"detected_text": ["按行列出的文字内容,若无文字返回空数组"],
|
||||
"detected_objects": ["出现的主要物品或场景要素,中文描述"],
|
||||
"sensitive_items": ["可疑化学品、药物、工具等,如无则为空数组"],
|
||||
"summary": "2-3 句话概括图片与潜在风险,中文",
|
||||
"confidence": "High | Medium | Low"
|
||||
}
|
||||
只返回 JSON,不要额外说明。
|
||||
""".strip()
|
||||
|
||||
|
||||
# ----------- 数据结构 -----------
|
||||
@dataclass
|
||||
class RecognitionResult:
|
||||
image_name: str
|
||||
image_path: str
|
||||
detected_text: List[str] = field(default_factory=list)
|
||||
detected_objects: List[str] = field(default_factory=list)
|
||||
sensitive_items: List[str] = field(default_factory=list)
|
||||
summary: str = ""
|
||||
confidence: str = ""
|
||||
raw_response: str = ""
|
||||
api_type: str = ""
|
||||
model_name: str = ""
|
||||
latency_seconds: float = 0.0
|
||||
analyzed_at: str = ""
|
||||
error: Optional[str] = None
|
||||
|
||||
def to_row(self) -> Dict[str, Any]:
|
||||
"""转换为写入 Excel 的行数据"""
|
||||
return {
|
||||
"image_name": self.image_name,
|
||||
"image_path": self.image_path,
|
||||
"detected_text": "\n".join(self.detected_text).strip(),
|
||||
"detected_objects": SEPARATOR.join(self.detected_objects),
|
||||
"sensitive_items": SEPARATOR.join(self.sensitive_items),
|
||||
"summary": self.summary,
|
||||
"confidence": self.confidence,
|
||||
"latency_seconds": round(self.latency_seconds, 2),
|
||||
"api_type": self.api_type,
|
||||
"model": self.model_name,
|
||||
"analyzed_at": self.analyzed_at,
|
||||
"error": self.error or "",
|
||||
"raw_response": self.raw_response,
|
||||
}
|
||||
|
||||
@property
|
||||
def is_success(self) -> bool:
|
||||
return self.error is None
|
||||
|
||||
|
||||
# ----------- 工具函数 -----------
|
||||
def encode_image(image_path: Path) -> Tuple[str, str]:
|
||||
"""以 base64 形式读取图片,同时推断 MIME 类型"""
|
||||
data = image_path.read_bytes()
|
||||
encoded = base64.b64encode(data).decode("utf-8")
|
||||
mime_type, _ = mimetypes.guess_type(image_path.name)
|
||||
return encoded, mime_type or "image/jpeg"
|
||||
|
||||
|
||||
def extract_json_from_response(text: str) -> Tuple[Optional[Dict[str, Any]], Optional[str]]:
|
||||
"""
|
||||
从模型返回的文本中提取 JSON。
|
||||
优先尝试直接解析,失败后截取首尾花括号进行再次解析。
|
||||
"""
|
||||
if not text:
|
||||
return None, "空响应"
|
||||
|
||||
text = text.strip()
|
||||
try:
|
||||
return json.loads(text), None
|
||||
except json.JSONDecodeError:
|
||||
start = text.find("{")
|
||||
end = text.rfind("}")
|
||||
if start != -1 and end != -1 and end > start:
|
||||
candidate = text[start:end + 1]
|
||||
try:
|
||||
return json.loads(candidate), None
|
||||
except json.JSONDecodeError as exc:
|
||||
return None, f"JSON 解析失败: {exc}"
|
||||
return None, "响应中未找到 JSON 对象"
|
||||
|
||||
|
||||
def normalize_list(value: Any) -> List[str]:
|
||||
"""将模型返回的字段整理为字符串列表"""
|
||||
if value is None:
|
||||
return []
|
||||
if isinstance(value, str):
|
||||
return [part.strip() for part in value.split(SEPARATOR) if part.strip()]
|
||||
if isinstance(value, (list, tuple, set)):
|
||||
return [str(item).strip() for item in value if str(item).strip()]
|
||||
return [str(value).strip()]
|
||||
|
||||
|
||||
def http_post_json(url: str, headers: Dict[str, str], payload: Dict[str, Any], timeout: int) -> Tuple[int, str]:
|
||||
"""发送 JSON POST 请求(优先使用 requests,回退到 urllib)"""
|
||||
body = json.dumps(payload)
|
||||
headers = {**headers, "Content-Type": "application/json"}
|
||||
|
||||
if requests:
|
||||
# 禁用 SSL 证书验证(用于自签名证书)
|
||||
response = requests.post(url, headers=headers, data=body, timeout=timeout, verify=False)
|
||||
return response.status_code, response.text
|
||||
|
||||
# requests 不可用时使用 urllib
|
||||
import urllib.error
|
||||
import urllib.request
|
||||
|
||||
req = urllib.request.Request(url, data=body.encode("utf-8"), headers=headers, method="POST")
|
||||
try:
|
||||
with urllib.request.urlopen(req, timeout=timeout) as resp:
|
||||
response_body = resp.read().decode("utf-8")
|
||||
return resp.getcode(), response_body
|
||||
except urllib.error.HTTPError as exc: # pragma: no cover - 网络错误场景
|
||||
return exc.code, exc.read().decode("utf-8")
|
||||
|
||||
|
||||
def http_post_multipart(url: str, headers: Dict[str, str],
|
||||
data: Dict[str, str], files: Dict[str, Any],
|
||||
timeout: int) -> Tuple[int, str]:
|
||||
"""发送 multipart/form-data POST 请求(用于 Dify 文件上传)"""
|
||||
if not requests:
|
||||
raise RuntimeError(
|
||||
"Dify 文件上传需要 requests 库,请安装: pip install requests"
|
||||
)
|
||||
|
||||
# 禁用 SSL 证书验证(用于自签名证书)
|
||||
response = requests.post(url, headers=headers, data=data,
|
||||
files=files, timeout=timeout, verify=False)
|
||||
return response.status_code, response.text
|
||||
|
||||
|
||||
def build_endpoint(base_url: Optional[str], path: str, default_base: str) -> str:
|
||||
"""根据 base_url 和 path 生成完整 API 地址"""
|
||||
base = (base_url or default_base).rstrip("/")
|
||||
path = path if path.startswith("/") else f"/{path}"
|
||||
return f"{base}{path}"
|
||||
|
||||
|
||||
def load_env_file(env_file: Optional[str]) -> Optional[Path]:
|
||||
"""从 .env 文件加载环境变量"""
|
||||
if not env_file:
|
||||
return None
|
||||
|
||||
env_path = Path(env_file).expanduser()
|
||||
if not env_path.exists():
|
||||
return None
|
||||
|
||||
for raw_line in env_path.read_text(encoding="utf-8").splitlines():
|
||||
line = raw_line.strip()
|
||||
if not line or line.startswith("#"):
|
||||
continue
|
||||
if line.startswith("export "):
|
||||
line = line[len("export "):].strip()
|
||||
if "=" not in line:
|
||||
continue
|
||||
key, value = line.split("=", 1)
|
||||
key = key.strip()
|
||||
value = value.strip().strip('"').strip("'")
|
||||
if not key:
|
||||
continue
|
||||
os.environ[key] = value
|
||||
|
||||
print(f"已从 {env_path} 加载环境变量")
|
||||
return env_path
|
||||
|
||||
|
||||
# ----------- 基础客户端 -----------
|
||||
class BaseVisionClient:
|
||||
def __init__(self, model: str, api_type: str, timeout: int = DEFAULT_TIMEOUT):
|
||||
self.model = model
|
||||
self.api_type = api_type
|
||||
self.timeout = timeout
|
||||
|
||||
def analyze_image(self, image_path: Path, prompt: str) -> RecognitionResult:
|
||||
encoded, mime_type = encode_image(image_path)
|
||||
start = time.time()
|
||||
analyzed_at = datetime.now().isoformat()
|
||||
|
||||
try:
|
||||
raw_text = self._send_request(
|
||||
image_base64=encoded,
|
||||
mime_type=mime_type,
|
||||
prompt=prompt,
|
||||
image_name=image_path.name,
|
||||
)
|
||||
except Exception as exc: # pragma: no cover - 网络错误较难稳定复现
|
||||
latency = time.time() - start
|
||||
return RecognitionResult(
|
||||
image_name=image_path.name,
|
||||
image_path=str(image_path),
|
||||
api_type=self.api_type,
|
||||
model_name=self.model,
|
||||
latency_seconds=latency,
|
||||
analyzed_at=analyzed_at,
|
||||
raw_response="",
|
||||
error=str(exc),
|
||||
)
|
||||
|
||||
latency = time.time() - start
|
||||
parsed, parse_error = extract_json_from_response(raw_text)
|
||||
|
||||
detected_text = normalize_list(parsed.get("detected_text")) if parsed else []
|
||||
detected_objects = normalize_list(parsed.get("detected_objects")) if parsed else []
|
||||
sensitive_items = normalize_list(parsed.get("sensitive_items")) if parsed else []
|
||||
summary = (parsed or {}).get("summary", "")
|
||||
confidence = (parsed or {}).get("confidence", "")
|
||||
|
||||
return RecognitionResult(
|
||||
image_name=image_path.name,
|
||||
image_path=str(image_path),
|
||||
detected_text=detected_text,
|
||||
detected_objects=detected_objects,
|
||||
sensitive_items=sensitive_items,
|
||||
summary=str(summary),
|
||||
confidence=str(confidence),
|
||||
raw_response=raw_text,
|
||||
api_type=self.api_type,
|
||||
model_name=self.model,
|
||||
latency_seconds=latency,
|
||||
analyzed_at=analyzed_at,
|
||||
error=parse_error,
|
||||
)
|
||||
|
||||
def _send_request(self, image_base64: str, mime_type: str, prompt: str, image_name: str) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
# ----------- OpenAI 客户端 -----------
|
||||
class OpenAIVisionClient(BaseVisionClient):
|
||||
def __init__(self, api_key: str, model: str, timeout: int, base_url: Optional[str] = None):
|
||||
super().__init__(model=model, api_type="openai", timeout=timeout)
|
||||
if not api_key:
|
||||
raise ValueError("OPENAI_API_KEY 未设置")
|
||||
self.api_key = api_key
|
||||
self.endpoint = build_endpoint(base_url, "/v1/chat/completions", "https://api.openai.com")
|
||||
|
||||
def _send_request(self, image_base64: str, mime_type: str, prompt: str, image_name: str) -> str:
|
||||
image_url = f"data:{mime_type};base64,{image_base64}"
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"temperature": 0,
|
||||
"max_tokens": DEFAULT_MAX_TOKENS,
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a multimodal analyst that returns structured JSON responses in Chinese.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": f"{prompt}\n图片文件名: {image_name}"},
|
||||
{"type": "image_url", "image_url": {"url": image_url}},
|
||||
],
|
||||
},
|
||||
],
|
||||
}
|
||||
headers = {"Authorization": f"Bearer {self.api_key}"}
|
||||
status_code, response_text = http_post_json(self.endpoint, headers, payload, self.timeout)
|
||||
|
||||
if status_code >= 400:
|
||||
raise RuntimeError(f"OpenAI API 调用失败 ({status_code}): {response_text}")
|
||||
|
||||
response_json = json.loads(response_text)
|
||||
content = response_json["choices"][0]["message"]["content"]
|
||||
if isinstance(content, list):
|
||||
text = "\n".join(part.get("text", "") for part in content if isinstance(part, dict))
|
||||
else:
|
||||
text = content
|
||||
return text.strip()
|
||||
|
||||
|
||||
# ----------- DMX (文档上传方式) 客户端 -----------
|
||||
class DMXVisionClient(BaseVisionClient):
|
||||
def __init__(self, api_key: str, model: str, timeout: int, base_url: Optional[str] = None):
|
||||
super().__init__(model=model, api_type="dmx", timeout=timeout)
|
||||
if not api_key:
|
||||
raise ValueError("DMX_API_KEY 未设置")
|
||||
self.api_key = api_key
|
||||
self.endpoint = build_endpoint(base_url, "/v1/chat/completions", "https://www.dmxapi.cn")
|
||||
|
||||
def _send_request(self, image_base64: str, mime_type: str, prompt: str, image_name: str) -> str:
|
||||
image_url = f"data:{mime_type};base64,{image_base64}"
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"temperature": 0.1,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": f"{prompt}\n图片文件名: {image_name}"},
|
||||
{"type": "image_url", "image_url": {"url": image_url}},
|
||||
],
|
||||
}
|
||||
],
|
||||
}
|
||||
headers = {"Authorization": f"Bearer {self.api_key}"}
|
||||
status_code, response_text = http_post_json(self.endpoint, headers, payload, self.timeout)
|
||||
if status_code >= 400:
|
||||
raise RuntimeError(f"DMX API 调用失败 ({status_code}): {response_text}")
|
||||
|
||||
response_json = json.loads(response_text)
|
||||
content = response_json["choices"][0]["message"]["content"]
|
||||
if isinstance(content, list):
|
||||
text = "\n".join(part.get("text", "") for part in content if isinstance(part, dict))
|
||||
else:
|
||||
text = content
|
||||
return text.strip()
|
||||
|
||||
|
||||
# ----------- Anthropic 客户端 -----------
|
||||
class AnthropicVisionClient(BaseVisionClient):
|
||||
def __init__(self, api_key: str, model: str, timeout: int, base_url: Optional[str] = None):
|
||||
super().__init__(model=model, api_type="anthropic", timeout=timeout)
|
||||
if not api_key:
|
||||
raise ValueError("ANTHROPIC_API_KEY 未设置")
|
||||
self.api_key = api_key
|
||||
self.endpoint = build_endpoint(base_url, "/v1/messages", "https://api.anthropic.com")
|
||||
|
||||
def _send_request(self, image_base64: str, mime_type: str, prompt: str, image_name: str) -> str:
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"max_tokens": DEFAULT_MAX_TOKENS,
|
||||
"system": "You are a multimodal analyst that responds with JSON only.",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": f"{prompt}\n图片文件名: {image_name}"},
|
||||
{
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": mime_type,
|
||||
"data": image_base64,
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
}
|
||||
headers = {
|
||||
"x-api-key": self.api_key,
|
||||
"anthropic-version": "2023-06-01",
|
||||
}
|
||||
status_code, response_text = http_post_json(self.endpoint, headers, payload, self.timeout)
|
||||
if status_code >= 400:
|
||||
raise RuntimeError(f"Anthropic API 调用失败 ({status_code}): {response_text}")
|
||||
|
||||
response_json = json.loads(response_text)
|
||||
text_parts = [part.get("text", "") for part in response_json.get("content", []) if part.get("type") == "text"]
|
||||
return "\n".join(text_parts).strip()
|
||||
|
||||
|
||||
# ----------- 模拟客户端(无 API 时使用) -----------
|
||||
class MockVisionClient(BaseVisionClient):
|
||||
def __init__(self):
|
||||
super().__init__(model="mock", api_type="mock", timeout=1)
|
||||
|
||||
def _send_request(self, image_base64: str, mime_type: str, prompt: str, image_name: str) -> str:
|
||||
del image_base64, mime_type, prompt # 未使用
|
||||
fake_json = {
|
||||
"detected_text": ["(模拟模式)未调用真实 API"],
|
||||
"detected_objects": ["sample-object-from-filename"],
|
||||
"sensitive_items": [],
|
||||
"summary": f"Mock 模式:演示 {image_name} 的返回格式。",
|
||||
"confidence": "Low",
|
||||
}
|
||||
return json.dumps(fake_json, ensure_ascii=False)
|
||||
|
||||
|
||||
# ----------- Dify Chatflow 客户端 -----------
|
||||
class DifyVisionClient(BaseVisionClient):
|
||||
"""Dify Chatflow API client for vision analysis"""
|
||||
|
||||
def __init__(self, api_key: str, model: str, timeout: int,
|
||||
base_url: Optional[str] = None, user_id: Optional[str] = None):
|
||||
super().__init__(model=model, api_type="dify", timeout=timeout)
|
||||
if not api_key:
|
||||
raise ValueError("DIFY_API_KEY 未设置")
|
||||
self.api_key = api_key
|
||||
self.base_url = (base_url or "https://dify.example.com").rstrip("/")
|
||||
self.user_id = user_id or "default-user"
|
||||
|
||||
def _upload_file(self, image_path: Path) -> str:
|
||||
"""上传图片到 Dify,返回文件ID (UUID)"""
|
||||
url = f"{self.base_url}/v1/files/upload"
|
||||
headers = {"Authorization": f"Bearer {self.api_key}"}
|
||||
data = {"user": self.user_id}
|
||||
|
||||
# 使用 multipart/form-data 上传
|
||||
mime_type, _ = mimetypes.guess_type(image_path.name)
|
||||
with open(image_path, "rb") as f:
|
||||
files = {"file": (image_path.name, f, mime_type or "image/jpeg")}
|
||||
status_code, response_text = http_post_multipart(
|
||||
url, headers, data, files, self.timeout
|
||||
)
|
||||
|
||||
if status_code >= 400:
|
||||
raise RuntimeError(f"Dify 文件上传失败 ({status_code}): {response_text}")
|
||||
|
||||
response_json = json.loads(response_text)
|
||||
return response_json["id"] # 返回文件UUID
|
||||
|
||||
def analyze_image(self, image_path: Path, prompt: str) -> RecognitionResult:
|
||||
"""重写完整流程:上传文件 → 发送工作流请求"""
|
||||
start = time.time()
|
||||
analyzed_at = datetime.now().isoformat()
|
||||
|
||||
try:
|
||||
# Step 1: 上传图片
|
||||
file_id = self._upload_file(image_path)
|
||||
|
||||
# Step 2: 发送工作流请求(使用 /v1/workflows/run)
|
||||
url = f"{self.base_url}/v1/workflows/run"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
# 构建 inputs
|
||||
inputs = {
|
||||
"prompt01": prompt, # 段落类型,无字符限制
|
||||
"pho01": { # Workflow 需要的图片参数
|
||||
"type": "image",
|
||||
"transfer_method": "local_file",
|
||||
"upload_file_id": file_id
|
||||
}
|
||||
}
|
||||
|
||||
payload = {
|
||||
"inputs": inputs,
|
||||
"response_mode": "blocking",
|
||||
"user": self.user_id
|
||||
}
|
||||
|
||||
status_code, response_text = http_post_json(
|
||||
url, headers, payload, self.timeout
|
||||
)
|
||||
|
||||
if status_code >= 400:
|
||||
raise RuntimeError(f"Dify 工作流失败 ({status_code}): {response_text}")
|
||||
|
||||
# Step 3: 提取返回的 JSON 内容(Workflow 返回在 data.outputs.text)
|
||||
response_json = json.loads(response_text)
|
||||
# Workflow API 返回格式: {"data": {"outputs": {"text": "..."}}}
|
||||
raw_text = response_json.get("data", {}).get("outputs", {}).get("text", "")
|
||||
|
||||
except Exception as exc:
|
||||
latency = time.time() - start
|
||||
return RecognitionResult(
|
||||
image_name=image_path.name,
|
||||
image_path=str(image_path),
|
||||
api_type=self.api_type,
|
||||
model_name=self.model,
|
||||
latency_seconds=latency,
|
||||
analyzed_at=analyzed_at,
|
||||
raw_response="",
|
||||
error=str(exc),
|
||||
)
|
||||
|
||||
# Step 4: 解析 JSON 结果(复用现有逻辑)
|
||||
latency = time.time() - start
|
||||
parsed, parse_error = extract_json_from_response(raw_text)
|
||||
|
||||
detected_text = normalize_list(parsed.get("detected_text")) if parsed else []
|
||||
detected_objects = normalize_list(parsed.get("detected_objects")) if parsed else []
|
||||
sensitive_items = normalize_list(parsed.get("sensitive_items")) if parsed else []
|
||||
summary = (parsed or {}).get("summary", "")
|
||||
confidence = (parsed or {}).get("confidence", "")
|
||||
|
||||
return RecognitionResult(
|
||||
image_name=image_path.name,
|
||||
image_path=str(image_path),
|
||||
detected_text=detected_text,
|
||||
detected_objects=detected_objects,
|
||||
sensitive_items=sensitive_items,
|
||||
summary=str(summary),
|
||||
confidence=str(confidence),
|
||||
raw_response=raw_text,
|
||||
api_type=self.api_type,
|
||||
model_name=self.model,
|
||||
latency_seconds=latency,
|
||||
analyzed_at=analyzed_at,
|
||||
error=parse_error,
|
||||
)
|
||||
|
||||
|
||||
# ----------- 主执行逻辑 -----------
|
||||
def list_image_files(image_dir: Path, recursive: bool = False) -> List[Path]:
|
||||
"""列出待处理的图片文件"""
|
||||
if recursive:
|
||||
candidates = (p for p in image_dir.rglob("*") if p.is_file())
|
||||
else:
|
||||
candidates = (p for p in image_dir.iterdir() if p.is_file())
|
||||
|
||||
files = [p for p in candidates if p.suffix.lower() in SUPPORTED_EXTENSIONS]
|
||||
return sorted(files)
|
||||
|
||||
|
||||
def create_client(args: argparse.Namespace) -> BaseVisionClient:
|
||||
"""根据用户配置创建合适的客户端"""
|
||||
api_type = (args.api_type or os.getenv("LLM_API_TYPE") or "openai").lower()
|
||||
model = args.model or os.getenv("OPENAI_MODEL") or "gpt-4o-mini"
|
||||
|
||||
if api_type == "mock" or args.mock:
|
||||
return MockVisionClient()
|
||||
|
||||
if api_type == "openai":
|
||||
model = args.model or os.getenv("OPENAI_MODEL") or "gpt-4o-mini"
|
||||
api_key = os.getenv("OPENAI_API_KEY", "")
|
||||
base_url = os.getenv("OPENAI_BASE_URL")
|
||||
return OpenAIVisionClient(api_key=api_key, model=model, timeout=args.timeout, base_url=base_url)
|
||||
|
||||
if api_type == "dmx":
|
||||
model = args.model or os.getenv("DMX_MODEL") or "gpt-5-mini"
|
||||
api_key = os.getenv("DMX_API_KEY", "")
|
||||
base_url = os.getenv("DMX_BASE_URL")
|
||||
return DMXVisionClient(api_key=api_key, model=model, timeout=args.timeout, base_url=base_url)
|
||||
|
||||
if api_type == "anthropic":
|
||||
model = args.model or os.getenv("ANTHROPIC_MODEL") or "claude-3-5-sonnet-20241022"
|
||||
api_key = os.getenv("ANTHROPIC_API_KEY", "")
|
||||
base_url = os.getenv("ANTHROPIC_BASE_URL")
|
||||
return AnthropicVisionClient(api_key=api_key, model=model, timeout=args.timeout, base_url=base_url)
|
||||
|
||||
if api_type == "dify":
|
||||
model = args.model or os.getenv("DIFY_MODEL") or "dify-chatflow"
|
||||
api_key = os.getenv("DIFY_API_KEY", "")
|
||||
base_url = os.getenv("DIFY_BASE_URL")
|
||||
user_id = os.getenv("DIFY_USER_ID") or "default-user"
|
||||
return DifyVisionClient(api_key=api_key, model=model,
|
||||
timeout=args.timeout, base_url=base_url,
|
||||
user_id=user_id)
|
||||
|
||||
raise ValueError(f"不支持的 api_type: {api_type}")
|
||||
|
||||
|
||||
def load_prompt(prompt_file: Optional[str]) -> str:
|
||||
if prompt_file:
|
||||
return Path(prompt_file).read_text(encoding="utf-8").strip()
|
||||
env_prompt = os.getenv("VISION_ANALYSIS_PROMPT")
|
||||
return (env_prompt or DEFAULT_ANALYSIS_PROMPT).strip()
|
||||
|
||||
|
||||
def run_batch_recognition(args: argparse.Namespace) -> List[RecognitionResult]:
|
||||
image_dir = Path(args.image_dir).resolve()
|
||||
if not image_dir.exists():
|
||||
raise FileNotFoundError(f"图片目录不存在: {image_dir}")
|
||||
|
||||
output_path = Path(args.output).resolve()
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
prompt = load_prompt(args.prompt_file)
|
||||
client = create_client(args)
|
||||
|
||||
images = list_image_files(image_dir, recursive=args.recursive)
|
||||
if args.offset:
|
||||
images = images[args.offset :]
|
||||
if args.limit:
|
||||
images = images[: args.limit]
|
||||
|
||||
if not images:
|
||||
print("未找到任何图片文件。")
|
||||
return []
|
||||
|
||||
max_workers = max(1, args.max_workers)
|
||||
debug = args.debug
|
||||
|
||||
# 简化的输出头部
|
||||
if not debug:
|
||||
print(f"- 开始处理 {len(images)} 张图片 | API: {client.api_type} ({client.model})")
|
||||
else:
|
||||
print("=" * 60)
|
||||
print(f"图片目录: {image_dir}")
|
||||
print(f"输出文件: {output_path}")
|
||||
print(f"匹配到 {len(images)} 张图片,将使用 {client.api_type} ({client.model}) 执行识别")
|
||||
print(f"并发 worker 数: {max_workers}")
|
||||
print("=" * 60)
|
||||
|
||||
results: List[RecognitionResult] = []
|
||||
total_images = len(images)
|
||||
|
||||
if max_workers == 1:
|
||||
# 串行处理 - 使用进度条
|
||||
pbar = tqdm(images, desc="处理进度", unit="图", disable=debug)
|
||||
for idx, image_path in enumerate(pbar, 1):
|
||||
if debug:
|
||||
print(f"[{idx}/{total_images}] 识别 {image_path.name} ...")
|
||||
|
||||
result = client.analyze_image(image_path, prompt)
|
||||
results.append(result)
|
||||
|
||||
if debug:
|
||||
if result.is_success:
|
||||
print(
|
||||
f" ✓ 成功,文字 {len(result.detected_text)} 行,物品 {len(result.detected_objects)} 项,"
|
||||
f"耗时 {result.latency_seconds:.1f}s"
|
||||
)
|
||||
else:
|
||||
print(f" ✗ 失败: {result.error}")
|
||||
else:
|
||||
# 精简输出:更新进度条描述
|
||||
status = "✓" if result.is_success else "✗"
|
||||
pbar.set_postfix_str(f"{status} {image_path.name[:20]}...")
|
||||
|
||||
if args.sleep > 0:
|
||||
time.sleep(args.sleep)
|
||||
pbar.close()
|
||||
else:
|
||||
if args.sleep > 0:
|
||||
print("提示: 并发模式下忽略 --sleep 节流参数。")
|
||||
|
||||
result_lookup: Dict[Path, RecognitionResult] = {}
|
||||
|
||||
# 并发处理 - 使用进度条
|
||||
pbar = tqdm(total=total_images, desc="处理进度", unit="图", disable=debug)
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
future_map = {
|
||||
executor.submit(client.analyze_image, image_path, prompt): image_path
|
||||
for image_path in images
|
||||
}
|
||||
|
||||
for completed_idx, future in enumerate(as_completed(future_map), 1):
|
||||
image_path = future_map[future]
|
||||
result = future.result()
|
||||
result_lookup[image_path] = result
|
||||
|
||||
if debug:
|
||||
print(f"[{completed_idx}/{total_images}] 识别 {image_path.name} ...")
|
||||
if result.is_success:
|
||||
print(
|
||||
f" ✓ 成功,文字 {len(result.detected_text)} 行,物品 {len(result.detected_objects)} 项,"
|
||||
f"耗时 {result.latency_seconds:.1f}s"
|
||||
)
|
||||
else:
|
||||
print(f" ✗ 失败: {result.error}")
|
||||
else:
|
||||
# 精简输出:更新进度条
|
||||
status = "✓" if result.is_success else "✗"
|
||||
pbar.set_postfix_str(f"{status} {image_path.name[:20]}...")
|
||||
pbar.update(1)
|
||||
|
||||
pbar.close()
|
||||
results = [result_lookup[path] for path in images]
|
||||
|
||||
save_results(results, output_path)
|
||||
return results
|
||||
|
||||
|
||||
def save_results(results: Sequence[RecognitionResult], output_path: Path) -> None:
|
||||
if not results:
|
||||
print("没有可保存的结果。")
|
||||
return
|
||||
|
||||
df = pd.DataFrame([res.to_row() for res in results])
|
||||
df.to_excel(output_path, index=False, engine="openpyxl")
|
||||
success_count = sum(1 for res in results if res.is_success)
|
||||
print(f"\n识别完成,成功 {success_count}/{len(results)},结果已写入 {output_path}")
|
||||
|
||||
|
||||
def parse_args(argv: Optional[Sequence[str]] = None) -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="批量图片多模态识别(文字 + 物品)",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
base_dir = Path(__file__).resolve().parent.parent
|
||||
default_image_dir = base_dir / "data" / "images"
|
||||
default_output = base_dir / "data" / "output" / "image_recognition_results.xlsx"
|
||||
default_env_file = base_dir / ".env"
|
||||
|
||||
parser.add_argument("--image-dir", type=str, default=str(default_image_dir), help="图片目录")
|
||||
parser.add_argument("--output", type=str, default=str(default_output), help="识别结果输出文件(Excel)")
|
||||
parser.add_argument("--api-type", choices=["openai", "anthropic", "dmx", "dify", "mock"], help="选择使用的 API 类型")
|
||||
parser.add_argument("--model", type=str, help="指定模型名称(默认读取环境变量)")
|
||||
parser.add_argument("--prompt-file", type=str, help="自定义提示词文件")
|
||||
parser.add_argument("--recursive", action="store_true", help="递归搜索子目录")
|
||||
parser.add_argument("--limit", type=int, help="限制最大处理图片数")
|
||||
parser.add_argument("--offset", type=int, default=0, help="跳过前 N 张图片")
|
||||
parser.add_argument("--sleep", type=float, default=DEFAULT_SLEEP_SECONDS, help="每次请求后的等待秒数")
|
||||
parser.add_argument("--timeout", type=int, default=DEFAULT_TIMEOUT, help="HTTP 请求超时")
|
||||
parser.add_argument("--mock", action="store_true", help="强制使用模拟模式(无需 API)")
|
||||
parser.add_argument("--env-file", type=str, default=str(default_env_file), help="包含 API Key / Base URL / Model 配置的 .env 文件")
|
||||
parser.add_argument("--max-workers", type=int, default=1, help="并行识别线程数")
|
||||
parser.add_argument("--debug", action="store_true", help="启用详细输出(默认为精简模式)")
|
||||
|
||||
return parser.parse_args(argv)
|
||||
|
||||
|
||||
def main(argv: Optional[Sequence[str]] = None) -> None:
|
||||
args = parse_args(argv)
|
||||
load_env_file(args.env_file)
|
||||
try:
|
||||
run_batch_recognition(args)
|
||||
except Exception as exc:
|
||||
print(f"执行失败: {exc}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
778
scripts/keyword_matcher.py
Normal file
778
scripts/keyword_matcher.py
Normal file
@@ -0,0 +1,778 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
多模式关键词匹配工具(重构版)
|
||||
- CAS号识别:专注于 `CAS号` 列,支持多种格式(-, 空格, 无分隔符等)
|
||||
- 模糊识别:对所有候选文本(含CAS)进行容错匹配
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Set, Optional, Tuple
|
||||
|
||||
import pandas as pd
|
||||
|
||||
# 可选依赖
|
||||
try:
|
||||
import ahocorasick
|
||||
HAS_AC = True
|
||||
except ImportError:
|
||||
HAS_AC = False
|
||||
|
||||
|
||||
# ========== 常量定义 ==========
|
||||
SEPARATOR = "|||"
|
||||
MATCH_RESULT_SEPARATOR = " | "
|
||||
PROGRESS_INTERVAL = 1000
|
||||
DEFAULT_FUZZY_THRESHOLD = 85
|
||||
|
||||
# CAS号正则表达式
|
||||
# 匹配格式:2-7位数字 + 分隔符(可选) + 2位数字 + 分隔符(可选) + 1位数字
|
||||
# 支持分隔符:- (连字符), 空格, . (点), _ (下划线), 或无分隔符
|
||||
CAS_REGEX_PATTERN = r'\b(\d{2,7})[\s\-._]?(\d{2})[\s\-._]?(\d)\b'
|
||||
|
||||
MODE_KEYWORD_COLUMNS: Dict[str, List[str]] = {
|
||||
"cas": ["CAS号"],
|
||||
"exact": ["中文名", "英文名", "CAS号", "简称", "可能名称"],
|
||||
}
|
||||
|
||||
MODE_LABELS = {
|
||||
"cas": "CAS号识别",
|
||||
"exact": "精确匹配",
|
||||
}
|
||||
|
||||
|
||||
# ========== 数据类 ==========
|
||||
@dataclass
|
||||
class MatchResult:
|
||||
"""匹配结果数据类"""
|
||||
matched_indices: List[int]
|
||||
matched_keywords: List[str]
|
||||
elapsed_time: float
|
||||
total_rows: int
|
||||
matcher_name: str
|
||||
|
||||
@property
|
||||
def match_count(self) -> int:
|
||||
return len(self.matched_indices)
|
||||
|
||||
@property
|
||||
def match_rate(self) -> float:
|
||||
return (self.match_count / self.total_rows * 100) if self.total_rows > 0 else 0.0
|
||||
|
||||
@property
|
||||
def speed(self) -> float:
|
||||
return (self.total_rows / self.elapsed_time) if self.elapsed_time > 0 else 0.0
|
||||
|
||||
|
||||
# ========== 工具函数 ==========
|
||||
def normalize_cas(cas_str: str) -> str:
|
||||
"""
|
||||
将各种格式的CAS号规范化为标准格式 XXX-XX-X
|
||||
|
||||
支持的输入格式:
|
||||
- 123-45-6 (标准格式)
|
||||
- 123 45 6 (空格分隔)
|
||||
- 123.45.6 (点分隔)
|
||||
- 123_45_6 (下划线分隔)
|
||||
- 12345 6 或 1234 56 (部分分隔)
|
||||
- 123456 (无分隔符,仅当总长度正确时)
|
||||
|
||||
返回:标准格式的CAS号,如果无法解析则返回原字符串
|
||||
"""
|
||||
if not cas_str or not isinstance(cas_str, str):
|
||||
return str(cas_str)
|
||||
|
||||
# 移除所有非数字字符,只保留数字
|
||||
digits_only = re.sub(r'[^\d]', '', cas_str)
|
||||
|
||||
# CAS号至少需要5位数字(最短格式:XX-XX-X)
|
||||
if len(digits_only) < 5:
|
||||
return cas_str
|
||||
|
||||
# 重新格式化为标准格式:前n-3位-中间2位-最后1位
|
||||
# 例如:123456 -> 1234-5-6, 12345 -> 123-4-5
|
||||
return f"{digits_only[:-3]}-{digits_only[-3:-1]}-{digits_only[-1]}"
|
||||
|
||||
|
||||
def extract_cas_numbers(text: str, pattern: str = CAS_REGEX_PATTERN) -> Set[str]:
|
||||
"""
|
||||
从文本中提取所有CAS号并规范化
|
||||
|
||||
参数:
|
||||
text: 待搜索的文本
|
||||
pattern: CAS号正则表达式
|
||||
|
||||
返回:规范化后的CAS号集合
|
||||
"""
|
||||
if not text:
|
||||
return set()
|
||||
|
||||
matches = re.finditer(pattern, str(text))
|
||||
cas_numbers = set()
|
||||
|
||||
for match in matches:
|
||||
# 提取完整匹配
|
||||
raw_cas = match.group(0)
|
||||
# 规范化
|
||||
normalized = normalize_cas(raw_cas)
|
||||
cas_numbers.add(normalized)
|
||||
|
||||
return cas_numbers
|
||||
|
||||
|
||||
def split_value(value: str, separator: str) -> List[str]:
|
||||
"""将单元格内容拆分为多个候选关键词"""
|
||||
if separator and separator in value:
|
||||
parts = value.split(separator)
|
||||
else:
|
||||
parts = [value]
|
||||
return [part.strip() for part in parts if part and part.strip()]
|
||||
|
||||
|
||||
def load_keywords_for_mode(
|
||||
df: pd.DataFrame,
|
||||
mode: str,
|
||||
separator: str = SEPARATOR
|
||||
) -> Set[str]:
|
||||
"""根据模式加载关键词集合"""
|
||||
mode_lower = mode.lower()
|
||||
if mode_lower not in MODE_KEYWORD_COLUMNS:
|
||||
raise ValueError(f"不支持的模式: {mode}")
|
||||
|
||||
target_columns = MODE_KEYWORD_COLUMNS[mode_lower]
|
||||
available_columns = [col for col in target_columns if col in df.columns]
|
||||
missing_columns = [col for col in target_columns if col not in df.columns]
|
||||
|
||||
if not available_columns:
|
||||
raise ValueError(
|
||||
f"模式 '{mode_lower}' 需要的列 {target_columns} 均不存在,"
|
||||
f"当前可用列: {df.columns.tolist()}"
|
||||
)
|
||||
|
||||
if missing_columns:
|
||||
print(f"警告: 以下列在关键词文件中缺失: {missing_columns}")
|
||||
|
||||
keywords: Set[str] = set()
|
||||
for column in available_columns:
|
||||
for value in df[column].dropna():
|
||||
value_str = str(value).strip()
|
||||
if not value_str or value_str in ['#N/A#', 'nan', 'None']:
|
||||
continue
|
||||
for token in split_value(value_str, separator):
|
||||
# CAS模式下规范化CAS号
|
||||
if mode_lower == "cas":
|
||||
token = normalize_cas(token)
|
||||
keywords.add(token)
|
||||
|
||||
label = MODE_LABELS.get(mode_lower, mode_lower)
|
||||
print(f"模式「{label}」共加载 {len(keywords)} 个候选关键词,来源列: {available_columns}")
|
||||
|
||||
if not keywords:
|
||||
print(f"警告: 模式 '{mode_lower}' 未加载到任何关键词!")
|
||||
|
||||
return keywords
|
||||
|
||||
|
||||
def show_progress(
|
||||
current: int,
|
||||
total: int,
|
||||
start_time: float,
|
||||
matched_count: int,
|
||||
interval: int = PROGRESS_INTERVAL
|
||||
) -> None:
|
||||
"""显示处理进度"""
|
||||
if (current + 1) % interval == 0:
|
||||
elapsed = time.time() - start_time
|
||||
speed = (current + 1) / elapsed
|
||||
print(f"已处理 {current + 1}/{total} 行,速度: {speed:.1f} 行/秒,匹配到 {matched_count} 行")
|
||||
|
||||
|
||||
# ========== 匹配器基类(策略模式) ==========
|
||||
class KeywordMatcher(ABC):
|
||||
"""关键词匹配器抽象基类"""
|
||||
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
|
||||
def match(
|
||||
self,
|
||||
df: pd.DataFrame,
|
||||
keywords: Set[str],
|
||||
text_column: str
|
||||
) -> MatchResult:
|
||||
"""执行匹配(模板方法)"""
|
||||
print(f"开始匹配(使用{self.name})...")
|
||||
self._prepare(keywords)
|
||||
|
||||
matched_indices = []
|
||||
matched_keywords_list = []
|
||||
start_time = time.time()
|
||||
|
||||
for idx, text in enumerate(df[text_column]):
|
||||
if pd.isna(text):
|
||||
continue
|
||||
|
||||
text_str = str(text)
|
||||
matches = self._match_single_text(text_str, keywords)
|
||||
|
||||
if matches:
|
||||
matched_indices.append(idx)
|
||||
formatted = self._format_matches(matches)
|
||||
matched_keywords_list.append(formatted)
|
||||
|
||||
show_progress(idx, len(df), start_time, len(matched_indices))
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
return MatchResult(
|
||||
matched_indices=matched_indices,
|
||||
matched_keywords=matched_keywords_list,
|
||||
elapsed_time=elapsed,
|
||||
total_rows=len(df),
|
||||
matcher_name=self.name
|
||||
)
|
||||
|
||||
def _prepare(self, keywords: Set[str]) -> None:
|
||||
"""预处理(子类可选实现)"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _match_single_text(self, text: str, keywords: Set[str]) -> Set[str]:
|
||||
"""匹配单条文本(子类必须实现)"""
|
||||
pass
|
||||
|
||||
def _format_matches(self, matches: Set[str]) -> str:
|
||||
"""格式化匹配结果(子类可重写)"""
|
||||
return MATCH_RESULT_SEPARATOR.join(sorted(matches))
|
||||
|
||||
|
||||
# ========== 具体匹配器实现 ==========
|
||||
class AhoCorasickMatcher(KeywordMatcher):
|
||||
"""Aho-Corasick 自动机匹配器"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("Aho-Corasick 自动机")
|
||||
self.automaton = None
|
||||
|
||||
def _prepare(self, keywords: Set[str]) -> None:
|
||||
"""构建自动机"""
|
||||
if not HAS_AC:
|
||||
raise RuntimeError("pyahocorasick 未安装")
|
||||
|
||||
print("正在构建Aho-Corasick自动机...")
|
||||
self.automaton = ahocorasick.Automaton()
|
||||
|
||||
for keyword in keywords:
|
||||
self.automaton.add_word(keyword, keyword)
|
||||
|
||||
self.automaton.make_automaton()
|
||||
print("自动机构建完成")
|
||||
|
||||
def _match_single_text(self, text: str, keywords: Set[str]) -> Set[str]:
|
||||
"""使用自动机匹配"""
|
||||
matched = set()
|
||||
for end_index, keyword in self.automaton.iter(text):
|
||||
matched.add(keyword)
|
||||
return matched
|
||||
|
||||
|
||||
class SetMatcher(KeywordMatcher):
|
||||
"""标准集合匹配器"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("标准集合匹配")
|
||||
|
||||
def _match_single_text(self, text: str, keywords: Set[str]) -> Set[str]:
|
||||
"""使用集合成员检查匹配"""
|
||||
return {kw for kw in keywords if kw in text}
|
||||
|
||||
|
||||
class CASRegexMatcher(KeywordMatcher):
|
||||
"""CAS号正则表达式匹配器(支持多种格式)"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("CAS号正则匹配(支持多种格式)")
|
||||
self.pattern = re.compile(CAS_REGEX_PATTERN)
|
||||
|
||||
def _match_single_text(self, text: str, keywords: Set[str]) -> Set[str]:
|
||||
"""
|
||||
使用正则表达式提取CAS号,规范化后与关键词库比对
|
||||
|
||||
流程:
|
||||
1. 用正则提取文本中所有可能的CAS号
|
||||
2. 将提取的CAS号规范化为标准格式
|
||||
3. 与关键词库(已规范化)进行比对
|
||||
"""
|
||||
# 从文本中提取并规范化CAS号
|
||||
found_cas_numbers = extract_cas_numbers(text, self.pattern)
|
||||
|
||||
# 与关键词库求交集
|
||||
matched = found_cas_numbers & keywords
|
||||
|
||||
return matched
|
||||
|
||||
|
||||
class RegexExactMatcher(KeywordMatcher):
|
||||
"""正则表达式精确匹配器(支持词边界)"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("正则表达式精确匹配(词边界)")
|
||||
self.pattern = None
|
||||
|
||||
def _prepare(self, keywords: Set[str]) -> None:
|
||||
"""构建正则表达式模式"""
|
||||
print("正在构建正则表达式模式...")
|
||||
|
||||
# 转义所有特殊字符,并使用词边界 \b 确保完整词匹配
|
||||
escaped_keywords = [re.escape(kw) for kw in keywords]
|
||||
|
||||
# 构建正则模式:\b(keyword1|keyword2|...)\b
|
||||
# 词边界确保不会匹配到部分词
|
||||
pattern_str = r'\b(' + '|'.join(escaped_keywords) + r')\b'
|
||||
self.pattern = re.compile(pattern_str)
|
||||
|
||||
print(f"正则模式构建完成,共 {len(keywords)} 个关键词")
|
||||
|
||||
def _match_single_text(self, text: str, keywords: Set[str]) -> Set[str]:
|
||||
"""使用正则表达式精确匹配"""
|
||||
if not self.pattern:
|
||||
return set()
|
||||
|
||||
# 查找所有匹配项
|
||||
matches = self.pattern.findall(text)
|
||||
return set(matches)
|
||||
|
||||
|
||||
# ========== 匹配器工厂 ==========
|
||||
def create_matcher(algorithm: str, fuzzy_threshold: int = DEFAULT_FUZZY_THRESHOLD, mode: str = None) -> KeywordMatcher:
|
||||
"""
|
||||
根据算法类型和模式创建匹配器
|
||||
|
||||
参数:
|
||||
algorithm: 匹配算法 (auto, set, exact)
|
||||
fuzzy_threshold: 已废弃,保留仅为向后兼容
|
||||
mode: 检测模式 (cas, exact),用于选择特定匹配器
|
||||
"""
|
||||
algorithm_lower = algorithm.lower()
|
||||
|
||||
# CAS模式使用CAS正则匹配器
|
||||
if mode and mode.lower() == "cas":
|
||||
return CASRegexMatcher()
|
||||
|
||||
# exact模式或exact算法使用正则精确匹配器
|
||||
if mode and mode.lower() == "exact":
|
||||
return RegexExactMatcher()
|
||||
|
||||
if algorithm_lower == "exact":
|
||||
return RegexExactMatcher()
|
||||
|
||||
elif algorithm_lower == "set":
|
||||
return SetMatcher()
|
||||
|
||||
elif algorithm_lower == "auto":
|
||||
if HAS_AC:
|
||||
return AhoCorasickMatcher()
|
||||
else:
|
||||
print("警告: 未安装 pyahocorasick,使用标准匹配方法")
|
||||
return SetMatcher()
|
||||
|
||||
else:
|
||||
raise ValueError(f"不支持的算法: {algorithm}")
|
||||
|
||||
|
||||
# ========== 结果处理 ==========
|
||||
def save_results(
|
||||
df: pd.DataFrame,
|
||||
result: MatchResult,
|
||||
output_file: str
|
||||
) -> Optional[pd.DataFrame]:
|
||||
"""保存匹配结果到 Excel"""
|
||||
if result.match_count == 0:
|
||||
print("未找到任何匹配")
|
||||
return None
|
||||
|
||||
result_df = df.iloc[result.matched_indices].copy()
|
||||
result_df.insert(0, "匹配到的关键词", result.matched_keywords)
|
||||
|
||||
result_df.to_excel(output_file, index=False, engine='openpyxl')
|
||||
print(f"结果已保存到: {output_file}")
|
||||
return result_df
|
||||
|
||||
|
||||
def print_statistics(result: MatchResult) -> None:
|
||||
"""打印匹配统计信息"""
|
||||
print(f"\n{'='*60}")
|
||||
print(f"匹配完成!")
|
||||
print(f"匹配模式: {result.matcher_name}")
|
||||
print(f"总耗时: {result.elapsed_time:.2f} 秒")
|
||||
print(f"处理速度: {result.speed:.1f} 行/秒")
|
||||
print(f"匹配到 {result.match_count} 行数据 ({result.match_rate:.2f}%)")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
|
||||
def preview_results(result_df: pd.DataFrame, num_rows: int = 5) -> None:
|
||||
"""预览匹配结果"""
|
||||
if result_df.empty:
|
||||
return
|
||||
|
||||
print(f"前{num_rows}行匹配结果预览:")
|
||||
print("=" * 80)
|
||||
pd.set_option('display.max_columns', None)
|
||||
pd.set_option('display.width', None)
|
||||
pd.set_option('display.max_colwidth', 100)
|
||||
print(result_df.head(num_rows))
|
||||
print("=" * 80)
|
||||
print(f"\n✓ 总共匹配到 {len(result_df)} 行数据")
|
||||
|
||||
|
||||
# ========== 主流程 ==========
|
||||
def perform_matching(
|
||||
df: pd.DataFrame,
|
||||
keywords: Set[str],
|
||||
text_column: str,
|
||||
output_file: str,
|
||||
algorithm: str = "auto",
|
||||
mode: str = None
|
||||
) -> Optional[pd.DataFrame]:
|
||||
"""执行完整的匹配流程"""
|
||||
# 验证列存在
|
||||
if text_column not in df.columns:
|
||||
print(f"可用列名: {df.columns.tolist()}")
|
||||
raise ValueError(f"列 '{text_column}' 不存在")
|
||||
|
||||
print(f"文本文件共有 {len(df)} 行数据\n")
|
||||
|
||||
# 创建匹配器并执行匹配
|
||||
matcher = create_matcher(algorithm, mode=mode)
|
||||
result = matcher.match(df, keywords, text_column)
|
||||
|
||||
# 输出统计信息
|
||||
print_statistics(result)
|
||||
|
||||
# 保存结果
|
||||
result_df = save_results(df, result, output_file)
|
||||
|
||||
return result_df
|
||||
|
||||
|
||||
def process_single_mode(
|
||||
keywords_df: pd.DataFrame,
|
||||
text_df: pd.DataFrame,
|
||||
mode: str,
|
||||
text_column: str,
|
||||
output_file: Path,
|
||||
separator: str = SEPARATOR,
|
||||
save_to_file: bool = True
|
||||
) -> Optional[pd.DataFrame]:
|
||||
"""
|
||||
处理单个检测模式
|
||||
|
||||
返回:匹配结果 DataFrame(包含原始索引)
|
||||
"""
|
||||
mode_lower = mode.lower()
|
||||
label = MODE_LABELS.get(mode_lower, mode_lower)
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print(f">>> 正在执行识别模式: {label}")
|
||||
print(f"{'='*60}")
|
||||
|
||||
# 加载关键词
|
||||
keywords = load_keywords_for_mode(keywords_df, mode_lower, separator=separator)
|
||||
|
||||
# 显示关键词样例
|
||||
sample_keywords = list(keywords)[:10]
|
||||
if sample_keywords:
|
||||
print(f"\n{label} - 关键词样例(前10个):")
|
||||
for idx, kw in enumerate(sample_keywords, 1):
|
||||
print(f" {idx}. {kw}")
|
||||
print()
|
||||
|
||||
# 选择算法
|
||||
algorithm = "exact" if mode_lower == "exact" else "auto"
|
||||
|
||||
# 执行匹配(如果不需要保存到文件,传入临时路径)
|
||||
temp_output = str(output_file) if save_to_file else "/tmp/temp_match.xlsx"
|
||||
result_df = perform_matching(
|
||||
df=text_df,
|
||||
keywords=keywords,
|
||||
text_column=text_column,
|
||||
output_file=temp_output,
|
||||
algorithm=algorithm,
|
||||
mode=mode_lower # 传递模式参数
|
||||
)
|
||||
|
||||
# 如果不保存文件,删除临时文件
|
||||
if not save_to_file and result_df is not None:
|
||||
import os
|
||||
if os.path.exists(temp_output):
|
||||
os.remove(temp_output)
|
||||
|
||||
# 预览结果
|
||||
if result_df is not None and save_to_file:
|
||||
preview_results(result_df)
|
||||
|
||||
# 添加模式标识列
|
||||
if result_df is not None:
|
||||
result_df['匹配模式'] = label
|
||||
|
||||
return result_df
|
||||
|
||||
|
||||
def run_multiple_modes(
|
||||
keywords_file: Path,
|
||||
text_file: Path,
|
||||
output_file: Path,
|
||||
text_column: str,
|
||||
modes: List[str],
|
||||
separator: str = SEPARATOR
|
||||
) -> None:
|
||||
"""运行多个检测模式,合并结果到单一文件"""
|
||||
# 验证文件存在
|
||||
if not keywords_file.exists():
|
||||
raise FileNotFoundError(f"找不到关键词文件: {keywords_file}")
|
||||
if not text_file.exists():
|
||||
raise FileNotFoundError(f"找不到文本文件: {text_file}")
|
||||
|
||||
# 加载数据
|
||||
print(f"正在加载关键词文件: {keywords_file}")
|
||||
keywords_df = pd.read_excel(keywords_file)
|
||||
print(f"可用列: {keywords_df.columns.tolist()}\n")
|
||||
|
||||
print(f"正在加载文本文件: {text_file}")
|
||||
text_df = pd.read_excel(text_file)
|
||||
print(f"文本列: {text_column}\n")
|
||||
|
||||
# 验证模式
|
||||
if not modes:
|
||||
raise ValueError("modes 不能为空,请至少指定一个模式")
|
||||
|
||||
for mode in modes:
|
||||
if mode.lower() not in MODE_KEYWORD_COLUMNS:
|
||||
raise ValueError(f"不支持的识别模式: {mode}")
|
||||
|
||||
# 收集所有模式的匹配结果
|
||||
all_results = []
|
||||
multiple_modes = len(modes) > 1
|
||||
|
||||
for mode in modes:
|
||||
mode_lower = mode.lower()
|
||||
|
||||
# 处理该模式(不保存到单独文件)
|
||||
result_df = process_single_mode(
|
||||
keywords_df=keywords_df,
|
||||
text_df=text_df,
|
||||
mode=mode_lower,
|
||||
text_column=text_column,
|
||||
output_file=output_file, # 这个参数在 save_to_file=False 时不使用
|
||||
separator=separator,
|
||||
save_to_file=False # 不保存到单独文件
|
||||
)
|
||||
|
||||
if result_df is not None and not result_df.empty:
|
||||
all_results.append(result_df)
|
||||
|
||||
# 合并所有结果
|
||||
if not all_results:
|
||||
print("\n所有模式均未匹配到数据")
|
||||
return
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print("正在合并所有模式的匹配结果...")
|
||||
print(f"{'='*60}")
|
||||
|
||||
# 合并结果
|
||||
merged_df = merge_mode_results(all_results, text_df)
|
||||
|
||||
# 保存合并后的结果
|
||||
merged_df.to_excel(output_file, index=False, engine='openpyxl')
|
||||
print(f"\n合并结果已保存到: {output_file}")
|
||||
print(f" 总匹配行数: {len(merged_df)} 行")
|
||||
|
||||
# 统计每个模式的贡献
|
||||
print(f"\n各模式匹配统计:")
|
||||
for mode_result in all_results:
|
||||
mode_name = mode_result['匹配模式'].iloc[0]
|
||||
count = len(mode_result)
|
||||
print(f" {mode_name:20s}: {count:4d} 行")
|
||||
|
||||
# 预览合并结果
|
||||
print(f"\n{'='*60}")
|
||||
print("合并结果预览(前5行):")
|
||||
print(f"{'='*60}")
|
||||
preview_results(merged_df, num_rows=5)
|
||||
|
||||
|
||||
def merge_mode_results(
|
||||
results: List[pd.DataFrame],
|
||||
original_df: pd.DataFrame
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
合并多个模式的匹配结果
|
||||
|
||||
策略:
|
||||
1. 按原始数据行索引合并
|
||||
2. 如果同一行被多个模式匹配,合并关键词和模式标识
|
||||
3. 保留原始数据的所有列
|
||||
"""
|
||||
if not results:
|
||||
return pd.DataFrame()
|
||||
|
||||
# 记录每个原始行索引的匹配信息
|
||||
row_matches = {}
|
||||
|
||||
for result_df in results:
|
||||
for idx, row in result_df.iterrows():
|
||||
if idx not in row_matches:
|
||||
# 首次出现该行
|
||||
row_matches[idx] = {
|
||||
'keywords': row['匹配到的关键词'],
|
||||
'modes': [row['匹配模式']]
|
||||
}
|
||||
else:
|
||||
# 该行已被其他模式匹配过,合并关键词和模式
|
||||
existing_keywords = row_matches[idx]['keywords']
|
||||
new_keywords = row['匹配到的关键词']
|
||||
|
||||
# 合并关键词(去重)
|
||||
all_keywords = set(str(existing_keywords).split(' | ')) | set(str(new_keywords).split(' | '))
|
||||
row_matches[idx]['keywords'] = ' | '.join(sorted(all_keywords))
|
||||
|
||||
# 添加模式标识
|
||||
row_matches[idx]['modes'].append(row['匹配模式'])
|
||||
|
||||
# 构建最终结果
|
||||
final_indices = list(row_matches.keys())
|
||||
final_df = original_df.loc[final_indices].copy()
|
||||
|
||||
# 添加合并后的列
|
||||
final_df.insert(0, '匹配到的关键词', [row_matches[idx]['keywords'] for idx in final_indices])
|
||||
final_df.insert(1, '匹配模式', [' + '.join(row_matches[idx]['modes']) for idx in final_indices])
|
||||
|
||||
# 按原始索引排序
|
||||
final_df = final_df.sort_index()
|
||||
|
||||
return final_df
|
||||
|
||||
|
||||
# ========== 命令行接口 ==========
|
||||
def parse_args():
|
||||
"""解析命令行参数"""
|
||||
parser = argparse.ArgumentParser(
|
||||
description='多模式关键词匹配工具',
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
示例:
|
||||
# 使用默认配置(两种模式)
|
||||
python keyword_matcher.py
|
||||
|
||||
# 仅执行 CAS 号识别
|
||||
python keyword_matcher.py -m cas
|
||||
|
||||
# 仅执行精确匹配
|
||||
python keyword_matcher.py -m exact
|
||||
|
||||
# 指定自定义文件路径
|
||||
python keyword_matcher.py -k ../data/input/keywords.xlsx -t ../data/input/text.xlsx
|
||||
"""
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-k', '--keywords',
|
||||
type=str,
|
||||
help='关键词文件路径 (默认: ../data/input/keywords.xlsx)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-t', '--text',
|
||||
type=str,
|
||||
help='文本文件路径 (默认: ../data/input/clickin_text_img.xlsx)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-o', '--output',
|
||||
type=str,
|
||||
help='输出文件路径 (默认: ../data/output/keyword_matched_results.xlsx)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-c', '--text-column',
|
||||
type=str,
|
||||
default='文本',
|
||||
help='文本列名 (默认: 文本)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-m', '--modes',
|
||||
nargs='+',
|
||||
choices=['cas', 'exact'],
|
||||
default=['cas', 'exact'],
|
||||
help='识别模式 (默认: cas exact)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--separator',
|
||||
type=str,
|
||||
default=SEPARATOR,
|
||||
help=f'关键词分隔符 (默认: {SEPARATOR})'
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
args = parse_args()
|
||||
|
||||
# 确定文件路径
|
||||
base_dir = Path(__file__).resolve().parent
|
||||
|
||||
keywords_file = Path(args.keywords) if args.keywords else (
|
||||
base_dir.parent / "data" / "input" / "keywords.xlsx"
|
||||
)
|
||||
|
||||
text_file = Path(args.text) if args.text else (
|
||||
base_dir.parent / "data" / "input" / "clickin_text_img.xlsx"
|
||||
)
|
||||
|
||||
output_file = Path(args.output) if args.output else (
|
||||
base_dir.parent / "data" / "output" / "keyword_matched_results.xlsx"
|
||||
)
|
||||
|
||||
# 显示依赖库状态
|
||||
print("=" * 60)
|
||||
print("依赖库状态:")
|
||||
print(f" pyahocorasick: {'已安装 ✓' if HAS_AC else '未安装 ✗'}")
|
||||
print("=" * 60)
|
||||
print()
|
||||
|
||||
if not HAS_AC:
|
||||
print("提示: 安装 pyahocorasick 可获得 5x 性能提升: pip install pyahocorasick\n")
|
||||
|
||||
try:
|
||||
run_multiple_modes(
|
||||
keywords_file=keywords_file,
|
||||
text_file=text_file,
|
||||
output_file=output_file,
|
||||
text_column=args.text_column,
|
||||
modes=args.modes,
|
||||
separator=args.separator
|
||||
)
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("✓ 所有模式处理完成!")
|
||||
print("=" * 60)
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n错误: {e}")
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
154
scripts/run.sh
Normal file
154
scripts/run.sh
Normal file
@@ -0,0 +1,154 @@
|
||||
#!/bin/bash
|
||||
# 批量图片分析管理脚本 - 极简版
|
||||
|
||||
BASE_DIR="$HOME/bin/p_anlz"
|
||||
BATCH_SCRIPT="$BASE_DIR/scripts/batch_run.sh"
|
||||
LOG_DIR="$BASE_DIR/logs"
|
||||
MAIN_LOG="$LOG_DIR/batch_run.log"
|
||||
|
||||
# 显示帮助
|
||||
show_help() {
|
||||
cat << EOF
|
||||
批量图片分析管理工具
|
||||
|
||||
用法: $0 <命令>
|
||||
|
||||
命令:
|
||||
start 后台启动任务
|
||||
stop 停止任务
|
||||
status 查看运行状态
|
||||
log 查看实时日志
|
||||
help 显示帮助
|
||||
|
||||
示例:
|
||||
$0 start # 启动任务
|
||||
$0 log # 查看日志
|
||||
$0 status # 查看状态
|
||||
|
||||
修改参数:
|
||||
API_TYPE=openai $0 start
|
||||
MAX_WORKERS=3 $0 start
|
||||
TIMEOUT=120 $0 start
|
||||
EOF
|
||||
}
|
||||
|
||||
# 启动任务
|
||||
start_task() {
|
||||
if pgrep -f "batch_run.sh" > /dev/null; then
|
||||
echo "任务已在运行中"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "启动批量分析任务..."
|
||||
|
||||
# 导出环境变量
|
||||
export API_TYPE="${API_TYPE:-dify}"
|
||||
export MAX_WORKERS="${MAX_WORKERS:-1}"
|
||||
export TIMEOUT="${TIMEOUT:-90}"
|
||||
|
||||
# 后台运行
|
||||
nohup bash "$BATCH_SCRIPT" > /dev/null 2>&1 &
|
||||
|
||||
sleep 2
|
||||
|
||||
if pgrep -f "batch_run.sh" > /dev/null; then
|
||||
echo "✓ 任务已启动"
|
||||
echo ""
|
||||
echo "查看日志: $0 log"
|
||||
echo "查看状态: $0 status"
|
||||
else
|
||||
echo "✗ 启动失败"
|
||||
exit 1
|
||||
fi
|
||||
}
|
||||
|
||||
# 停止任务
|
||||
stop_task() {
|
||||
if ! pgrep -f "batch_run.sh" > /dev/null; then
|
||||
echo "任务未运行"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
echo "停止任务..."
|
||||
pkill -f "batch_run.sh"
|
||||
pkill -f "image_batch_recognizer.py"
|
||||
|
||||
sleep 1
|
||||
|
||||
if ! pgrep -f "batch_run.sh" > /dev/null; then
|
||||
echo "✓ 任务已停止"
|
||||
else
|
||||
echo "✗ 停止失败,尝试强制终止..."
|
||||
pkill -9 -f "batch_run.sh"
|
||||
pkill -9 -f "image_batch_recognizer.py"
|
||||
fi
|
||||
}
|
||||
|
||||
# 查看状态
|
||||
show_status() {
|
||||
echo "任务状态:"
|
||||
echo ""
|
||||
|
||||
if pgrep -f "batch_run.sh" > /dev/null; then
|
||||
echo "✓ 运行中"
|
||||
echo ""
|
||||
echo "进程信息:"
|
||||
ps aux | grep -E "batch_run.sh|image_batch_recognizer.py" | grep -v grep
|
||||
echo ""
|
||||
|
||||
# 显示当前进度
|
||||
if [ -f "$MAIN_LOG" ]; then
|
||||
echo "最新日志:"
|
||||
tail -10 "$MAIN_LOG"
|
||||
fi
|
||||
else
|
||||
echo "○ 未运行"
|
||||
|
||||
if [ -f "$MAIN_LOG" ]; then
|
||||
echo ""
|
||||
echo "上次运行结果:"
|
||||
tail -15 "$MAIN_LOG" | grep -A 10 "批量分析完成"
|
||||
fi
|
||||
fi
|
||||
}
|
||||
|
||||
# 查看日志
|
||||
show_log() {
|
||||
if [ ! -f "$MAIN_LOG" ]; then
|
||||
echo "日志文件不存在: $MAIN_LOG"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "实时查看日志 (Ctrl+C 退出):"
|
||||
echo ""
|
||||
tail -f "$MAIN_LOG"
|
||||
}
|
||||
|
||||
# 主函数
|
||||
main() {
|
||||
case "${1:-help}" in
|
||||
start)
|
||||
start_task
|
||||
;;
|
||||
stop)
|
||||
stop_task
|
||||
;;
|
||||
status)
|
||||
show_status
|
||||
;;
|
||||
log)
|
||||
show_log
|
||||
;;
|
||||
help|--help|-h)
|
||||
show_help
|
||||
;;
|
||||
*)
|
||||
echo "未知命令: $1"
|
||||
echo ""
|
||||
show_help
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
}
|
||||
|
||||
main "$@"
|
||||
349
scripts/run_batch_background.sh
Normal file
349
scripts/run_batch_background.sh
Normal file
@@ -0,0 +1,349 @@
|
||||
#!/bin/bash
|
||||
# 后台启动批量分析任务的包装脚本
|
||||
# 支持 nohup、screen、tmux 三种方式
|
||||
|
||||
set -e
|
||||
|
||||
BASE_DIR="$HOME/bin/p_anlz"
|
||||
SCRIPT_DIR="$BASE_DIR/scripts"
|
||||
BATCH_SCRIPT="$SCRIPT_DIR/batch_analyze_all_folders.sh"
|
||||
LOG_DIR="$BASE_DIR/logs"
|
||||
NOHUP_LOG="$LOG_DIR/nohup.out"
|
||||
|
||||
# 颜色输出
|
||||
GREEN='\033[0;32m'
|
||||
YELLOW='\033[1;33m'
|
||||
BLUE='\033[0;34m'
|
||||
NC='\033[0m'
|
||||
|
||||
# 显示使用说明
|
||||
show_usage() {
|
||||
cat << EOF
|
||||
批量图片分析后台运行工具
|
||||
|
||||
用法:
|
||||
$0 [方式]
|
||||
|
||||
方式选项:
|
||||
nohup 使用 nohup 后台运行(默认,最简单)
|
||||
screen 使用 screen 会话运行(可重新连接)
|
||||
tmux 使用 tmux 会话运行(推荐,功能最强)
|
||||
status 查看当前运行状态
|
||||
logs 实时查看主日志
|
||||
progress 查看处理进度
|
||||
live 实时查看当前 Python 程序输出(推荐)
|
||||
|
||||
示例:
|
||||
$0 nohup # 使用 nohup 启动
|
||||
$0 screen # 使用 screen 启动
|
||||
$0 tmux # 使用 tmux 启动
|
||||
$0 status # 查看状态
|
||||
$0 logs # 查看主日志
|
||||
$0 progress # 查看进度
|
||||
$0 live # 实时查看 Python 输出(含进度条)
|
||||
|
||||
后台运行后如何查看:
|
||||
- nohup 方式: tail -f $NOHUP_LOG
|
||||
- screen 方式: screen -r batch_analyze
|
||||
- tmux 方式: tmux attach -t batch_analyze
|
||||
- 实时输出: $0 live
|
||||
EOF
|
||||
}
|
||||
|
||||
# 检查脚本是否存在
|
||||
check_script() {
|
||||
if [ ! -f "$BATCH_SCRIPT" ]; then
|
||||
echo -e "${RED}错误: 批处理脚本不存在: $BATCH_SCRIPT${NC}"
|
||||
exit 1
|
||||
fi
|
||||
chmod +x "$BATCH_SCRIPT"
|
||||
}
|
||||
|
||||
# 检查是否已经在运行
|
||||
check_running() {
|
||||
if pgrep -f "batch_analyze_all_folders.sh" > /dev/null; then
|
||||
echo -e "${YELLOW}警告: 批量分析任务已在运行中${NC}"
|
||||
echo "进程信息:"
|
||||
ps aux | grep "batch_analyze_all_folders.sh" | grep -v grep
|
||||
echo ""
|
||||
read -p "是否继续启动新任务? (y/N): " confirm
|
||||
if [ "$confirm" != "y" ] && [ "$confirm" != "Y" ]; then
|
||||
echo "已取消"
|
||||
exit 0
|
||||
fi
|
||||
fi
|
||||
}
|
||||
|
||||
# 使用 nohup 启动
|
||||
start_nohup() {
|
||||
echo -e "${BLUE}使用 nohup 启动批量分析任务...${NC}"
|
||||
mkdir -p "$LOG_DIR"
|
||||
|
||||
nohup bash "$BATCH_SCRIPT" > "$NOHUP_LOG" 2>&1 &
|
||||
local pid=$!
|
||||
|
||||
echo -e "${GREEN}✓ 任务已在后台启动 (PID: $pid)${NC}"
|
||||
echo ""
|
||||
echo "查看日志:"
|
||||
echo " tail -f $NOHUP_LOG"
|
||||
echo ""
|
||||
echo "查看进度:"
|
||||
echo " $0 progress"
|
||||
echo ""
|
||||
echo "停止任务:"
|
||||
echo " kill $pid"
|
||||
}
|
||||
|
||||
# 使用 screen 启动
|
||||
start_screen() {
|
||||
if ! command -v screen &> /dev/null; then
|
||||
echo -e "${YELLOW}screen 未安装,请先安装: sudo apt-get install screen${NC}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo -e "${BLUE}使用 screen 启动批量分析任务...${NC}"
|
||||
|
||||
# 检查是否已有同名 session
|
||||
if screen -list | grep -q "batch_analyze"; then
|
||||
echo -e "${YELLOW}警告: screen 会话 'batch_analyze' 已存在${NC}"
|
||||
read -p "是否删除旧会话并创建新会话? (y/N): " confirm
|
||||
if [ "$confirm" = "y" ] || [ "$confirm" = "Y" ]; then
|
||||
screen -S batch_analyze -X quit 2>/dev/null || true
|
||||
else
|
||||
echo "已取消"
|
||||
exit 0
|
||||
fi
|
||||
fi
|
||||
|
||||
screen -dmS batch_analyze bash "$BATCH_SCRIPT"
|
||||
|
||||
echo -e "${GREEN}✓ 任务已在 screen 会话中启动${NC}"
|
||||
echo ""
|
||||
echo "重新连接到会话:"
|
||||
echo " screen -r batch_analyze"
|
||||
echo ""
|
||||
echo "分离会话: Ctrl+A, 然后按 D"
|
||||
echo ""
|
||||
echo "查看所有会话:"
|
||||
echo " screen -ls"
|
||||
echo ""
|
||||
echo "终止会话:"
|
||||
echo " screen -S batch_analyze -X quit"
|
||||
}
|
||||
|
||||
# 使用 tmux 启动
|
||||
start_tmux() {
|
||||
if ! command -v tmux &> /dev/null; then
|
||||
echo -e "${YELLOW}tmux 未安装,请先安装: sudo apt-get install tmux${NC}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo -e "${BLUE}使用 tmux 启动批量分析任务...${NC}"
|
||||
|
||||
# 检查是否已有同名 session
|
||||
if tmux has-session -t batch_analyze 2>/dev/null; then
|
||||
echo -e "${YELLOW}警告: tmux 会话 'batch_analyze' 已存在${NC}"
|
||||
read -p "是否删除旧会话并创建新会话? (y/N): " confirm
|
||||
if [ "$confirm" = "y" ] || [ "$confirm" = "Y" ]; then
|
||||
tmux kill-session -t batch_analyze 2>/dev/null || true
|
||||
else
|
||||
echo "已取消"
|
||||
exit 0
|
||||
fi
|
||||
fi
|
||||
|
||||
tmux new-session -d -s batch_analyze "bash $BATCH_SCRIPT"
|
||||
|
||||
echo -e "${GREEN}✓ 任务已在 tmux 会话中启动${NC}"
|
||||
echo ""
|
||||
echo "重新连接到会话:"
|
||||
echo " tmux attach -t batch_analyze"
|
||||
echo ""
|
||||
echo "分离会话: Ctrl+B, 然后按 D"
|
||||
echo ""
|
||||
echo "查看所有会话:"
|
||||
echo " tmux ls"
|
||||
echo ""
|
||||
echo "终止会话:"
|
||||
echo " tmux kill-session -t batch_analyze"
|
||||
}
|
||||
|
||||
# 查看运行状态
|
||||
show_status() {
|
||||
echo -e "${BLUE}批量分析任务状态:${NC}"
|
||||
echo ""
|
||||
|
||||
# 检查进程
|
||||
if pgrep -f "batch_analyze_all_folders.sh" > /dev/null; then
|
||||
echo -e "${GREEN}✓ 任务正在运行${NC}"
|
||||
echo ""
|
||||
echo "进程信息:"
|
||||
ps aux | grep "batch_analyze_all_folders.sh" | grep -v grep
|
||||
else
|
||||
echo -e "${YELLOW}○ 任务未运行${NC}"
|
||||
fi
|
||||
|
||||
echo ""
|
||||
|
||||
# 检查 screen 会话
|
||||
if command -v screen &> /dev/null; then
|
||||
if screen -list | grep -q "batch_analyze"; then
|
||||
echo -e "${GREEN}✓ Screen 会话存在${NC}"
|
||||
screen -list | grep "batch_analyze"
|
||||
fi
|
||||
fi
|
||||
|
||||
echo ""
|
||||
|
||||
# 检查 tmux 会话
|
||||
if command -v tmux &> /dev/null; then
|
||||
if tmux has-session -t batch_analyze 2>/dev/null; then
|
||||
echo -e "${GREEN}✓ Tmux 会话存在${NC}"
|
||||
tmux list-sessions | grep "batch_analyze"
|
||||
fi
|
||||
fi
|
||||
}
|
||||
|
||||
# 查看日志
|
||||
show_logs() {
|
||||
local main_log="$LOG_DIR/batch_analyze.log"
|
||||
|
||||
if [ -f "$main_log" ]; then
|
||||
echo -e "${BLUE}实时查看主日志 (Ctrl+C 退出):${NC}"
|
||||
echo ""
|
||||
tail -f "$main_log"
|
||||
elif [ -f "$NOHUP_LOG" ]; then
|
||||
echo -e "${BLUE}实时查看 nohup 日志 (Ctrl+C 退出):${NC}"
|
||||
echo ""
|
||||
tail -f "$NOHUP_LOG"
|
||||
else
|
||||
echo -e "${YELLOW}未找到日志文件${NC}"
|
||||
echo "日志目录: $LOG_DIR"
|
||||
fi
|
||||
}
|
||||
|
||||
# 查看进度
|
||||
show_progress() {
|
||||
local progress_file="$LOG_DIR/progress.txt"
|
||||
|
||||
if [ ! -f "$progress_file" ]; then
|
||||
echo -e "${YELLOW}进度文件不存在: $progress_file${NC}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo -e "${BLUE}批量分析进度:${NC}"
|
||||
echo ""
|
||||
|
||||
# 显示进度文件内容
|
||||
cat "$progress_file"
|
||||
|
||||
echo ""
|
||||
echo -e "${BLUE}统计信息:${NC}"
|
||||
|
||||
# 统计各状态数量
|
||||
local total=$(grep -c "SUCCESS\|FAILED\|SKIPPED" "$progress_file" || echo "0")
|
||||
local success=$(grep -c "SUCCESS" "$progress_file" || echo "0")
|
||||
local failed=$(grep -c "FAILED" "$progress_file" || echo "0")
|
||||
local skipped=$(grep -c "SKIPPED" "$progress_file" || echo "0")
|
||||
|
||||
echo "总计: $total"
|
||||
echo -e "${GREEN}成功: $success${NC}"
|
||||
echo -e "${YELLOW}失败: $failed${NC}"
|
||||
echo -e "${BLUE}跳过: $skipped${NC}"
|
||||
|
||||
# 显示最后更新时间
|
||||
if [ -f "$progress_file" ]; then
|
||||
local last_update=$(stat -c %y "$progress_file" 2>/dev/null || stat -f "%Sm" "$progress_file" 2>/dev/null)
|
||||
echo ""
|
||||
echo "最后更新: $last_update"
|
||||
fi
|
||||
}
|
||||
|
||||
# 实时查看当前 Python 程序输出
|
||||
show_live() {
|
||||
local main_log="$LOG_DIR/batch_analyze.log"
|
||||
|
||||
echo -e "${BLUE}实时查看 Python 程序输出 (Ctrl+C 退出)${NC}"
|
||||
echo ""
|
||||
|
||||
# 检查任务是否在运行
|
||||
if ! pgrep -f "batch_analyze_all_folders.sh" > /dev/null; then
|
||||
echo -e "${YELLOW}任务未运行${NC}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# 查找当前正在处理的文件夹
|
||||
if [ -f "$main_log" ]; then
|
||||
local current_folder=$(grep -oP "处理文件夹: \K.*" "$main_log" | tail -1)
|
||||
if [ -n "$current_folder" ]; then
|
||||
local folder_log="$LOG_DIR/${current_folder}.log"
|
||||
|
||||
if [ -f "$folder_log" ]; then
|
||||
echo -e "${GREEN}当前处理: $current_folder${NC}"
|
||||
echo -e "${BLUE}日志文件: $folder_log${NC}"
|
||||
echo ""
|
||||
echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━"
|
||||
|
||||
# 实时跟踪文件夹日志(显示 Python 程序的完整输出)
|
||||
tail -f "$folder_log"
|
||||
else
|
||||
echo -e "${YELLOW}等待日志文件生成...${NC}"
|
||||
sleep 2
|
||||
show_live
|
||||
fi
|
||||
else
|
||||
echo -e "${YELLOW}等待任务开始...${NC}"
|
||||
sleep 2
|
||||
show_live
|
||||
fi
|
||||
else
|
||||
echo -e "${YELLOW}主日志文件不存在${NC}"
|
||||
exit 1
|
||||
fi
|
||||
}
|
||||
|
||||
# 主函数
|
||||
main() {
|
||||
local mode="${1:-nohup}"
|
||||
|
||||
case "$mode" in
|
||||
nohup)
|
||||
check_script
|
||||
check_running
|
||||
start_nohup
|
||||
;;
|
||||
screen)
|
||||
check_script
|
||||
check_running
|
||||
start_screen
|
||||
;;
|
||||
tmux)
|
||||
check_script
|
||||
check_running
|
||||
start_tmux
|
||||
;;
|
||||
status)
|
||||
show_status
|
||||
;;
|
||||
logs)
|
||||
show_logs
|
||||
;;
|
||||
progress)
|
||||
show_progress
|
||||
;;
|
||||
live)
|
||||
show_live
|
||||
;;
|
||||
-h|--help|help)
|
||||
show_usage
|
||||
;;
|
||||
*)
|
||||
echo "未知选项: $mode"
|
||||
echo ""
|
||||
show_usage
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
}
|
||||
|
||||
main "$@"
|
||||
Reference in New Issue
Block a user