init: first upload

This commit is contained in:
2026-01-04 09:07:25 +08:00
commit 29f6e25f70
9 changed files with 2598 additions and 0 deletions

View File

@@ -0,0 +1,776 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
批量图片识别脚本
功能:
1. 遍历指定目录下的图片文件
2. 将图片上传至大模型 APIOpenAI、Anthropic或本地模拟模式
3. 输出包含图片中文本、物品描述、风险概述的 Excel 报表
使用示例:
cd scripts
python3 image_batch_recognizer.py --limit 5 --api-type openai
"""
from __future__ import annotations
import argparse
import base64
import json
import mimetypes
import os
import sys
import time
from dataclasses import dataclass, field
from datetime import datetime
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Any, Dict, List, Optional, Sequence, Tuple
import pandas as pd
try:
import requests # type: ignore
# 禁用 SSL 警告(用于自签名证书)
import urllib3
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
except ImportError: # pragma: no cover - requests 是可选依赖
requests = None
try:
from tqdm import tqdm # type: ignore
except ImportError:
# 如果没有安装 tqdm提供一个简单的替代
class tqdm:
def __init__(self, iterable=None, total=None, desc=None, **kwargs):
self.iterable = iterable
self.total = total or (len(iterable) if iterable else 0)
self.desc = desc
self.n = 0
def __iter__(self):
for item in self.iterable:
yield item
self.n += 1
def update(self, n=1):
self.n += n
def set_postfix_str(self, s):
# 简化版本,不做任何操作
pass
def close(self):
pass
SEPARATOR = "|||"
DEFAULT_SLEEP_SECONDS = 0.0
DEFAULT_TIMEOUT = 90
DEFAULT_MAX_TOKENS = 800
SUPPORTED_EXTENSIONS = {".jpg", ".jpeg", ".png", ".webp"}
DEFAULT_ANALYSIS_PROMPT = """
你是一名涉毒风控场景下的多模态分析师。请识别我提供的图片内容,并输出 JSON。
JSON 字段要求:
{
"detected_text": ["按行列出的文字内容,若无文字返回空数组"],
"detected_objects": ["出现的主要物品或场景要素,中文描述"],
"sensitive_items": ["可疑化学品、药物、工具等,如无则为空数组"],
"summary": "2-3 句话概括图片与潜在风险,中文",
"confidence": "High | Medium | Low"
}
只返回 JSON不要额外说明。
""".strip()
# ----------- 数据结构 -----------
@dataclass
class RecognitionResult:
image_name: str
image_path: str
detected_text: List[str] = field(default_factory=list)
detected_objects: List[str] = field(default_factory=list)
sensitive_items: List[str] = field(default_factory=list)
summary: str = ""
confidence: str = ""
raw_response: str = ""
api_type: str = ""
model_name: str = ""
latency_seconds: float = 0.0
analyzed_at: str = ""
error: Optional[str] = None
def to_row(self) -> Dict[str, Any]:
"""转换为写入 Excel 的行数据"""
return {
"image_name": self.image_name,
"image_path": self.image_path,
"detected_text": "\n".join(self.detected_text).strip(),
"detected_objects": SEPARATOR.join(self.detected_objects),
"sensitive_items": SEPARATOR.join(self.sensitive_items),
"summary": self.summary,
"confidence": self.confidence,
"latency_seconds": round(self.latency_seconds, 2),
"api_type": self.api_type,
"model": self.model_name,
"analyzed_at": self.analyzed_at,
"error": self.error or "",
"raw_response": self.raw_response,
}
@property
def is_success(self) -> bool:
return self.error is None
# ----------- 工具函数 -----------
def encode_image(image_path: Path) -> Tuple[str, str]:
"""以 base64 形式读取图片,同时推断 MIME 类型"""
data = image_path.read_bytes()
encoded = base64.b64encode(data).decode("utf-8")
mime_type, _ = mimetypes.guess_type(image_path.name)
return encoded, mime_type or "image/jpeg"
def extract_json_from_response(text: str) -> Tuple[Optional[Dict[str, Any]], Optional[str]]:
"""
从模型返回的文本中提取 JSON。
优先尝试直接解析,失败后截取首尾花括号进行再次解析。
"""
if not text:
return None, "空响应"
text = text.strip()
try:
return json.loads(text), None
except json.JSONDecodeError:
start = text.find("{")
end = text.rfind("}")
if start != -1 and end != -1 and end > start:
candidate = text[start:end + 1]
try:
return json.loads(candidate), None
except json.JSONDecodeError as exc:
return None, f"JSON 解析失败: {exc}"
return None, "响应中未找到 JSON 对象"
def normalize_list(value: Any) -> List[str]:
"""将模型返回的字段整理为字符串列表"""
if value is None:
return []
if isinstance(value, str):
return [part.strip() for part in value.split(SEPARATOR) if part.strip()]
if isinstance(value, (list, tuple, set)):
return [str(item).strip() for item in value if str(item).strip()]
return [str(value).strip()]
def http_post_json(url: str, headers: Dict[str, str], payload: Dict[str, Any], timeout: int) -> Tuple[int, str]:
"""发送 JSON POST 请求(优先使用 requests回退到 urllib"""
body = json.dumps(payload)
headers = {**headers, "Content-Type": "application/json"}
if requests:
# 禁用 SSL 证书验证(用于自签名证书)
response = requests.post(url, headers=headers, data=body, timeout=timeout, verify=False)
return response.status_code, response.text
# requests 不可用时使用 urllib
import urllib.error
import urllib.request
req = urllib.request.Request(url, data=body.encode("utf-8"), headers=headers, method="POST")
try:
with urllib.request.urlopen(req, timeout=timeout) as resp:
response_body = resp.read().decode("utf-8")
return resp.getcode(), response_body
except urllib.error.HTTPError as exc: # pragma: no cover - 网络错误场景
return exc.code, exc.read().decode("utf-8")
def http_post_multipart(url: str, headers: Dict[str, str],
data: Dict[str, str], files: Dict[str, Any],
timeout: int) -> Tuple[int, str]:
"""发送 multipart/form-data POST 请求(用于 Dify 文件上传)"""
if not requests:
raise RuntimeError(
"Dify 文件上传需要 requests 库,请安装: pip install requests"
)
# 禁用 SSL 证书验证(用于自签名证书)
response = requests.post(url, headers=headers, data=data,
files=files, timeout=timeout, verify=False)
return response.status_code, response.text
def build_endpoint(base_url: Optional[str], path: str, default_base: str) -> str:
"""根据 base_url 和 path 生成完整 API 地址"""
base = (base_url or default_base).rstrip("/")
path = path if path.startswith("/") else f"/{path}"
return f"{base}{path}"
def load_env_file(env_file: Optional[str]) -> Optional[Path]:
"""从 .env 文件加载环境变量"""
if not env_file:
return None
env_path = Path(env_file).expanduser()
if not env_path.exists():
return None
for raw_line in env_path.read_text(encoding="utf-8").splitlines():
line = raw_line.strip()
if not line or line.startswith("#"):
continue
if line.startswith("export "):
line = line[len("export "):].strip()
if "=" not in line:
continue
key, value = line.split("=", 1)
key = key.strip()
value = value.strip().strip('"').strip("'")
if not key:
continue
os.environ[key] = value
print(f"已从 {env_path} 加载环境变量")
return env_path
# ----------- 基础客户端 -----------
class BaseVisionClient:
def __init__(self, model: str, api_type: str, timeout: int = DEFAULT_TIMEOUT):
self.model = model
self.api_type = api_type
self.timeout = timeout
def analyze_image(self, image_path: Path, prompt: str) -> RecognitionResult:
encoded, mime_type = encode_image(image_path)
start = time.time()
analyzed_at = datetime.now().isoformat()
try:
raw_text = self._send_request(
image_base64=encoded,
mime_type=mime_type,
prompt=prompt,
image_name=image_path.name,
)
except Exception as exc: # pragma: no cover - 网络错误较难稳定复现
latency = time.time() - start
return RecognitionResult(
image_name=image_path.name,
image_path=str(image_path),
api_type=self.api_type,
model_name=self.model,
latency_seconds=latency,
analyzed_at=analyzed_at,
raw_response="",
error=str(exc),
)
latency = time.time() - start
parsed, parse_error = extract_json_from_response(raw_text)
detected_text = normalize_list(parsed.get("detected_text")) if parsed else []
detected_objects = normalize_list(parsed.get("detected_objects")) if parsed else []
sensitive_items = normalize_list(parsed.get("sensitive_items")) if parsed else []
summary = (parsed or {}).get("summary", "")
confidence = (parsed or {}).get("confidence", "")
return RecognitionResult(
image_name=image_path.name,
image_path=str(image_path),
detected_text=detected_text,
detected_objects=detected_objects,
sensitive_items=sensitive_items,
summary=str(summary),
confidence=str(confidence),
raw_response=raw_text,
api_type=self.api_type,
model_name=self.model,
latency_seconds=latency,
analyzed_at=analyzed_at,
error=parse_error,
)
def _send_request(self, image_base64: str, mime_type: str, prompt: str, image_name: str) -> str:
raise NotImplementedError
# ----------- OpenAI 客户端 -----------
class OpenAIVisionClient(BaseVisionClient):
def __init__(self, api_key: str, model: str, timeout: int, base_url: Optional[str] = None):
super().__init__(model=model, api_type="openai", timeout=timeout)
if not api_key:
raise ValueError("OPENAI_API_KEY 未设置")
self.api_key = api_key
self.endpoint = build_endpoint(base_url, "/v1/chat/completions", "https://api.openai.com")
def _send_request(self, image_base64: str, mime_type: str, prompt: str, image_name: str) -> str:
image_url = f"data:{mime_type};base64,{image_base64}"
payload = {
"model": self.model,
"temperature": 0,
"max_tokens": DEFAULT_MAX_TOKENS,
"messages": [
{
"role": "system",
"content": "You are a multimodal analyst that returns structured JSON responses in Chinese.",
},
{
"role": "user",
"content": [
{"type": "text", "text": f"{prompt}\n图片文件名: {image_name}"},
{"type": "image_url", "image_url": {"url": image_url}},
],
},
],
}
headers = {"Authorization": f"Bearer {self.api_key}"}
status_code, response_text = http_post_json(self.endpoint, headers, payload, self.timeout)
if status_code >= 400:
raise RuntimeError(f"OpenAI API 调用失败 ({status_code}): {response_text}")
response_json = json.loads(response_text)
content = response_json["choices"][0]["message"]["content"]
if isinstance(content, list):
text = "\n".join(part.get("text", "") for part in content if isinstance(part, dict))
else:
text = content
return text.strip()
# ----------- DMX (文档上传方式) 客户端 -----------
class DMXVisionClient(BaseVisionClient):
def __init__(self, api_key: str, model: str, timeout: int, base_url: Optional[str] = None):
super().__init__(model=model, api_type="dmx", timeout=timeout)
if not api_key:
raise ValueError("DMX_API_KEY 未设置")
self.api_key = api_key
self.endpoint = build_endpoint(base_url, "/v1/chat/completions", "https://www.dmxapi.cn")
def _send_request(self, image_base64: str, mime_type: str, prompt: str, image_name: str) -> str:
image_url = f"data:{mime_type};base64,{image_base64}"
payload = {
"model": self.model,
"temperature": 0.1,
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": f"{prompt}\n图片文件名: {image_name}"},
{"type": "image_url", "image_url": {"url": image_url}},
],
}
],
}
headers = {"Authorization": f"Bearer {self.api_key}"}
status_code, response_text = http_post_json(self.endpoint, headers, payload, self.timeout)
if status_code >= 400:
raise RuntimeError(f"DMX API 调用失败 ({status_code}): {response_text}")
response_json = json.loads(response_text)
content = response_json["choices"][0]["message"]["content"]
if isinstance(content, list):
text = "\n".join(part.get("text", "") for part in content if isinstance(part, dict))
else:
text = content
return text.strip()
# ----------- Anthropic 客户端 -----------
class AnthropicVisionClient(BaseVisionClient):
def __init__(self, api_key: str, model: str, timeout: int, base_url: Optional[str] = None):
super().__init__(model=model, api_type="anthropic", timeout=timeout)
if not api_key:
raise ValueError("ANTHROPIC_API_KEY 未设置")
self.api_key = api_key
self.endpoint = build_endpoint(base_url, "/v1/messages", "https://api.anthropic.com")
def _send_request(self, image_base64: str, mime_type: str, prompt: str, image_name: str) -> str:
payload = {
"model": self.model,
"max_tokens": DEFAULT_MAX_TOKENS,
"system": "You are a multimodal analyst that responds with JSON only.",
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": f"{prompt}\n图片文件名: {image_name}"},
{
"type": "image",
"source": {
"type": "base64",
"media_type": mime_type,
"data": image_base64,
},
},
],
}
],
}
headers = {
"x-api-key": self.api_key,
"anthropic-version": "2023-06-01",
}
status_code, response_text = http_post_json(self.endpoint, headers, payload, self.timeout)
if status_code >= 400:
raise RuntimeError(f"Anthropic API 调用失败 ({status_code}): {response_text}")
response_json = json.loads(response_text)
text_parts = [part.get("text", "") for part in response_json.get("content", []) if part.get("type") == "text"]
return "\n".join(text_parts).strip()
# ----------- 模拟客户端(无 API 时使用) -----------
class MockVisionClient(BaseVisionClient):
def __init__(self):
super().__init__(model="mock", api_type="mock", timeout=1)
def _send_request(self, image_base64: str, mime_type: str, prompt: str, image_name: str) -> str:
del image_base64, mime_type, prompt # 未使用
fake_json = {
"detected_text": ["(模拟模式)未调用真实 API"],
"detected_objects": ["sample-object-from-filename"],
"sensitive_items": [],
"summary": f"Mock 模式:演示 {image_name} 的返回格式。",
"confidence": "Low",
}
return json.dumps(fake_json, ensure_ascii=False)
# ----------- Dify Chatflow 客户端 -----------
class DifyVisionClient(BaseVisionClient):
"""Dify Chatflow API client for vision analysis"""
def __init__(self, api_key: str, model: str, timeout: int,
base_url: Optional[str] = None, user_id: Optional[str] = None):
super().__init__(model=model, api_type="dify", timeout=timeout)
if not api_key:
raise ValueError("DIFY_API_KEY 未设置")
self.api_key = api_key
self.base_url = (base_url or "https://dify.example.com").rstrip("/")
self.user_id = user_id or "default-user"
def _upload_file(self, image_path: Path) -> str:
"""上传图片到 Dify返回文件ID (UUID)"""
url = f"{self.base_url}/v1/files/upload"
headers = {"Authorization": f"Bearer {self.api_key}"}
data = {"user": self.user_id}
# 使用 multipart/form-data 上传
mime_type, _ = mimetypes.guess_type(image_path.name)
with open(image_path, "rb") as f:
files = {"file": (image_path.name, f, mime_type or "image/jpeg")}
status_code, response_text = http_post_multipart(
url, headers, data, files, self.timeout
)
if status_code >= 400:
raise RuntimeError(f"Dify 文件上传失败 ({status_code}): {response_text}")
response_json = json.loads(response_text)
return response_json["id"] # 返回文件UUID
def analyze_image(self, image_path: Path, prompt: str) -> RecognitionResult:
"""重写完整流程:上传文件 → 发送工作流请求"""
start = time.time()
analyzed_at = datetime.now().isoformat()
try:
# Step 1: 上传图片
file_id = self._upload_file(image_path)
# Step 2: 发送工作流请求(使用 /v1/workflows/run
url = f"{self.base_url}/v1/workflows/run"
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
# 构建 inputs
inputs = {
"prompt01": prompt, # 段落类型,无字符限制
"pho01": { # Workflow 需要的图片参数
"type": "image",
"transfer_method": "local_file",
"upload_file_id": file_id
}
}
payload = {
"inputs": inputs,
"response_mode": "blocking",
"user": self.user_id
}
status_code, response_text = http_post_json(
url, headers, payload, self.timeout
)
if status_code >= 400:
raise RuntimeError(f"Dify 工作流失败 ({status_code}): {response_text}")
# Step 3: 提取返回的 JSON 内容Workflow 返回在 data.outputs.text
response_json = json.loads(response_text)
# Workflow API 返回格式: {"data": {"outputs": {"text": "..."}}}
raw_text = response_json.get("data", {}).get("outputs", {}).get("text", "")
except Exception as exc:
latency = time.time() - start
return RecognitionResult(
image_name=image_path.name,
image_path=str(image_path),
api_type=self.api_type,
model_name=self.model,
latency_seconds=latency,
analyzed_at=analyzed_at,
raw_response="",
error=str(exc),
)
# Step 4: 解析 JSON 结果(复用现有逻辑)
latency = time.time() - start
parsed, parse_error = extract_json_from_response(raw_text)
detected_text = normalize_list(parsed.get("detected_text")) if parsed else []
detected_objects = normalize_list(parsed.get("detected_objects")) if parsed else []
sensitive_items = normalize_list(parsed.get("sensitive_items")) if parsed else []
summary = (parsed or {}).get("summary", "")
confidence = (parsed or {}).get("confidence", "")
return RecognitionResult(
image_name=image_path.name,
image_path=str(image_path),
detected_text=detected_text,
detected_objects=detected_objects,
sensitive_items=sensitive_items,
summary=str(summary),
confidence=str(confidence),
raw_response=raw_text,
api_type=self.api_type,
model_name=self.model,
latency_seconds=latency,
analyzed_at=analyzed_at,
error=parse_error,
)
# ----------- 主执行逻辑 -----------
def list_image_files(image_dir: Path, recursive: bool = False) -> List[Path]:
"""列出待处理的图片文件"""
if recursive:
candidates = (p for p in image_dir.rglob("*") if p.is_file())
else:
candidates = (p for p in image_dir.iterdir() if p.is_file())
files = [p for p in candidates if p.suffix.lower() in SUPPORTED_EXTENSIONS]
return sorted(files)
def create_client(args: argparse.Namespace) -> BaseVisionClient:
"""根据用户配置创建合适的客户端"""
api_type = (args.api_type or os.getenv("LLM_API_TYPE") or "openai").lower()
model = args.model or os.getenv("OPENAI_MODEL") or "gpt-4o-mini"
if api_type == "mock" or args.mock:
return MockVisionClient()
if api_type == "openai":
model = args.model or os.getenv("OPENAI_MODEL") or "gpt-4o-mini"
api_key = os.getenv("OPENAI_API_KEY", "")
base_url = os.getenv("OPENAI_BASE_URL")
return OpenAIVisionClient(api_key=api_key, model=model, timeout=args.timeout, base_url=base_url)
if api_type == "dmx":
model = args.model or os.getenv("DMX_MODEL") or "gpt-5-mini"
api_key = os.getenv("DMX_API_KEY", "")
base_url = os.getenv("DMX_BASE_URL")
return DMXVisionClient(api_key=api_key, model=model, timeout=args.timeout, base_url=base_url)
if api_type == "anthropic":
model = args.model or os.getenv("ANTHROPIC_MODEL") or "claude-3-5-sonnet-20241022"
api_key = os.getenv("ANTHROPIC_API_KEY", "")
base_url = os.getenv("ANTHROPIC_BASE_URL")
return AnthropicVisionClient(api_key=api_key, model=model, timeout=args.timeout, base_url=base_url)
if api_type == "dify":
model = args.model or os.getenv("DIFY_MODEL") or "dify-chatflow"
api_key = os.getenv("DIFY_API_KEY", "")
base_url = os.getenv("DIFY_BASE_URL")
user_id = os.getenv("DIFY_USER_ID") or "default-user"
return DifyVisionClient(api_key=api_key, model=model,
timeout=args.timeout, base_url=base_url,
user_id=user_id)
raise ValueError(f"不支持的 api_type: {api_type}")
def load_prompt(prompt_file: Optional[str]) -> str:
if prompt_file:
return Path(prompt_file).read_text(encoding="utf-8").strip()
env_prompt = os.getenv("VISION_ANALYSIS_PROMPT")
return (env_prompt or DEFAULT_ANALYSIS_PROMPT).strip()
def run_batch_recognition(args: argparse.Namespace) -> List[RecognitionResult]:
image_dir = Path(args.image_dir).resolve()
if not image_dir.exists():
raise FileNotFoundError(f"图片目录不存在: {image_dir}")
output_path = Path(args.output).resolve()
output_path.parent.mkdir(parents=True, exist_ok=True)
prompt = load_prompt(args.prompt_file)
client = create_client(args)
images = list_image_files(image_dir, recursive=args.recursive)
if args.offset:
images = images[args.offset :]
if args.limit:
images = images[: args.limit]
if not images:
print("未找到任何图片文件。")
return []
max_workers = max(1, args.max_workers)
debug = args.debug
# 简化的输出头部
if not debug:
print(f"- 开始处理 {len(images)} 张图片 | API: {client.api_type} ({client.model})")
else:
print("=" * 60)
print(f"图片目录: {image_dir}")
print(f"输出文件: {output_path}")
print(f"匹配到 {len(images)} 张图片,将使用 {client.api_type} ({client.model}) 执行识别")
print(f"并发 worker 数: {max_workers}")
print("=" * 60)
results: List[RecognitionResult] = []
total_images = len(images)
if max_workers == 1:
# 串行处理 - 使用进度条
pbar = tqdm(images, desc="处理进度", unit="", disable=debug)
for idx, image_path in enumerate(pbar, 1):
if debug:
print(f"[{idx}/{total_images}] 识别 {image_path.name} ...")
result = client.analyze_image(image_path, prompt)
results.append(result)
if debug:
if result.is_success:
print(
f" ✓ 成功,文字 {len(result.detected_text)} 行,物品 {len(result.detected_objects)} 项,"
f"耗时 {result.latency_seconds:.1f}s"
)
else:
print(f" ✗ 失败: {result.error}")
else:
# 精简输出:更新进度条描述
status = "" if result.is_success else ""
pbar.set_postfix_str(f"{status} {image_path.name[:20]}...")
if args.sleep > 0:
time.sleep(args.sleep)
pbar.close()
else:
if args.sleep > 0:
print("提示: 并发模式下忽略 --sleep 节流参数。")
result_lookup: Dict[Path, RecognitionResult] = {}
# 并发处理 - 使用进度条
pbar = tqdm(total=total_images, desc="处理进度", unit="", disable=debug)
with ThreadPoolExecutor(max_workers=max_workers) as executor:
future_map = {
executor.submit(client.analyze_image, image_path, prompt): image_path
for image_path in images
}
for completed_idx, future in enumerate(as_completed(future_map), 1):
image_path = future_map[future]
result = future.result()
result_lookup[image_path] = result
if debug:
print(f"[{completed_idx}/{total_images}] 识别 {image_path.name} ...")
if result.is_success:
print(
f" ✓ 成功,文字 {len(result.detected_text)} 行,物品 {len(result.detected_objects)} 项,"
f"耗时 {result.latency_seconds:.1f}s"
)
else:
print(f" ✗ 失败: {result.error}")
else:
# 精简输出:更新进度条
status = "" if result.is_success else ""
pbar.set_postfix_str(f"{status} {image_path.name[:20]}...")
pbar.update(1)
pbar.close()
results = [result_lookup[path] for path in images]
save_results(results, output_path)
return results
def save_results(results: Sequence[RecognitionResult], output_path: Path) -> None:
if not results:
print("没有可保存的结果。")
return
df = pd.DataFrame([res.to_row() for res in results])
df.to_excel(output_path, index=False, engine="openpyxl")
success_count = sum(1 for res in results if res.is_success)
print(f"\n识别完成,成功 {success_count}/{len(results)},结果已写入 {output_path}")
def parse_args(argv: Optional[Sequence[str]] = None) -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="批量图片多模态识别(文字 + 物品)",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
base_dir = Path(__file__).resolve().parent.parent
default_image_dir = base_dir / "data" / "images"
default_output = base_dir / "data" / "output" / "image_recognition_results.xlsx"
default_env_file = base_dir / ".env"
parser.add_argument("--image-dir", type=str, default=str(default_image_dir), help="图片目录")
parser.add_argument("--output", type=str, default=str(default_output), help="识别结果输出文件Excel")
parser.add_argument("--api-type", choices=["openai", "anthropic", "dmx", "dify", "mock"], help="选择使用的 API 类型")
parser.add_argument("--model", type=str, help="指定模型名称(默认读取环境变量)")
parser.add_argument("--prompt-file", type=str, help="自定义提示词文件")
parser.add_argument("--recursive", action="store_true", help="递归搜索子目录")
parser.add_argument("--limit", type=int, help="限制最大处理图片数")
parser.add_argument("--offset", type=int, default=0, help="跳过前 N 张图片")
parser.add_argument("--sleep", type=float, default=DEFAULT_SLEEP_SECONDS, help="每次请求后的等待秒数")
parser.add_argument("--timeout", type=int, default=DEFAULT_TIMEOUT, help="HTTP 请求超时")
parser.add_argument("--mock", action="store_true", help="强制使用模拟模式(无需 API")
parser.add_argument("--env-file", type=str, default=str(default_env_file), help="包含 API Key / Base URL / Model 配置的 .env 文件")
parser.add_argument("--max-workers", type=int, default=1, help="并行识别线程数")
parser.add_argument("--debug", action="store_true", help="启用详细输出(默认为精简模式)")
return parser.parse_args(argv)
def main(argv: Optional[Sequence[str]] = None) -> None:
args = parse_args(argv)
load_env_file(args.env_file)
try:
run_batch_recognition(args)
except Exception as exc:
print(f"执行失败: {exc}")
sys.exit(1)
if __name__ == "__main__":
main()

778
scripts/keyword_matcher.py Normal file
View File

@@ -0,0 +1,778 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
多模式关键词匹配工具(重构版)
- CAS号识别专注于 `CAS号` 列,支持多种格式(-, 空格, 无分隔符等)
- 模糊识别对所有候选文本含CAS进行容错匹配
"""
import argparse
import re
import sys
import time
import traceback
from abc import ABC, abstractmethod
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Set, Optional, Tuple
import pandas as pd
# 可选依赖
try:
import ahocorasick
HAS_AC = True
except ImportError:
HAS_AC = False
# ========== 常量定义 ==========
SEPARATOR = "|||"
MATCH_RESULT_SEPARATOR = " | "
PROGRESS_INTERVAL = 1000
DEFAULT_FUZZY_THRESHOLD = 85
# CAS号正则表达式
# 匹配格式2-7位数字 + 分隔符(可选) + 2位数字 + 分隔符(可选) + 1位数字
# 支持分隔符:- (连字符), 空格, . (点), _ (下划线), 或无分隔符
CAS_REGEX_PATTERN = r'\b(\d{2,7})[\s\-._]?(\d{2})[\s\-._]?(\d)\b'
MODE_KEYWORD_COLUMNS: Dict[str, List[str]] = {
"cas": ["CAS号"],
"exact": ["中文名", "英文名", "CAS号", "简称", "可能名称"],
}
MODE_LABELS = {
"cas": "CAS号识别",
"exact": "精确匹配",
}
# ========== 数据类 ==========
@dataclass
class MatchResult:
"""匹配结果数据类"""
matched_indices: List[int]
matched_keywords: List[str]
elapsed_time: float
total_rows: int
matcher_name: str
@property
def match_count(self) -> int:
return len(self.matched_indices)
@property
def match_rate(self) -> float:
return (self.match_count / self.total_rows * 100) if self.total_rows > 0 else 0.0
@property
def speed(self) -> float:
return (self.total_rows / self.elapsed_time) if self.elapsed_time > 0 else 0.0
# ========== 工具函数 ==========
def normalize_cas(cas_str: str) -> str:
"""
将各种格式的CAS号规范化为标准格式 XXX-XX-X
支持的输入格式:
- 123-45-6 (标准格式)
- 123 45 6 (空格分隔)
- 123.45.6 (点分隔)
- 123_45_6 (下划线分隔)
- 12345 6 或 1234 56 (部分分隔)
- 123456 (无分隔符,仅当总长度正确时)
返回标准格式的CAS号如果无法解析则返回原字符串
"""
if not cas_str or not isinstance(cas_str, str):
return str(cas_str)
# 移除所有非数字字符,只保留数字
digits_only = re.sub(r'[^\d]', '', cas_str)
# CAS号至少需要5位数字最短格式XX-XX-X
if len(digits_only) < 5:
return cas_str
# 重新格式化为标准格式前n-3位-中间2位-最后1位
# 例如123456 -> 1234-5-6, 12345 -> 123-4-5
return f"{digits_only[:-3]}-{digits_only[-3:-1]}-{digits_only[-1]}"
def extract_cas_numbers(text: str, pattern: str = CAS_REGEX_PATTERN) -> Set[str]:
"""
从文本中提取所有CAS号并规范化
参数:
text: 待搜索的文本
pattern: CAS号正则表达式
返回规范化后的CAS号集合
"""
if not text:
return set()
matches = re.finditer(pattern, str(text))
cas_numbers = set()
for match in matches:
# 提取完整匹配
raw_cas = match.group(0)
# 规范化
normalized = normalize_cas(raw_cas)
cas_numbers.add(normalized)
return cas_numbers
def split_value(value: str, separator: str) -> List[str]:
"""将单元格内容拆分为多个候选关键词"""
if separator and separator in value:
parts = value.split(separator)
else:
parts = [value]
return [part.strip() for part in parts if part and part.strip()]
def load_keywords_for_mode(
df: pd.DataFrame,
mode: str,
separator: str = SEPARATOR
) -> Set[str]:
"""根据模式加载关键词集合"""
mode_lower = mode.lower()
if mode_lower not in MODE_KEYWORD_COLUMNS:
raise ValueError(f"不支持的模式: {mode}")
target_columns = MODE_KEYWORD_COLUMNS[mode_lower]
available_columns = [col for col in target_columns if col in df.columns]
missing_columns = [col for col in target_columns if col not in df.columns]
if not available_columns:
raise ValueError(
f"模式 '{mode_lower}' 需要的列 {target_columns} 均不存在,"
f"当前可用列: {df.columns.tolist()}"
)
if missing_columns:
print(f"警告: 以下列在关键词文件中缺失: {missing_columns}")
keywords: Set[str] = set()
for column in available_columns:
for value in df[column].dropna():
value_str = str(value).strip()
if not value_str or value_str in ['#N/A#', 'nan', 'None']:
continue
for token in split_value(value_str, separator):
# CAS模式下规范化CAS号
if mode_lower == "cas":
token = normalize_cas(token)
keywords.add(token)
label = MODE_LABELS.get(mode_lower, mode_lower)
print(f"模式「{label}」共加载 {len(keywords)} 个候选关键词,来源列: {available_columns}")
if not keywords:
print(f"警告: 模式 '{mode_lower}' 未加载到任何关键词!")
return keywords
def show_progress(
current: int,
total: int,
start_time: float,
matched_count: int,
interval: int = PROGRESS_INTERVAL
) -> None:
"""显示处理进度"""
if (current + 1) % interval == 0:
elapsed = time.time() - start_time
speed = (current + 1) / elapsed
print(f"已处理 {current + 1}/{total} 行,速度: {speed:.1f} 行/秒,匹配到 {matched_count}")
# ========== 匹配器基类(策略模式) ==========
class KeywordMatcher(ABC):
"""关键词匹配器抽象基类"""
def __init__(self, name: str):
self.name = name
def match(
self,
df: pd.DataFrame,
keywords: Set[str],
text_column: str
) -> MatchResult:
"""执行匹配(模板方法)"""
print(f"开始匹配(使用{self.name}...")
self._prepare(keywords)
matched_indices = []
matched_keywords_list = []
start_time = time.time()
for idx, text in enumerate(df[text_column]):
if pd.isna(text):
continue
text_str = str(text)
matches = self._match_single_text(text_str, keywords)
if matches:
matched_indices.append(idx)
formatted = self._format_matches(matches)
matched_keywords_list.append(formatted)
show_progress(idx, len(df), start_time, len(matched_indices))
elapsed = time.time() - start_time
return MatchResult(
matched_indices=matched_indices,
matched_keywords=matched_keywords_list,
elapsed_time=elapsed,
total_rows=len(df),
matcher_name=self.name
)
def _prepare(self, keywords: Set[str]) -> None:
"""预处理(子类可选实现)"""
pass
@abstractmethod
def _match_single_text(self, text: str, keywords: Set[str]) -> Set[str]:
"""匹配单条文本(子类必须实现)"""
pass
def _format_matches(self, matches: Set[str]) -> str:
"""格式化匹配结果(子类可重写)"""
return MATCH_RESULT_SEPARATOR.join(sorted(matches))
# ========== 具体匹配器实现 ==========
class AhoCorasickMatcher(KeywordMatcher):
"""Aho-Corasick 自动机匹配器"""
def __init__(self):
super().__init__("Aho-Corasick 自动机")
self.automaton = None
def _prepare(self, keywords: Set[str]) -> None:
"""构建自动机"""
if not HAS_AC:
raise RuntimeError("pyahocorasick 未安装")
print("正在构建Aho-Corasick自动机...")
self.automaton = ahocorasick.Automaton()
for keyword in keywords:
self.automaton.add_word(keyword, keyword)
self.automaton.make_automaton()
print("自动机构建完成")
def _match_single_text(self, text: str, keywords: Set[str]) -> Set[str]:
"""使用自动机匹配"""
matched = set()
for end_index, keyword in self.automaton.iter(text):
matched.add(keyword)
return matched
class SetMatcher(KeywordMatcher):
"""标准集合匹配器"""
def __init__(self):
super().__init__("标准集合匹配")
def _match_single_text(self, text: str, keywords: Set[str]) -> Set[str]:
"""使用集合成员检查匹配"""
return {kw for kw in keywords if kw in text}
class CASRegexMatcher(KeywordMatcher):
"""CAS号正则表达式匹配器支持多种格式"""
def __init__(self):
super().__init__("CAS号正则匹配支持多种格式")
self.pattern = re.compile(CAS_REGEX_PATTERN)
def _match_single_text(self, text: str, keywords: Set[str]) -> Set[str]:
"""
使用正则表达式提取CAS号规范化后与关键词库比对
流程:
1. 用正则提取文本中所有可能的CAS号
2. 将提取的CAS号规范化为标准格式
3. 与关键词库(已规范化)进行比对
"""
# 从文本中提取并规范化CAS号
found_cas_numbers = extract_cas_numbers(text, self.pattern)
# 与关键词库求交集
matched = found_cas_numbers & keywords
return matched
class RegexExactMatcher(KeywordMatcher):
"""正则表达式精确匹配器(支持词边界)"""
def __init__(self):
super().__init__("正则表达式精确匹配(词边界)")
self.pattern = None
def _prepare(self, keywords: Set[str]) -> None:
"""构建正则表达式模式"""
print("正在构建正则表达式模式...")
# 转义所有特殊字符,并使用词边界 \b 确保完整词匹配
escaped_keywords = [re.escape(kw) for kw in keywords]
# 构建正则模式:\b(keyword1|keyword2|...)\b
# 词边界确保不会匹配到部分词
pattern_str = r'\b(' + '|'.join(escaped_keywords) + r')\b'
self.pattern = re.compile(pattern_str)
print(f"正则模式构建完成,共 {len(keywords)} 个关键词")
def _match_single_text(self, text: str, keywords: Set[str]) -> Set[str]:
"""使用正则表达式精确匹配"""
if not self.pattern:
return set()
# 查找所有匹配项
matches = self.pattern.findall(text)
return set(matches)
# ========== 匹配器工厂 ==========
def create_matcher(algorithm: str, fuzzy_threshold: int = DEFAULT_FUZZY_THRESHOLD, mode: str = None) -> KeywordMatcher:
"""
根据算法类型和模式创建匹配器
参数:
algorithm: 匹配算法 (auto, set, exact)
fuzzy_threshold: 已废弃,保留仅为向后兼容
mode: 检测模式 (cas, exact),用于选择特定匹配器
"""
algorithm_lower = algorithm.lower()
# CAS模式使用CAS正则匹配器
if mode and mode.lower() == "cas":
return CASRegexMatcher()
# exact模式或exact算法使用正则精确匹配器
if mode and mode.lower() == "exact":
return RegexExactMatcher()
if algorithm_lower == "exact":
return RegexExactMatcher()
elif algorithm_lower == "set":
return SetMatcher()
elif algorithm_lower == "auto":
if HAS_AC:
return AhoCorasickMatcher()
else:
print("警告: 未安装 pyahocorasick使用标准匹配方法")
return SetMatcher()
else:
raise ValueError(f"不支持的算法: {algorithm}")
# ========== 结果处理 ==========
def save_results(
df: pd.DataFrame,
result: MatchResult,
output_file: str
) -> Optional[pd.DataFrame]:
"""保存匹配结果到 Excel"""
if result.match_count == 0:
print("未找到任何匹配")
return None
result_df = df.iloc[result.matched_indices].copy()
result_df.insert(0, "匹配到的关键词", result.matched_keywords)
result_df.to_excel(output_file, index=False, engine='openpyxl')
print(f"结果已保存到: {output_file}")
return result_df
def print_statistics(result: MatchResult) -> None:
"""打印匹配统计信息"""
print(f"\n{'='*60}")
print(f"匹配完成!")
print(f"匹配模式: {result.matcher_name}")
print(f"总耗时: {result.elapsed_time:.2f}")
print(f"处理速度: {result.speed:.1f} 行/秒")
print(f"匹配到 {result.match_count} 行数据 ({result.match_rate:.2f}%)")
print(f"{'='*60}\n")
def preview_results(result_df: pd.DataFrame, num_rows: int = 5) -> None:
"""预览匹配结果"""
if result_df.empty:
return
print(f"{num_rows}行匹配结果预览:")
print("=" * 80)
pd.set_option('display.max_columns', None)
pd.set_option('display.width', None)
pd.set_option('display.max_colwidth', 100)
print(result_df.head(num_rows))
print("=" * 80)
print(f"\n✓ 总共匹配到 {len(result_df)} 行数据")
# ========== 主流程 ==========
def perform_matching(
df: pd.DataFrame,
keywords: Set[str],
text_column: str,
output_file: str,
algorithm: str = "auto",
mode: str = None
) -> Optional[pd.DataFrame]:
"""执行完整的匹配流程"""
# 验证列存在
if text_column not in df.columns:
print(f"可用列名: {df.columns.tolist()}")
raise ValueError(f"'{text_column}' 不存在")
print(f"文本文件共有 {len(df)} 行数据\n")
# 创建匹配器并执行匹配
matcher = create_matcher(algorithm, mode=mode)
result = matcher.match(df, keywords, text_column)
# 输出统计信息
print_statistics(result)
# 保存结果
result_df = save_results(df, result, output_file)
return result_df
def process_single_mode(
keywords_df: pd.DataFrame,
text_df: pd.DataFrame,
mode: str,
text_column: str,
output_file: Path,
separator: str = SEPARATOR,
save_to_file: bool = True
) -> Optional[pd.DataFrame]:
"""
处理单个检测模式
返回:匹配结果 DataFrame包含原始索引
"""
mode_lower = mode.lower()
label = MODE_LABELS.get(mode_lower, mode_lower)
print(f"\n{'='*60}")
print(f">>> 正在执行识别模式: {label}")
print(f"{'='*60}")
# 加载关键词
keywords = load_keywords_for_mode(keywords_df, mode_lower, separator=separator)
# 显示关键词样例
sample_keywords = list(keywords)[:10]
if sample_keywords:
print(f"\n{label} - 关键词样例前10个:")
for idx, kw in enumerate(sample_keywords, 1):
print(f" {idx}. {kw}")
print()
# 选择算法
algorithm = "exact" if mode_lower == "exact" else "auto"
# 执行匹配(如果不需要保存到文件,传入临时路径)
temp_output = str(output_file) if save_to_file else "/tmp/temp_match.xlsx"
result_df = perform_matching(
df=text_df,
keywords=keywords,
text_column=text_column,
output_file=temp_output,
algorithm=algorithm,
mode=mode_lower # 传递模式参数
)
# 如果不保存文件,删除临时文件
if not save_to_file and result_df is not None:
import os
if os.path.exists(temp_output):
os.remove(temp_output)
# 预览结果
if result_df is not None and save_to_file:
preview_results(result_df)
# 添加模式标识列
if result_df is not None:
result_df['匹配模式'] = label
return result_df
def run_multiple_modes(
keywords_file: Path,
text_file: Path,
output_file: Path,
text_column: str,
modes: List[str],
separator: str = SEPARATOR
) -> None:
"""运行多个检测模式,合并结果到单一文件"""
# 验证文件存在
if not keywords_file.exists():
raise FileNotFoundError(f"找不到关键词文件: {keywords_file}")
if not text_file.exists():
raise FileNotFoundError(f"找不到文本文件: {text_file}")
# 加载数据
print(f"正在加载关键词文件: {keywords_file}")
keywords_df = pd.read_excel(keywords_file)
print(f"可用列: {keywords_df.columns.tolist()}\n")
print(f"正在加载文本文件: {text_file}")
text_df = pd.read_excel(text_file)
print(f"文本列: {text_column}\n")
# 验证模式
if not modes:
raise ValueError("modes 不能为空,请至少指定一个模式")
for mode in modes:
if mode.lower() not in MODE_KEYWORD_COLUMNS:
raise ValueError(f"不支持的识别模式: {mode}")
# 收集所有模式的匹配结果
all_results = []
multiple_modes = len(modes) > 1
for mode in modes:
mode_lower = mode.lower()
# 处理该模式(不保存到单独文件)
result_df = process_single_mode(
keywords_df=keywords_df,
text_df=text_df,
mode=mode_lower,
text_column=text_column,
output_file=output_file, # 这个参数在 save_to_file=False 时不使用
separator=separator,
save_to_file=False # 不保存到单独文件
)
if result_df is not None and not result_df.empty:
all_results.append(result_df)
# 合并所有结果
if not all_results:
print("\n所有模式均未匹配到数据")
return
print(f"\n{'='*60}")
print("正在合并所有模式的匹配结果...")
print(f"{'='*60}")
# 合并结果
merged_df = merge_mode_results(all_results, text_df)
# 保存合并后的结果
merged_df.to_excel(output_file, index=False, engine='openpyxl')
print(f"\n合并结果已保存到: {output_file}")
print(f" 总匹配行数: {len(merged_df)}")
# 统计每个模式的贡献
print(f"\n各模式匹配统计:")
for mode_result in all_results:
mode_name = mode_result['匹配模式'].iloc[0]
count = len(mode_result)
print(f" {mode_name:20s}: {count:4d}")
# 预览合并结果
print(f"\n{'='*60}")
print("合并结果预览前5行:")
print(f"{'='*60}")
preview_results(merged_df, num_rows=5)
def merge_mode_results(
results: List[pd.DataFrame],
original_df: pd.DataFrame
) -> pd.DataFrame:
"""
合并多个模式的匹配结果
策略:
1. 按原始数据行索引合并
2. 如果同一行被多个模式匹配,合并关键词和模式标识
3. 保留原始数据的所有列
"""
if not results:
return pd.DataFrame()
# 记录每个原始行索引的匹配信息
row_matches = {}
for result_df in results:
for idx, row in result_df.iterrows():
if idx not in row_matches:
# 首次出现该行
row_matches[idx] = {
'keywords': row['匹配到的关键词'],
'modes': [row['匹配模式']]
}
else:
# 该行已被其他模式匹配过,合并关键词和模式
existing_keywords = row_matches[idx]['keywords']
new_keywords = row['匹配到的关键词']
# 合并关键词(去重)
all_keywords = set(str(existing_keywords).split(' | ')) | set(str(new_keywords).split(' | '))
row_matches[idx]['keywords'] = ' | '.join(sorted(all_keywords))
# 添加模式标识
row_matches[idx]['modes'].append(row['匹配模式'])
# 构建最终结果
final_indices = list(row_matches.keys())
final_df = original_df.loc[final_indices].copy()
# 添加合并后的列
final_df.insert(0, '匹配到的关键词', [row_matches[idx]['keywords'] for idx in final_indices])
final_df.insert(1, '匹配模式', [' + '.join(row_matches[idx]['modes']) for idx in final_indices])
# 按原始索引排序
final_df = final_df.sort_index()
return final_df
# ========== 命令行接口 ==========
def parse_args():
"""解析命令行参数"""
parser = argparse.ArgumentParser(
description='多模式关键词匹配工具',
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
示例:
# 使用默认配置(两种模式)
python keyword_matcher.py
# 仅执行 CAS 号识别
python keyword_matcher.py -m cas
# 仅执行精确匹配
python keyword_matcher.py -m exact
# 指定自定义文件路径
python keyword_matcher.py -k ../data/input/keywords.xlsx -t ../data/input/text.xlsx
"""
)
parser.add_argument(
'-k', '--keywords',
type=str,
help='关键词文件路径 (默认: ../data/input/keywords.xlsx)'
)
parser.add_argument(
'-t', '--text',
type=str,
help='文本文件路径 (默认: ../data/input/clickin_text_img.xlsx)'
)
parser.add_argument(
'-o', '--output',
type=str,
help='输出文件路径 (默认: ../data/output/keyword_matched_results.xlsx)'
)
parser.add_argument(
'-c', '--text-column',
type=str,
default='文本',
help='文本列名 (默认: 文本)'
)
parser.add_argument(
'-m', '--modes',
nargs='+',
choices=['cas', 'exact'],
default=['cas', 'exact'],
help='识别模式 (默认: cas exact)'
)
parser.add_argument(
'--separator',
type=str,
default=SEPARATOR,
help=f'关键词分隔符 (默认: {SEPARATOR})'
)
return parser.parse_args()
def main():
"""主函数"""
args = parse_args()
# 确定文件路径
base_dir = Path(__file__).resolve().parent
keywords_file = Path(args.keywords) if args.keywords else (
base_dir.parent / "data" / "input" / "keywords.xlsx"
)
text_file = Path(args.text) if args.text else (
base_dir.parent / "data" / "input" / "clickin_text_img.xlsx"
)
output_file = Path(args.output) if args.output else (
base_dir.parent / "data" / "output" / "keyword_matched_results.xlsx"
)
# 显示依赖库状态
print("=" * 60)
print("依赖库状态:")
print(f" pyahocorasick: {'已安装 ✓' if HAS_AC else '未安装 ✗'}")
print("=" * 60)
print()
if not HAS_AC:
print("提示: 安装 pyahocorasick 可获得 5x 性能提升: pip install pyahocorasick\n")
try:
run_multiple_modes(
keywords_file=keywords_file,
text_file=text_file,
output_file=output_file,
text_column=args.text_column,
modes=args.modes,
separator=args.separator
)
print("\n" + "=" * 60)
print("✓ 所有模式处理完成!")
print("=" * 60)
except Exception as e:
print(f"\n错误: {e}")
traceback.print_exc()
sys.exit(1)
if __name__ == "__main__":
main()

154
scripts/run.sh Normal file
View File

@@ -0,0 +1,154 @@
#!/bin/bash
# 批量图片分析管理脚本 - 极简版
BASE_DIR="$HOME/bin/p_anlz"
BATCH_SCRIPT="$BASE_DIR/scripts/batch_run.sh"
LOG_DIR="$BASE_DIR/logs"
MAIN_LOG="$LOG_DIR/batch_run.log"
# 显示帮助
show_help() {
cat << EOF
批量图片分析管理工具
用法: $0 <命令>
命令:
start 后台启动任务
stop 停止任务
status 查看运行状态
log 查看实时日志
help 显示帮助
示例:
$0 start # 启动任务
$0 log # 查看日志
$0 status # 查看状态
修改参数:
API_TYPE=openai $0 start
MAX_WORKERS=3 $0 start
TIMEOUT=120 $0 start
EOF
}
# 启动任务
start_task() {
if pgrep -f "batch_run.sh" > /dev/null; then
echo "任务已在运行中"
exit 1
fi
echo "启动批量分析任务..."
# 导出环境变量
export API_TYPE="${API_TYPE:-dify}"
export MAX_WORKERS="${MAX_WORKERS:-1}"
export TIMEOUT="${TIMEOUT:-90}"
# 后台运行
nohup bash "$BATCH_SCRIPT" > /dev/null 2>&1 &
sleep 2
if pgrep -f "batch_run.sh" > /dev/null; then
echo "✓ 任务已启动"
echo ""
echo "查看日志: $0 log"
echo "查看状态: $0 status"
else
echo "✗ 启动失败"
exit 1
fi
}
# 停止任务
stop_task() {
if ! pgrep -f "batch_run.sh" > /dev/null; then
echo "任务未运行"
exit 0
fi
echo "停止任务..."
pkill -f "batch_run.sh"
pkill -f "image_batch_recognizer.py"
sleep 1
if ! pgrep -f "batch_run.sh" > /dev/null; then
echo "✓ 任务已停止"
else
echo "✗ 停止失败,尝试强制终止..."
pkill -9 -f "batch_run.sh"
pkill -9 -f "image_batch_recognizer.py"
fi
}
# 查看状态
show_status() {
echo "任务状态:"
echo ""
if pgrep -f "batch_run.sh" > /dev/null; then
echo "✓ 运行中"
echo ""
echo "进程信息:"
ps aux | grep -E "batch_run.sh|image_batch_recognizer.py" | grep -v grep
echo ""
# 显示当前进度
if [ -f "$MAIN_LOG" ]; then
echo "最新日志:"
tail -10 "$MAIN_LOG"
fi
else
echo "○ 未运行"
if [ -f "$MAIN_LOG" ]; then
echo ""
echo "上次运行结果:"
tail -15 "$MAIN_LOG" | grep -A 10 "批量分析完成"
fi
fi
}
# 查看日志
show_log() {
if [ ! -f "$MAIN_LOG" ]; then
echo "日志文件不存在: $MAIN_LOG"
exit 1
fi
echo "实时查看日志 (Ctrl+C 退出):"
echo ""
tail -f "$MAIN_LOG"
}
# 主函数
main() {
case "${1:-help}" in
start)
start_task
;;
stop)
stop_task
;;
status)
show_status
;;
log)
show_log
;;
help|--help|-h)
show_help
;;
*)
echo "未知命令: $1"
echo ""
show_help
exit 1
;;
esac
}
main "$@"

View File

@@ -0,0 +1,349 @@
#!/bin/bash
# 后台启动批量分析任务的包装脚本
# 支持 nohup、screen、tmux 三种方式
set -e
BASE_DIR="$HOME/bin/p_anlz"
SCRIPT_DIR="$BASE_DIR/scripts"
BATCH_SCRIPT="$SCRIPT_DIR/batch_analyze_all_folders.sh"
LOG_DIR="$BASE_DIR/logs"
NOHUP_LOG="$LOG_DIR/nohup.out"
# 颜色输出
GREEN='\033[0;32m'
YELLOW='\033[1;33m'
BLUE='\033[0;34m'
NC='\033[0m'
# 显示使用说明
show_usage() {
cat << EOF
批量图片分析后台运行工具
用法:
$0 [方式]
方式选项:
nohup 使用 nohup 后台运行(默认,最简单)
screen 使用 screen 会话运行(可重新连接)
tmux 使用 tmux 会话运行(推荐,功能最强)
status 查看当前运行状态
logs 实时查看主日志
progress 查看处理进度
live 实时查看当前 Python 程序输出(推荐)
示例:
$0 nohup # 使用 nohup 启动
$0 screen # 使用 screen 启动
$0 tmux # 使用 tmux 启动
$0 status # 查看状态
$0 logs # 查看主日志
$0 progress # 查看进度
$0 live # 实时查看 Python 输出(含进度条)
后台运行后如何查看:
- nohup 方式: tail -f $NOHUP_LOG
- screen 方式: screen -r batch_analyze
- tmux 方式: tmux attach -t batch_analyze
- 实时输出: $0 live
EOF
}
# 检查脚本是否存在
check_script() {
if [ ! -f "$BATCH_SCRIPT" ]; then
echo -e "${RED}错误: 批处理脚本不存在: $BATCH_SCRIPT${NC}"
exit 1
fi
chmod +x "$BATCH_SCRIPT"
}
# 检查是否已经在运行
check_running() {
if pgrep -f "batch_analyze_all_folders.sh" > /dev/null; then
echo -e "${YELLOW}警告: 批量分析任务已在运行中${NC}"
echo "进程信息:"
ps aux | grep "batch_analyze_all_folders.sh" | grep -v grep
echo ""
read -p "是否继续启动新任务? (y/N): " confirm
if [ "$confirm" != "y" ] && [ "$confirm" != "Y" ]; then
echo "已取消"
exit 0
fi
fi
}
# 使用 nohup 启动
start_nohup() {
echo -e "${BLUE}使用 nohup 启动批量分析任务...${NC}"
mkdir -p "$LOG_DIR"
nohup bash "$BATCH_SCRIPT" > "$NOHUP_LOG" 2>&1 &
local pid=$!
echo -e "${GREEN}✓ 任务已在后台启动 (PID: $pid)${NC}"
echo ""
echo "查看日志:"
echo " tail -f $NOHUP_LOG"
echo ""
echo "查看进度:"
echo " $0 progress"
echo ""
echo "停止任务:"
echo " kill $pid"
}
# 使用 screen 启动
start_screen() {
if ! command -v screen &> /dev/null; then
echo -e "${YELLOW}screen 未安装,请先安装: sudo apt-get install screen${NC}"
exit 1
fi
echo -e "${BLUE}使用 screen 启动批量分析任务...${NC}"
# 检查是否已有同名 session
if screen -list | grep -q "batch_analyze"; then
echo -e "${YELLOW}警告: screen 会话 'batch_analyze' 已存在${NC}"
read -p "是否删除旧会话并创建新会话? (y/N): " confirm
if [ "$confirm" = "y" ] || [ "$confirm" = "Y" ]; then
screen -S batch_analyze -X quit 2>/dev/null || true
else
echo "已取消"
exit 0
fi
fi
screen -dmS batch_analyze bash "$BATCH_SCRIPT"
echo -e "${GREEN}✓ 任务已在 screen 会话中启动${NC}"
echo ""
echo "重新连接到会话:"
echo " screen -r batch_analyze"
echo ""
echo "分离会话: Ctrl+A, 然后按 D"
echo ""
echo "查看所有会话:"
echo " screen -ls"
echo ""
echo "终止会话:"
echo " screen -S batch_analyze -X quit"
}
# 使用 tmux 启动
start_tmux() {
if ! command -v tmux &> /dev/null; then
echo -e "${YELLOW}tmux 未安装,请先安装: sudo apt-get install tmux${NC}"
exit 1
fi
echo -e "${BLUE}使用 tmux 启动批量分析任务...${NC}"
# 检查是否已有同名 session
if tmux has-session -t batch_analyze 2>/dev/null; then
echo -e "${YELLOW}警告: tmux 会话 'batch_analyze' 已存在${NC}"
read -p "是否删除旧会话并创建新会话? (y/N): " confirm
if [ "$confirm" = "y" ] || [ "$confirm" = "Y" ]; then
tmux kill-session -t batch_analyze 2>/dev/null || true
else
echo "已取消"
exit 0
fi
fi
tmux new-session -d -s batch_analyze "bash $BATCH_SCRIPT"
echo -e "${GREEN}✓ 任务已在 tmux 会话中启动${NC}"
echo ""
echo "重新连接到会话:"
echo " tmux attach -t batch_analyze"
echo ""
echo "分离会话: Ctrl+B, 然后按 D"
echo ""
echo "查看所有会话:"
echo " tmux ls"
echo ""
echo "终止会话:"
echo " tmux kill-session -t batch_analyze"
}
# 查看运行状态
show_status() {
echo -e "${BLUE}批量分析任务状态:${NC}"
echo ""
# 检查进程
if pgrep -f "batch_analyze_all_folders.sh" > /dev/null; then
echo -e "${GREEN}✓ 任务正在运行${NC}"
echo ""
echo "进程信息:"
ps aux | grep "batch_analyze_all_folders.sh" | grep -v grep
else
echo -e "${YELLOW}○ 任务未运行${NC}"
fi
echo ""
# 检查 screen 会话
if command -v screen &> /dev/null; then
if screen -list | grep -q "batch_analyze"; then
echo -e "${GREEN}✓ Screen 会话存在${NC}"
screen -list | grep "batch_analyze"
fi
fi
echo ""
# 检查 tmux 会话
if command -v tmux &> /dev/null; then
if tmux has-session -t batch_analyze 2>/dev/null; then
echo -e "${GREEN}✓ Tmux 会话存在${NC}"
tmux list-sessions | grep "batch_analyze"
fi
fi
}
# 查看日志
show_logs() {
local main_log="$LOG_DIR/batch_analyze.log"
if [ -f "$main_log" ]; then
echo -e "${BLUE}实时查看主日志 (Ctrl+C 退出):${NC}"
echo ""
tail -f "$main_log"
elif [ -f "$NOHUP_LOG" ]; then
echo -e "${BLUE}实时查看 nohup 日志 (Ctrl+C 退出):${NC}"
echo ""
tail -f "$NOHUP_LOG"
else
echo -e "${YELLOW}未找到日志文件${NC}"
echo "日志目录: $LOG_DIR"
fi
}
# 查看进度
show_progress() {
local progress_file="$LOG_DIR/progress.txt"
if [ ! -f "$progress_file" ]; then
echo -e "${YELLOW}进度文件不存在: $progress_file${NC}"
exit 1
fi
echo -e "${BLUE}批量分析进度:${NC}"
echo ""
# 显示进度文件内容
cat "$progress_file"
echo ""
echo -e "${BLUE}统计信息:${NC}"
# 统计各状态数量
local total=$(grep -c "SUCCESS\|FAILED\|SKIPPED" "$progress_file" || echo "0")
local success=$(grep -c "SUCCESS" "$progress_file" || echo "0")
local failed=$(grep -c "FAILED" "$progress_file" || echo "0")
local skipped=$(grep -c "SKIPPED" "$progress_file" || echo "0")
echo "总计: $total"
echo -e "${GREEN}成功: $success${NC}"
echo -e "${YELLOW}失败: $failed${NC}"
echo -e "${BLUE}跳过: $skipped${NC}"
# 显示最后更新时间
if [ -f "$progress_file" ]; then
local last_update=$(stat -c %y "$progress_file" 2>/dev/null || stat -f "%Sm" "$progress_file" 2>/dev/null)
echo ""
echo "最后更新: $last_update"
fi
}
# 实时查看当前 Python 程序输出
show_live() {
local main_log="$LOG_DIR/batch_analyze.log"
echo -e "${BLUE}实时查看 Python 程序输出 (Ctrl+C 退出)${NC}"
echo ""
# 检查任务是否在运行
if ! pgrep -f "batch_analyze_all_folders.sh" > /dev/null; then
echo -e "${YELLOW}任务未运行${NC}"
exit 1
fi
# 查找当前正在处理的文件夹
if [ -f "$main_log" ]; then
local current_folder=$(grep -oP "处理文件夹: \K.*" "$main_log" | tail -1)
if [ -n "$current_folder" ]; then
local folder_log="$LOG_DIR/${current_folder}.log"
if [ -f "$folder_log" ]; then
echo -e "${GREEN}当前处理: $current_folder${NC}"
echo -e "${BLUE}日志文件: $folder_log${NC}"
echo ""
echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━"
# 实时跟踪文件夹日志(显示 Python 程序的完整输出)
tail -f "$folder_log"
else
echo -e "${YELLOW}等待日志文件生成...${NC}"
sleep 2
show_live
fi
else
echo -e "${YELLOW}等待任务开始...${NC}"
sleep 2
show_live
fi
else
echo -e "${YELLOW}主日志文件不存在${NC}"
exit 1
fi
}
# 主函数
main() {
local mode="${1:-nohup}"
case "$mode" in
nohup)
check_script
check_running
start_nohup
;;
screen)
check_script
check_running
start_screen
;;
tmux)
check_script
check_running
start_tmux
;;
status)
show_status
;;
logs)
show_logs
;;
progress)
show_progress
;;
live)
show_live
;;
-h|--help|help)
show_usage
;;
*)
echo "未知选项: $mode"
echo ""
show_usage
exit 1
;;
esac
}
main "$@"