init: first upload
This commit is contained in:
4
.gitignore
vendored
Normal file
4
.gitignore
vendored
Normal file
@@ -0,0 +1,4 @@
|
||||
.DS_Store
|
||||
.env
|
||||
data/*
|
||||
batch_output/*
|
||||
28
AGENTS.md
Normal file
28
AGENTS.md
Normal file
@@ -0,0 +1,28 @@
|
||||
# Repository Guidelines
|
||||
|
||||
## Project Structure & Module Organization
|
||||
The repository centers on four Python entry points under `scripts/`: `quick_start.py`, `match_cas_numbers.py`, `keyword_matcher.py`, and `expand_keywords_with_llm.py`. Source data lives in `data/input/`, generated spreadsheets land in `data/output/`, and supporting evidence files reside in `data/images/`. Keep API credentials in `.env` files that copy the keys documented in `config.env.example`.
|
||||
|
||||
## Build, Test, and Development Commands
|
||||
Set up a sandboxed interpreter before running anything:
|
||||
```bash
|
||||
python3 -m venv .venv && source .venv/bin/activate
|
||||
pip install pandas openpyxl pyahocorasick
|
||||
```
|
||||
Core routines:
|
||||
- `cd scripts && python3 quick_start.py` validates the entire ingest → match → export flow with bundled sample sheets.
|
||||
- `python3 match_cas_numbers.py` reads `data/input/clickin_text_img.xlsx` and writes normalized CAS matches to `data/output/cas_matched_results_final.xlsx`.
|
||||
- `python3 keyword_matcher.py` now orchestrates the three detection modes (CAS 列、文本精确列、模糊容错) and writes per-mode reports such as `keyword_matched_results_cas.xlsx`; install `pyahocorasick` for the fast exact path and `rapidfuzz` for the fuzzy path.
|
||||
- `python3 expand_keywords_with_llm.py ../data/input/keywords.xlsx -m` mocks the LLM expansion; remove `-m` only after exporting `OPENAI_API_KEY` or `ANTHROPIC_API_KEY`.
|
||||
|
||||
## Coding Style & Naming Conventions
|
||||
Follow PEP 8: four-space indentation, `snake_case` functions, Upper_Snake constants. Module-level configuration such as column names or separators (`SEPARATOR = "|||"`) should be defined once and imported where needed. Preserve spreadsheet column spelling because pandas filters depend on exact casing. When expanding functionality, keep CLI argument names lowercase with hyphenated long options for consistency.
|
||||
|
||||
## Testing Guidelines
|
||||
Testing is empirical: rerun `quick_start.py` and the specific script you edited, then compare row counts, unique IDs, and timing stats against previous outputs. Use lightweight fixtures copied from `data/input/` to isolate regressions, and treat script warnings or pandas SettingWithCopy notices as failures until they are explained.
|
||||
|
||||
## Commit & Pull Request Guidelines
|
||||
Git history is not distributed with this bundle, so default to Conventional Commit subjects (`feat: add cas normalizer`, `fix: guard empty rows`). Each PR should list the commands executed, describe the input data used, and reference any README or `data/` updates. Link tracking tickets, attach screenshots of spreadsheet diffs when UI proof is needed, and keep binary artifacts out of the diff by adding them to `.gitignore` if necessary.
|
||||
|
||||
## Security & Configuration Tips
|
||||
Never commit real API keys or sensitive spreadsheets; point reviewers to sanitized snippets instead. Load secrets with `export $(cat config.env.example | xargs)` or a dedicated `.env` loader rather than embedding them in code. Generated Excel files may contain investigative evidence, so confine them to `data/output/` and scrub PII when sharing externally.
|
||||
269
CLAUDE.md
Normal file
269
CLAUDE.md
Normal file
@@ -0,0 +1,269 @@
|
||||
# CLAUDE.md
|
||||
|
||||
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
|
||||
|
||||
## Project Overview
|
||||
|
||||
This is a drug risk monitoring and data processing system for detecting controlled substances in text and image data from e-commerce platforms, darknet sources, and social media.
|
||||
|
||||
**Core Capabilities:**
|
||||
1. **CAS Number Matching**: Extract and match chemical CAS numbers from text using regex patterns (supports multiple formats)
|
||||
2. **Keyword Matching**: High-performance multi-mode keyword matching (fuzzy, CAS)
|
||||
3. **Keyword Expansion**: LLM-powered expansion of chemical/drug names to include variants, abbreviations, and aliases
|
||||
|
||||
## Running Scripts
|
||||
|
||||
All scripts must be run from the `scripts/` directory:
|
||||
|
||||
```bash
|
||||
cd scripts/
|
||||
|
||||
# Quick start (recommended for testing)
|
||||
python3 quick_start.py
|
||||
|
||||
# CAS number matching
|
||||
python3 match_cas_numbers.py
|
||||
|
||||
# Multi-mode keyword matching (default: both modes)
|
||||
python3 keyword_matcher.py
|
||||
|
||||
# Single mode matching
|
||||
python3 keyword_matcher.py -m cas # CAS number only
|
||||
python3 keyword_matcher.py -m fuzzy --threshold 90 # Fuzzy matching only
|
||||
|
||||
# Use larger keyword database
|
||||
python3 keyword_matcher.py -k ../data/input/keyword_all.xlsx
|
||||
|
||||
# Keyword expansion (mock mode, no API)
|
||||
python3 expand_keywords_with_llm.py -m
|
||||
|
||||
# Keyword expansion (with OpenAI API)
|
||||
export OPENAI_API_KEY="sk-..."
|
||||
python3 expand_keywords_with_llm.py ../data/input/keywords.xlsx
|
||||
```
|
||||
|
||||
## Dependencies
|
||||
|
||||
**Required:**
|
||||
```bash
|
||||
pip install pandas openpyxl
|
||||
```
|
||||
|
||||
**Optional (for fuzzy keyword matching):**
|
||||
```bash
|
||||
pip install rapidfuzz
|
||||
```
|
||||
|
||||
**Optional (for LLM keyword expansion):**
|
||||
```bash
|
||||
pip install openai anthropic
|
||||
```
|
||||
|
||||
## Data Flow Architecture
|
||||
|
||||
All scripts use relative paths from `scripts/` directory:
|
||||
|
||||
```
|
||||
Input: ../data/input/
|
||||
clickin_text_img.xlsx (2779 rows: text + image paths)
|
||||
keywords.xlsx (22 rows, basic keyword list)
|
||||
keyword_all.xlsx (1659 rows, 1308 unique CAS numbers)
|
||||
|
||||
Output: ../data/output/
|
||||
keyword_matched_results.xlsx (multi-mode merged results)
|
||||
cas_matched_results_final.xlsx
|
||||
test_keywords_expanded_rows.xlsx
|
||||
|
||||
Images: ../data/images/ (1955 JPG files, 84MB)
|
||||
```
|
||||
|
||||
**Processing Pipeline:**
|
||||
```
|
||||
Raw data collection -> Text extraction (OCR/LLM) ->
|
||||
Feature matching (CAS/keywords) -> Data cleaning ->
|
||||
Risk determination
|
||||
```
|
||||
|
||||
## Key Technical Details
|
||||
|
||||
### 1. CAS Number Matching (`match_cas_numbers.py`)
|
||||
- Supports multiple formats: `123-45-6`, `123 45 6`, `123 - 45 - 6`
|
||||
- Auto-normalizes to standard format `XXX-XX-X`
|
||||
- Uses regex pattern: `\b\d{2,7}[\s\-]+\d{2}[\s\-]+\d\b`
|
||||
- Dual-mode: `"regex"` for CAS matching, `"keywords"` for keyword matching
|
||||
|
||||
### 2. Keyword Matching (`keyword_matcher.py`) - REFACTORED
|
||||
|
||||
**Architecture:**
|
||||
- Strategy Pattern with `KeywordMatcher` base class
|
||||
- Concrete matchers: `CASRegexMatcher`, `FuzzyMatcher`
|
||||
- Factory Pattern for matcher creation
|
||||
- Dataclass-based result handling
|
||||
|
||||
**Two Detection Modes:**
|
||||
|
||||
1. **CAS Number Recognition (CAS号识别)**
|
||||
- Uses `CASRegexMatcher` with comprehensive regex pattern
|
||||
- Supports formats: `123-45-6`, `123 45 6`, `12345 6`, `123456`, `123.45.6`, `123_45_6`
|
||||
- Auto-normalizes all formats to standard `XXX-XX-X`
|
||||
- Regex: `\b(\d{2,7})[\s\-._]?(\d{2})[\s\-._]?(\d)\b`
|
||||
- Extracts CAS from text, normalizes, compares with keyword database
|
||||
- Source columns: `CAS号`
|
||||
|
||||
2. **Fuzzy Matching (模糊匹配)**
|
||||
- Uses `FuzzyMatcher` with RapidFuzz library
|
||||
- Default threshold: 85 (configurable via `--threshold`)
|
||||
- Scoring function: `partial_ratio`
|
||||
- Source columns: `中文名`, `英文名`, `CAS号`, `简称`, `可能名称`
|
||||
- **Note**: Fuzzy matching covers all cases that exact matching would find, making exact mode redundant
|
||||
|
||||
**Multi-Mode Result Merging:**
|
||||
- Automatically merges results from multiple modes
|
||||
- Deduplicates by row index
|
||||
- Combines matched keywords with ` | ` separator
|
||||
- Adds `匹配模式` column showing which modes matched (e.g., "CAS号识别 + 模糊匹配")
|
||||
|
||||
**Command-Line Options:**
|
||||
```bash
|
||||
-k, --keywords # Path to keywords file (default: ../data/input/keywords.xlsx)
|
||||
-t, --text # Path to text file (default: ../data/input/clickin_text_img.xlsx)
|
||||
-o, --output # Output file path (default: ../data/output/keyword_matched_results.xlsx)
|
||||
-c, --text-column # Column containing text to search (default: "文本")
|
||||
-m, --modes # Modes to run: cas, fuzzy (default: both)
|
||||
--threshold # Fuzzy matching threshold 0-100 (default: 85)
|
||||
--separator # Keyword separator in cells (default: "|||")
|
||||
```
|
||||
|
||||
**Performance:**
|
||||
- With keyword_all.xlsx (1308 CAS numbers):
|
||||
- CAS mode: 255 rows matched (9.18%)
|
||||
- Fuzzy mode: 513 rows matched (18.46%)
|
||||
- Merged (both modes): ~516 unique rows
|
||||
|
||||
**Uses `|||` separator:**
|
||||
- Chemical names contain commas, hyphens, slashes, semicolons
|
||||
- Triple pipe avoids conflicts with chemical nomenclature
|
||||
- Example: `甲基苯丙胺|||冰毒|||Methamphetamine|||MA`
|
||||
|
||||
### 3. Keyword Expansion (`expand_keywords_with_llm.py`)
|
||||
- Expands Chinese names, English names, abbreviations
|
||||
- Supports OpenAI and Anthropic APIs
|
||||
- Mock mode available for testing without API costs
|
||||
- Output formats: compact (single row with `|||` separators) or expanded (one name per row)
|
||||
|
||||
## Configuration Patterns
|
||||
|
||||
Scripts use command-line arguments (keyword_matcher.py) or in-file configuration blocks:
|
||||
|
||||
```python
|
||||
# ========== Configuration ==========
|
||||
keywords_file = "../data/input/keywords.xlsx"
|
||||
text_file = "../data/input/clickin_text_img.xlsx"
|
||||
keywords_column = "中文名"
|
||||
text_column = "文本"
|
||||
separator = "|||"
|
||||
output_file = "../data/output/results.xlsx"
|
||||
# =============================
|
||||
```
|
||||
|
||||
## Excel File Schemas
|
||||
|
||||
**Input - clickin_text_img.xlsx:**
|
||||
- Columns: `文本` (text), image paths, metadata
|
||||
- 2779 rows of scraped e-commerce/social media data
|
||||
|
||||
**Input - keywords.xlsx:**
|
||||
- Columns: `中文名`, `英文名`, `CAS号`, `简称`, `备注`, `可能名称`
|
||||
- `可能名称` contains multiple keywords separated by `|||`
|
||||
- 22 rows (small test dataset)
|
||||
|
||||
**Input - keyword_all.xlsx:**
|
||||
- Same schema as keywords.xlsx
|
||||
- 1659 rows with 1308 unique CAS numbers
|
||||
- Production keyword database
|
||||
|
||||
**Output - Multi-mode matched (keyword_matched_results.xlsx):**
|
||||
- Adds columns:
|
||||
- `匹配到的关键词` (matched keywords, separated by ` | `)
|
||||
- `匹配模式` (matching modes, e.g., "CAS号识别 + 模糊匹配")
|
||||
- Preserves all original columns
|
||||
- Deduplicated across all modes
|
||||
|
||||
**Output - CAS matched:**
|
||||
- Adds column: `匹配到的CAS号` (matched CAS numbers)
|
||||
- Preserves all original columns
|
||||
- Typical match rate: ~9-11% (255-303/2779 rows)
|
||||
|
||||
## Common Modifications
|
||||
|
||||
**To change input/output paths:**
|
||||
Use command-line arguments for `keyword_matcher.py`:
|
||||
```bash
|
||||
python3 keyword_matcher.py -k /path/to/keywords.xlsx -t /path/to/text.xlsx -o /path/to/output.xlsx
|
||||
```
|
||||
|
||||
Or edit the configuration block in other scripts' `main()` function.
|
||||
|
||||
**To switch between CAS and keyword matching:**
|
||||
In `match_cas_numbers.py`, change `match_mode = "regex"` to `match_mode = "keywords"`.
|
||||
|
||||
In `keyword_matcher.py`, use `-m` flag:
|
||||
```bash
|
||||
python3 keyword_matcher.py -m cas # CAS only
|
||||
python3 keyword_matcher.py -m fuzzy # Fuzzy only
|
||||
```
|
||||
|
||||
**To adjust fuzzy matching sensitivity:**
|
||||
```bash
|
||||
python3 keyword_matcher.py -m fuzzy --threshold 90 # Stricter (fewer matches)
|
||||
python3 keyword_matcher.py -m fuzzy --threshold 70 # More lenient (more matches)
|
||||
```
|
||||
|
||||
**To use different LLM APIs:**
|
||||
```bash
|
||||
# OpenAI (default)
|
||||
python3 expand_keywords_with_llm.py input.xlsx
|
||||
|
||||
# Anthropic
|
||||
python3 expand_keywords_with_llm.py input.xlsx -a anthropic
|
||||
```
|
||||
|
||||
## Code Architecture Highlights
|
||||
|
||||
### keyword_matcher.py Design Patterns
|
||||
|
||||
1. **Strategy Pattern**: Different matching algorithms (`KeywordMatcher` subclasses)
|
||||
2. **Template Method**: Common matching workflow in base class `match()` method
|
||||
3. **Factory Pattern**: `create_matcher()` selects appropriate matcher
|
||||
4. **Dependency Injection**: Optional dependency (rapidfuzz) handled gracefully
|
||||
|
||||
**Class Hierarchy:**
|
||||
```
|
||||
KeywordMatcher (ABC)
|
||||
├── CASRegexMatcher # Regex-based CAS number extraction
|
||||
└── FuzzyMatcher # RapidFuzz partial_ratio matching
|
||||
```
|
||||
|
||||
**Data Flow:**
|
||||
```
|
||||
1. Load keywords -> load_keywords_for_mode()
|
||||
2. Create matcher -> create_matcher()
|
||||
3. Match text -> matcher.match()
|
||||
├── _prepare() (build automaton, etc.)
|
||||
└── For each row:
|
||||
├── _match_single_text()
|
||||
└── _format_matches()
|
||||
4. Save results -> save_results()
|
||||
5. If multiple modes -> merge_mode_results()
|
||||
```
|
||||
|
||||
## Data Sensitivity
|
||||
|
||||
This codebase handles sensitive data related to controlled substances monitoring. The data includes:
|
||||
- Chemical compound names (Chinese and English)
|
||||
- CAS registry numbers
|
||||
- Image data from suspected illegal substance trading platforms
|
||||
- All data is for legitimate law enforcement/research purposes
|
||||
|
||||
Do not commit actual data files or API keys to version control.
|
||||
- to memorize
|
||||
211
README.md
Normal file
211
README.md
Normal file
@@ -0,0 +1,211 @@
|
||||
# 涉毒风险监测数据处理系统
|
||||
|
||||
从文本和图片中识别化学品 CAS 号、关键词,并进行多模态风险分析。
|
||||
|
||||
## 快速开始
|
||||
|
||||
```bash
|
||||
# 安装依赖
|
||||
pip install pandas openpyxl
|
||||
|
||||
# 进入脚本目录
|
||||
cd scripts/
|
||||
|
||||
# 关键词匹配
|
||||
python3 keyword_matcher.py
|
||||
|
||||
# 图片识别(模拟模式,无需 API)
|
||||
python3 image_batch_recognizer.py --mock --limit 5
|
||||
```
|
||||
|
||||
## 功能 1:关键词匹配
|
||||
|
||||
从文本中识别 CAS 号和关键词。
|
||||
|
||||
**两种匹配模式:**
|
||||
- **CAS 号识别**:正则提取并标准化为 `XXX-XX-X` 格式
|
||||
- **精确匹配**:匹配中文名、英文名、简称等
|
||||
|
||||
### 基本用法
|
||||
|
||||
```bash
|
||||
cd scripts/
|
||||
|
||||
# 默认运行两种模式
|
||||
python3 keyword_matcher.py
|
||||
|
||||
# 仅 CAS 号识别
|
||||
python3 keyword_matcher.py -m cas
|
||||
|
||||
# 仅精确匹配
|
||||
python3 keyword_matcher.py -m exact
|
||||
|
||||
# 自定义文件路径
|
||||
python3 keyword_matcher.py \
|
||||
-k ../data/input/keyword_all.xlsx \
|
||||
-t ../data/input/clickin_text_img.xlsx \
|
||||
-o ../data/output/results.xlsx
|
||||
```
|
||||
|
||||
### 参数说明
|
||||
|
||||
| 参数 | 说明 | 默认值 |
|
||||
|------|------|--------|
|
||||
| `-k, --keywords` | 关键词文件 | `../data/input/keywords.xlsx` |
|
||||
| `-t, --text` | 文本文件 | `../data/input/clickin_text_img.xlsx` |
|
||||
| `-o, --output` | 输出文件 | `../data/output/keyword_matched_results.xlsx` |
|
||||
| `-c, --text-column` | 文本列名 | `文本` |
|
||||
| `-m, --modes` | 匹配模式 | `cas exact` |
|
||||
| `--separator` | 关键词分隔符 | `\|\|\|` |
|
||||
|
||||
### 输出说明
|
||||
|
||||
每种模式生成独立文件:
|
||||
- `keyword_matched_results_cas.xlsx` - CAS 号匹配结果
|
||||
- `keyword_matched_results_exact.xlsx` - 精确匹配结果
|
||||
|
||||
输出列:
|
||||
- `匹配到的关键词` - 匹配的关键词列表
|
||||
- `匹配模式` - 使用的匹配模式
|
||||
|
||||
---
|
||||
|
||||
## 功能 2:图片批量识别
|
||||
|
||||
调用大模型 API 识别图片中的文字、物品和风险信息。
|
||||
|
||||
### 基本用法
|
||||
|
||||
```bash
|
||||
cd scripts/
|
||||
|
||||
# 模拟模式(无需 API,用于测试)
|
||||
python3 image_batch_recognizer.py --mock --limit 5
|
||||
|
||||
# 使用 OpenAI API
|
||||
python3 image_batch_recognizer.py --api-type openai --limit 10
|
||||
|
||||
# 使用 DMX API
|
||||
python3 image_batch_recognizer.py --api-type dmx --limit 10
|
||||
|
||||
# 并行处理
|
||||
python3 image_batch_recognizer.py --api-type openai --max-workers 3
|
||||
```
|
||||
|
||||
### 参数说明
|
||||
|
||||
| 参数 | 说明 | 默认值 |
|
||||
|------|------|--------|
|
||||
| `--image-dir` | 图片目录 | `../data/images` |
|
||||
| `--output` | 输出文件 | `../data/output/image_recognition_results.xlsx` |
|
||||
| `--api-type` | API 类型 | 从 `.env` 读取 |
|
||||
| `--model` | 模型名称 | 从 `.env` 读取 |
|
||||
| `--limit` | 最大处理数 | 无限制 |
|
||||
| `--offset` | 跳过前 N 张 | 0 |
|
||||
| `--max-workers` | 并行线程数 | 1 |
|
||||
| `--mock` | 模拟模式 | 否 |
|
||||
| `--recursive` | 递归子目录 | 否 |
|
||||
|
||||
### API 配置
|
||||
|
||||
复制 `.env` 配置文件并填写:
|
||||
|
||||
```bash
|
||||
cp config.env.example .env
|
||||
```
|
||||
|
||||
`.env` 示例:
|
||||
```
|
||||
OPENAI_API_KEY=sk-...
|
||||
OPENAI_MODEL=gpt-4o-mini
|
||||
|
||||
DMX_API_KEY=sk-dmx-...
|
||||
DMX_BASE_URL=https://www.dmxapi.cn
|
||||
DMX_MODEL=gpt-5-mini
|
||||
|
||||
LLM_API_TYPE=openai
|
||||
```
|
||||
|
||||
### 输出说明
|
||||
|
||||
输出 Excel 包含以下列:
|
||||
- `detected_text` - 识别的文字
|
||||
- `detected_objects` - 物品描述
|
||||
- `sensitive_items` - 敏感要素
|
||||
- `summary` - 风险摘要
|
||||
- `confidence` - 置信度
|
||||
|
||||
---
|
||||
|
||||
## 目录结构
|
||||
|
||||
```
|
||||
20251126_s2/
|
||||
├── scripts/
|
||||
│ ├── keyword_matcher.py # 关键词匹配
|
||||
│ ├── image_batch_recognizer.py # 图片识别
|
||||
│ ├── run.sh # 批处理管理
|
||||
│ └── run_batch_background.sh # 后台运行
|
||||
├── data/
|
||||
│ ├── input/ # 输入数据
|
||||
│ │ ├── clickin_text_img.xlsx # 文本数据
|
||||
│ │ └── keywords.xlsx # 关键词库
|
||||
│ ├── output/ # 输出结果
|
||||
│ └── images/ # 图片文件
|
||||
├── .env # API 配置
|
||||
└── config.env.example # 配置模板
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 批处理管理
|
||||
|
||||
使用 `run.sh` 管理后台任务:
|
||||
|
||||
```bash
|
||||
cd scripts/
|
||||
|
||||
./run.sh start # 启动任务
|
||||
./run.sh stop # 停止任务
|
||||
./run.sh status # 查看状态
|
||||
./run.sh log # 实时日志
|
||||
```
|
||||
|
||||
设置参数:
|
||||
```bash
|
||||
API_TYPE=openai MAX_WORKERS=3 ./run.sh start
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 依赖安装
|
||||
|
||||
```bash
|
||||
# 必需
|
||||
pip install pandas openpyxl
|
||||
|
||||
# 可选(提升性能)
|
||||
pip install pyahocorasick # 关键词匹配加速
|
||||
pip install tqdm # 进度条
|
||||
pip install requests # HTTP 请求
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 常见问题
|
||||
|
||||
**Q: 如何提升匹配速度?**
|
||||
安装 `pyahocorasick`,精确模式自动使用 Aho-Corasick 算法加速。
|
||||
|
||||
**Q: 没有 API Key 能测试吗?**
|
||||
使用 `--mock` 参数运行模拟模式。
|
||||
|
||||
**Q: 输出的分隔符能改吗?**
|
||||
使用 `--separator` 参数,默认 `|||` 不与化学名称冲突。
|
||||
|
||||
---
|
||||
|
||||
## 技术支持
|
||||
|
||||
- Python 3.7+
|
||||
- 查看帮助:`python3 script.py -h`
|
||||
29
config.env.example
Normal file
29
config.env.example
Normal file
@@ -0,0 +1,29 @@
|
||||
#!/bin/bash
|
||||
# API 配置示例
|
||||
# 复制此文件为 .env 并填入你的 API Key
|
||||
|
||||
# OpenAI API Key
|
||||
# 获取地址: https://platform.openai.com/api-keys
|
||||
export OPENAI_API_KEY="sk-your-openai-api-key-here"
|
||||
export OPENAI_BASE_URL="https://api.openai.com"
|
||||
export OPENAI_MODEL="gpt-4o-mini" # 可根据权限切换
|
||||
|
||||
# Anthropic Claude API Key
|
||||
# 获取地址: https://console.anthropic.com/
|
||||
export ANTHROPIC_API_KEY="sk-ant-your-anthropic-api-key-here"
|
||||
export ANTHROPIC_BASE_URL="https://api.anthropic.com"
|
||||
export ANTHROPIC_MODEL="claude-3-5-sonnet-20241022"
|
||||
|
||||
# API 选择(可选值:openai, anthropic, dmx, mock)
|
||||
export LLM_API_TYPE="openai"
|
||||
|
||||
# DMX API (https://www.dmxapi.cn) - 图片分析上传方式
|
||||
export DMX_API_KEY="sk-dmx-your-api-key"
|
||||
export DMX_BASE_URL="https://www.dmxapi.cn"
|
||||
export DMX_MODEL="gpt-5-mini"
|
||||
|
||||
# Dify API 配置
|
||||
export DIFY_API_KEY="app-your-actual-api-key-here"
|
||||
export DIFY_BASE_URL="https://your-dify-instance.com" # 替换为实际 Dify 实例地址
|
||||
export DIFY_USER_ID="default-user" # 可选,默认为 "default-user"
|
||||
export DIFY_MODEL="dify-chatflow" # 可选,仅用于记录,不影响实际调用
|
||||
776
scripts/image_batch_recognizer.py
Normal file
776
scripts/image_batch_recognizer.py
Normal file
@@ -0,0 +1,776 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
批量图片识别脚本
|
||||
|
||||
功能:
|
||||
1. 遍历指定目录下的图片文件
|
||||
2. 将图片上传至大模型 API(OpenAI、Anthropic,或本地模拟模式)
|
||||
3. 输出包含图片中文本、物品描述、风险概述的 Excel 报表
|
||||
|
||||
使用示例:
|
||||
cd scripts
|
||||
python3 image_batch_recognizer.py --limit 5 --api-type openai
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import base64
|
||||
import json
|
||||
import mimetypes
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from typing import Any, Dict, List, Optional, Sequence, Tuple
|
||||
|
||||
import pandas as pd
|
||||
|
||||
try:
|
||||
import requests # type: ignore
|
||||
# 禁用 SSL 警告(用于自签名证书)
|
||||
import urllib3
|
||||
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
||||
except ImportError: # pragma: no cover - requests 是可选依赖
|
||||
requests = None
|
||||
|
||||
try:
|
||||
from tqdm import tqdm # type: ignore
|
||||
except ImportError:
|
||||
# 如果没有安装 tqdm,提供一个简单的替代
|
||||
class tqdm:
|
||||
def __init__(self, iterable=None, total=None, desc=None, **kwargs):
|
||||
self.iterable = iterable
|
||||
self.total = total or (len(iterable) if iterable else 0)
|
||||
self.desc = desc
|
||||
self.n = 0
|
||||
|
||||
def __iter__(self):
|
||||
for item in self.iterable:
|
||||
yield item
|
||||
self.n += 1
|
||||
|
||||
def update(self, n=1):
|
||||
self.n += n
|
||||
|
||||
def set_postfix_str(self, s):
|
||||
# 简化版本,不做任何操作
|
||||
pass
|
||||
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
SEPARATOR = "|||"
|
||||
DEFAULT_SLEEP_SECONDS = 0.0
|
||||
DEFAULT_TIMEOUT = 90
|
||||
DEFAULT_MAX_TOKENS = 800
|
||||
SUPPORTED_EXTENSIONS = {".jpg", ".jpeg", ".png", ".webp"}
|
||||
|
||||
DEFAULT_ANALYSIS_PROMPT = """
|
||||
你是一名涉毒风控场景下的多模态分析师。请识别我提供的图片内容,并输出 JSON。
|
||||
JSON 字段要求:
|
||||
{
|
||||
"detected_text": ["按行列出的文字内容,若无文字返回空数组"],
|
||||
"detected_objects": ["出现的主要物品或场景要素,中文描述"],
|
||||
"sensitive_items": ["可疑化学品、药物、工具等,如无则为空数组"],
|
||||
"summary": "2-3 句话概括图片与潜在风险,中文",
|
||||
"confidence": "High | Medium | Low"
|
||||
}
|
||||
只返回 JSON,不要额外说明。
|
||||
""".strip()
|
||||
|
||||
|
||||
# ----------- 数据结构 -----------
|
||||
@dataclass
|
||||
class RecognitionResult:
|
||||
image_name: str
|
||||
image_path: str
|
||||
detected_text: List[str] = field(default_factory=list)
|
||||
detected_objects: List[str] = field(default_factory=list)
|
||||
sensitive_items: List[str] = field(default_factory=list)
|
||||
summary: str = ""
|
||||
confidence: str = ""
|
||||
raw_response: str = ""
|
||||
api_type: str = ""
|
||||
model_name: str = ""
|
||||
latency_seconds: float = 0.0
|
||||
analyzed_at: str = ""
|
||||
error: Optional[str] = None
|
||||
|
||||
def to_row(self) -> Dict[str, Any]:
|
||||
"""转换为写入 Excel 的行数据"""
|
||||
return {
|
||||
"image_name": self.image_name,
|
||||
"image_path": self.image_path,
|
||||
"detected_text": "\n".join(self.detected_text).strip(),
|
||||
"detected_objects": SEPARATOR.join(self.detected_objects),
|
||||
"sensitive_items": SEPARATOR.join(self.sensitive_items),
|
||||
"summary": self.summary,
|
||||
"confidence": self.confidence,
|
||||
"latency_seconds": round(self.latency_seconds, 2),
|
||||
"api_type": self.api_type,
|
||||
"model": self.model_name,
|
||||
"analyzed_at": self.analyzed_at,
|
||||
"error": self.error or "",
|
||||
"raw_response": self.raw_response,
|
||||
}
|
||||
|
||||
@property
|
||||
def is_success(self) -> bool:
|
||||
return self.error is None
|
||||
|
||||
|
||||
# ----------- 工具函数 -----------
|
||||
def encode_image(image_path: Path) -> Tuple[str, str]:
|
||||
"""以 base64 形式读取图片,同时推断 MIME 类型"""
|
||||
data = image_path.read_bytes()
|
||||
encoded = base64.b64encode(data).decode("utf-8")
|
||||
mime_type, _ = mimetypes.guess_type(image_path.name)
|
||||
return encoded, mime_type or "image/jpeg"
|
||||
|
||||
|
||||
def extract_json_from_response(text: str) -> Tuple[Optional[Dict[str, Any]], Optional[str]]:
|
||||
"""
|
||||
从模型返回的文本中提取 JSON。
|
||||
优先尝试直接解析,失败后截取首尾花括号进行再次解析。
|
||||
"""
|
||||
if not text:
|
||||
return None, "空响应"
|
||||
|
||||
text = text.strip()
|
||||
try:
|
||||
return json.loads(text), None
|
||||
except json.JSONDecodeError:
|
||||
start = text.find("{")
|
||||
end = text.rfind("}")
|
||||
if start != -1 and end != -1 and end > start:
|
||||
candidate = text[start:end + 1]
|
||||
try:
|
||||
return json.loads(candidate), None
|
||||
except json.JSONDecodeError as exc:
|
||||
return None, f"JSON 解析失败: {exc}"
|
||||
return None, "响应中未找到 JSON 对象"
|
||||
|
||||
|
||||
def normalize_list(value: Any) -> List[str]:
|
||||
"""将模型返回的字段整理为字符串列表"""
|
||||
if value is None:
|
||||
return []
|
||||
if isinstance(value, str):
|
||||
return [part.strip() for part in value.split(SEPARATOR) if part.strip()]
|
||||
if isinstance(value, (list, tuple, set)):
|
||||
return [str(item).strip() for item in value if str(item).strip()]
|
||||
return [str(value).strip()]
|
||||
|
||||
|
||||
def http_post_json(url: str, headers: Dict[str, str], payload: Dict[str, Any], timeout: int) -> Tuple[int, str]:
|
||||
"""发送 JSON POST 请求(优先使用 requests,回退到 urllib)"""
|
||||
body = json.dumps(payload)
|
||||
headers = {**headers, "Content-Type": "application/json"}
|
||||
|
||||
if requests:
|
||||
# 禁用 SSL 证书验证(用于自签名证书)
|
||||
response = requests.post(url, headers=headers, data=body, timeout=timeout, verify=False)
|
||||
return response.status_code, response.text
|
||||
|
||||
# requests 不可用时使用 urllib
|
||||
import urllib.error
|
||||
import urllib.request
|
||||
|
||||
req = urllib.request.Request(url, data=body.encode("utf-8"), headers=headers, method="POST")
|
||||
try:
|
||||
with urllib.request.urlopen(req, timeout=timeout) as resp:
|
||||
response_body = resp.read().decode("utf-8")
|
||||
return resp.getcode(), response_body
|
||||
except urllib.error.HTTPError as exc: # pragma: no cover - 网络错误场景
|
||||
return exc.code, exc.read().decode("utf-8")
|
||||
|
||||
|
||||
def http_post_multipart(url: str, headers: Dict[str, str],
|
||||
data: Dict[str, str], files: Dict[str, Any],
|
||||
timeout: int) -> Tuple[int, str]:
|
||||
"""发送 multipart/form-data POST 请求(用于 Dify 文件上传)"""
|
||||
if not requests:
|
||||
raise RuntimeError(
|
||||
"Dify 文件上传需要 requests 库,请安装: pip install requests"
|
||||
)
|
||||
|
||||
# 禁用 SSL 证书验证(用于自签名证书)
|
||||
response = requests.post(url, headers=headers, data=data,
|
||||
files=files, timeout=timeout, verify=False)
|
||||
return response.status_code, response.text
|
||||
|
||||
|
||||
def build_endpoint(base_url: Optional[str], path: str, default_base: str) -> str:
|
||||
"""根据 base_url 和 path 生成完整 API 地址"""
|
||||
base = (base_url or default_base).rstrip("/")
|
||||
path = path if path.startswith("/") else f"/{path}"
|
||||
return f"{base}{path}"
|
||||
|
||||
|
||||
def load_env_file(env_file: Optional[str]) -> Optional[Path]:
|
||||
"""从 .env 文件加载环境变量"""
|
||||
if not env_file:
|
||||
return None
|
||||
|
||||
env_path = Path(env_file).expanduser()
|
||||
if not env_path.exists():
|
||||
return None
|
||||
|
||||
for raw_line in env_path.read_text(encoding="utf-8").splitlines():
|
||||
line = raw_line.strip()
|
||||
if not line or line.startswith("#"):
|
||||
continue
|
||||
if line.startswith("export "):
|
||||
line = line[len("export "):].strip()
|
||||
if "=" not in line:
|
||||
continue
|
||||
key, value = line.split("=", 1)
|
||||
key = key.strip()
|
||||
value = value.strip().strip('"').strip("'")
|
||||
if not key:
|
||||
continue
|
||||
os.environ[key] = value
|
||||
|
||||
print(f"已从 {env_path} 加载环境变量")
|
||||
return env_path
|
||||
|
||||
|
||||
# ----------- 基础客户端 -----------
|
||||
class BaseVisionClient:
|
||||
def __init__(self, model: str, api_type: str, timeout: int = DEFAULT_TIMEOUT):
|
||||
self.model = model
|
||||
self.api_type = api_type
|
||||
self.timeout = timeout
|
||||
|
||||
def analyze_image(self, image_path: Path, prompt: str) -> RecognitionResult:
|
||||
encoded, mime_type = encode_image(image_path)
|
||||
start = time.time()
|
||||
analyzed_at = datetime.now().isoformat()
|
||||
|
||||
try:
|
||||
raw_text = self._send_request(
|
||||
image_base64=encoded,
|
||||
mime_type=mime_type,
|
||||
prompt=prompt,
|
||||
image_name=image_path.name,
|
||||
)
|
||||
except Exception as exc: # pragma: no cover - 网络错误较难稳定复现
|
||||
latency = time.time() - start
|
||||
return RecognitionResult(
|
||||
image_name=image_path.name,
|
||||
image_path=str(image_path),
|
||||
api_type=self.api_type,
|
||||
model_name=self.model,
|
||||
latency_seconds=latency,
|
||||
analyzed_at=analyzed_at,
|
||||
raw_response="",
|
||||
error=str(exc),
|
||||
)
|
||||
|
||||
latency = time.time() - start
|
||||
parsed, parse_error = extract_json_from_response(raw_text)
|
||||
|
||||
detected_text = normalize_list(parsed.get("detected_text")) if parsed else []
|
||||
detected_objects = normalize_list(parsed.get("detected_objects")) if parsed else []
|
||||
sensitive_items = normalize_list(parsed.get("sensitive_items")) if parsed else []
|
||||
summary = (parsed or {}).get("summary", "")
|
||||
confidence = (parsed or {}).get("confidence", "")
|
||||
|
||||
return RecognitionResult(
|
||||
image_name=image_path.name,
|
||||
image_path=str(image_path),
|
||||
detected_text=detected_text,
|
||||
detected_objects=detected_objects,
|
||||
sensitive_items=sensitive_items,
|
||||
summary=str(summary),
|
||||
confidence=str(confidence),
|
||||
raw_response=raw_text,
|
||||
api_type=self.api_type,
|
||||
model_name=self.model,
|
||||
latency_seconds=latency,
|
||||
analyzed_at=analyzed_at,
|
||||
error=parse_error,
|
||||
)
|
||||
|
||||
def _send_request(self, image_base64: str, mime_type: str, prompt: str, image_name: str) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
# ----------- OpenAI 客户端 -----------
|
||||
class OpenAIVisionClient(BaseVisionClient):
|
||||
def __init__(self, api_key: str, model: str, timeout: int, base_url: Optional[str] = None):
|
||||
super().__init__(model=model, api_type="openai", timeout=timeout)
|
||||
if not api_key:
|
||||
raise ValueError("OPENAI_API_KEY 未设置")
|
||||
self.api_key = api_key
|
||||
self.endpoint = build_endpoint(base_url, "/v1/chat/completions", "https://api.openai.com")
|
||||
|
||||
def _send_request(self, image_base64: str, mime_type: str, prompt: str, image_name: str) -> str:
|
||||
image_url = f"data:{mime_type};base64,{image_base64}"
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"temperature": 0,
|
||||
"max_tokens": DEFAULT_MAX_TOKENS,
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a multimodal analyst that returns structured JSON responses in Chinese.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": f"{prompt}\n图片文件名: {image_name}"},
|
||||
{"type": "image_url", "image_url": {"url": image_url}},
|
||||
],
|
||||
},
|
||||
],
|
||||
}
|
||||
headers = {"Authorization": f"Bearer {self.api_key}"}
|
||||
status_code, response_text = http_post_json(self.endpoint, headers, payload, self.timeout)
|
||||
|
||||
if status_code >= 400:
|
||||
raise RuntimeError(f"OpenAI API 调用失败 ({status_code}): {response_text}")
|
||||
|
||||
response_json = json.loads(response_text)
|
||||
content = response_json["choices"][0]["message"]["content"]
|
||||
if isinstance(content, list):
|
||||
text = "\n".join(part.get("text", "") for part in content if isinstance(part, dict))
|
||||
else:
|
||||
text = content
|
||||
return text.strip()
|
||||
|
||||
|
||||
# ----------- DMX (文档上传方式) 客户端 -----------
|
||||
class DMXVisionClient(BaseVisionClient):
|
||||
def __init__(self, api_key: str, model: str, timeout: int, base_url: Optional[str] = None):
|
||||
super().__init__(model=model, api_type="dmx", timeout=timeout)
|
||||
if not api_key:
|
||||
raise ValueError("DMX_API_KEY 未设置")
|
||||
self.api_key = api_key
|
||||
self.endpoint = build_endpoint(base_url, "/v1/chat/completions", "https://www.dmxapi.cn")
|
||||
|
||||
def _send_request(self, image_base64: str, mime_type: str, prompt: str, image_name: str) -> str:
|
||||
image_url = f"data:{mime_type};base64,{image_base64}"
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"temperature": 0.1,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": f"{prompt}\n图片文件名: {image_name}"},
|
||||
{"type": "image_url", "image_url": {"url": image_url}},
|
||||
],
|
||||
}
|
||||
],
|
||||
}
|
||||
headers = {"Authorization": f"Bearer {self.api_key}"}
|
||||
status_code, response_text = http_post_json(self.endpoint, headers, payload, self.timeout)
|
||||
if status_code >= 400:
|
||||
raise RuntimeError(f"DMX API 调用失败 ({status_code}): {response_text}")
|
||||
|
||||
response_json = json.loads(response_text)
|
||||
content = response_json["choices"][0]["message"]["content"]
|
||||
if isinstance(content, list):
|
||||
text = "\n".join(part.get("text", "") for part in content if isinstance(part, dict))
|
||||
else:
|
||||
text = content
|
||||
return text.strip()
|
||||
|
||||
|
||||
# ----------- Anthropic 客户端 -----------
|
||||
class AnthropicVisionClient(BaseVisionClient):
|
||||
def __init__(self, api_key: str, model: str, timeout: int, base_url: Optional[str] = None):
|
||||
super().__init__(model=model, api_type="anthropic", timeout=timeout)
|
||||
if not api_key:
|
||||
raise ValueError("ANTHROPIC_API_KEY 未设置")
|
||||
self.api_key = api_key
|
||||
self.endpoint = build_endpoint(base_url, "/v1/messages", "https://api.anthropic.com")
|
||||
|
||||
def _send_request(self, image_base64: str, mime_type: str, prompt: str, image_name: str) -> str:
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"max_tokens": DEFAULT_MAX_TOKENS,
|
||||
"system": "You are a multimodal analyst that responds with JSON only.",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": f"{prompt}\n图片文件名: {image_name}"},
|
||||
{
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": mime_type,
|
||||
"data": image_base64,
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
}
|
||||
headers = {
|
||||
"x-api-key": self.api_key,
|
||||
"anthropic-version": "2023-06-01",
|
||||
}
|
||||
status_code, response_text = http_post_json(self.endpoint, headers, payload, self.timeout)
|
||||
if status_code >= 400:
|
||||
raise RuntimeError(f"Anthropic API 调用失败 ({status_code}): {response_text}")
|
||||
|
||||
response_json = json.loads(response_text)
|
||||
text_parts = [part.get("text", "") for part in response_json.get("content", []) if part.get("type") == "text"]
|
||||
return "\n".join(text_parts).strip()
|
||||
|
||||
|
||||
# ----------- 模拟客户端(无 API 时使用) -----------
|
||||
class MockVisionClient(BaseVisionClient):
|
||||
def __init__(self):
|
||||
super().__init__(model="mock", api_type="mock", timeout=1)
|
||||
|
||||
def _send_request(self, image_base64: str, mime_type: str, prompt: str, image_name: str) -> str:
|
||||
del image_base64, mime_type, prompt # 未使用
|
||||
fake_json = {
|
||||
"detected_text": ["(模拟模式)未调用真实 API"],
|
||||
"detected_objects": ["sample-object-from-filename"],
|
||||
"sensitive_items": [],
|
||||
"summary": f"Mock 模式:演示 {image_name} 的返回格式。",
|
||||
"confidence": "Low",
|
||||
}
|
||||
return json.dumps(fake_json, ensure_ascii=False)
|
||||
|
||||
|
||||
# ----------- Dify Chatflow 客户端 -----------
|
||||
class DifyVisionClient(BaseVisionClient):
|
||||
"""Dify Chatflow API client for vision analysis"""
|
||||
|
||||
def __init__(self, api_key: str, model: str, timeout: int,
|
||||
base_url: Optional[str] = None, user_id: Optional[str] = None):
|
||||
super().__init__(model=model, api_type="dify", timeout=timeout)
|
||||
if not api_key:
|
||||
raise ValueError("DIFY_API_KEY 未设置")
|
||||
self.api_key = api_key
|
||||
self.base_url = (base_url or "https://dify.example.com").rstrip("/")
|
||||
self.user_id = user_id or "default-user"
|
||||
|
||||
def _upload_file(self, image_path: Path) -> str:
|
||||
"""上传图片到 Dify,返回文件ID (UUID)"""
|
||||
url = f"{self.base_url}/v1/files/upload"
|
||||
headers = {"Authorization": f"Bearer {self.api_key}"}
|
||||
data = {"user": self.user_id}
|
||||
|
||||
# 使用 multipart/form-data 上传
|
||||
mime_type, _ = mimetypes.guess_type(image_path.name)
|
||||
with open(image_path, "rb") as f:
|
||||
files = {"file": (image_path.name, f, mime_type or "image/jpeg")}
|
||||
status_code, response_text = http_post_multipart(
|
||||
url, headers, data, files, self.timeout
|
||||
)
|
||||
|
||||
if status_code >= 400:
|
||||
raise RuntimeError(f"Dify 文件上传失败 ({status_code}): {response_text}")
|
||||
|
||||
response_json = json.loads(response_text)
|
||||
return response_json["id"] # 返回文件UUID
|
||||
|
||||
def analyze_image(self, image_path: Path, prompt: str) -> RecognitionResult:
|
||||
"""重写完整流程:上传文件 → 发送工作流请求"""
|
||||
start = time.time()
|
||||
analyzed_at = datetime.now().isoformat()
|
||||
|
||||
try:
|
||||
# Step 1: 上传图片
|
||||
file_id = self._upload_file(image_path)
|
||||
|
||||
# Step 2: 发送工作流请求(使用 /v1/workflows/run)
|
||||
url = f"{self.base_url}/v1/workflows/run"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
# 构建 inputs
|
||||
inputs = {
|
||||
"prompt01": prompt, # 段落类型,无字符限制
|
||||
"pho01": { # Workflow 需要的图片参数
|
||||
"type": "image",
|
||||
"transfer_method": "local_file",
|
||||
"upload_file_id": file_id
|
||||
}
|
||||
}
|
||||
|
||||
payload = {
|
||||
"inputs": inputs,
|
||||
"response_mode": "blocking",
|
||||
"user": self.user_id
|
||||
}
|
||||
|
||||
status_code, response_text = http_post_json(
|
||||
url, headers, payload, self.timeout
|
||||
)
|
||||
|
||||
if status_code >= 400:
|
||||
raise RuntimeError(f"Dify 工作流失败 ({status_code}): {response_text}")
|
||||
|
||||
# Step 3: 提取返回的 JSON 内容(Workflow 返回在 data.outputs.text)
|
||||
response_json = json.loads(response_text)
|
||||
# Workflow API 返回格式: {"data": {"outputs": {"text": "..."}}}
|
||||
raw_text = response_json.get("data", {}).get("outputs", {}).get("text", "")
|
||||
|
||||
except Exception as exc:
|
||||
latency = time.time() - start
|
||||
return RecognitionResult(
|
||||
image_name=image_path.name,
|
||||
image_path=str(image_path),
|
||||
api_type=self.api_type,
|
||||
model_name=self.model,
|
||||
latency_seconds=latency,
|
||||
analyzed_at=analyzed_at,
|
||||
raw_response="",
|
||||
error=str(exc),
|
||||
)
|
||||
|
||||
# Step 4: 解析 JSON 结果(复用现有逻辑)
|
||||
latency = time.time() - start
|
||||
parsed, parse_error = extract_json_from_response(raw_text)
|
||||
|
||||
detected_text = normalize_list(parsed.get("detected_text")) if parsed else []
|
||||
detected_objects = normalize_list(parsed.get("detected_objects")) if parsed else []
|
||||
sensitive_items = normalize_list(parsed.get("sensitive_items")) if parsed else []
|
||||
summary = (parsed or {}).get("summary", "")
|
||||
confidence = (parsed or {}).get("confidence", "")
|
||||
|
||||
return RecognitionResult(
|
||||
image_name=image_path.name,
|
||||
image_path=str(image_path),
|
||||
detected_text=detected_text,
|
||||
detected_objects=detected_objects,
|
||||
sensitive_items=sensitive_items,
|
||||
summary=str(summary),
|
||||
confidence=str(confidence),
|
||||
raw_response=raw_text,
|
||||
api_type=self.api_type,
|
||||
model_name=self.model,
|
||||
latency_seconds=latency,
|
||||
analyzed_at=analyzed_at,
|
||||
error=parse_error,
|
||||
)
|
||||
|
||||
|
||||
# ----------- 主执行逻辑 -----------
|
||||
def list_image_files(image_dir: Path, recursive: bool = False) -> List[Path]:
|
||||
"""列出待处理的图片文件"""
|
||||
if recursive:
|
||||
candidates = (p for p in image_dir.rglob("*") if p.is_file())
|
||||
else:
|
||||
candidates = (p for p in image_dir.iterdir() if p.is_file())
|
||||
|
||||
files = [p for p in candidates if p.suffix.lower() in SUPPORTED_EXTENSIONS]
|
||||
return sorted(files)
|
||||
|
||||
|
||||
def create_client(args: argparse.Namespace) -> BaseVisionClient:
|
||||
"""根据用户配置创建合适的客户端"""
|
||||
api_type = (args.api_type or os.getenv("LLM_API_TYPE") or "openai").lower()
|
||||
model = args.model or os.getenv("OPENAI_MODEL") or "gpt-4o-mini"
|
||||
|
||||
if api_type == "mock" or args.mock:
|
||||
return MockVisionClient()
|
||||
|
||||
if api_type == "openai":
|
||||
model = args.model or os.getenv("OPENAI_MODEL") or "gpt-4o-mini"
|
||||
api_key = os.getenv("OPENAI_API_KEY", "")
|
||||
base_url = os.getenv("OPENAI_BASE_URL")
|
||||
return OpenAIVisionClient(api_key=api_key, model=model, timeout=args.timeout, base_url=base_url)
|
||||
|
||||
if api_type == "dmx":
|
||||
model = args.model or os.getenv("DMX_MODEL") or "gpt-5-mini"
|
||||
api_key = os.getenv("DMX_API_KEY", "")
|
||||
base_url = os.getenv("DMX_BASE_URL")
|
||||
return DMXVisionClient(api_key=api_key, model=model, timeout=args.timeout, base_url=base_url)
|
||||
|
||||
if api_type == "anthropic":
|
||||
model = args.model or os.getenv("ANTHROPIC_MODEL") or "claude-3-5-sonnet-20241022"
|
||||
api_key = os.getenv("ANTHROPIC_API_KEY", "")
|
||||
base_url = os.getenv("ANTHROPIC_BASE_URL")
|
||||
return AnthropicVisionClient(api_key=api_key, model=model, timeout=args.timeout, base_url=base_url)
|
||||
|
||||
if api_type == "dify":
|
||||
model = args.model or os.getenv("DIFY_MODEL") or "dify-chatflow"
|
||||
api_key = os.getenv("DIFY_API_KEY", "")
|
||||
base_url = os.getenv("DIFY_BASE_URL")
|
||||
user_id = os.getenv("DIFY_USER_ID") or "default-user"
|
||||
return DifyVisionClient(api_key=api_key, model=model,
|
||||
timeout=args.timeout, base_url=base_url,
|
||||
user_id=user_id)
|
||||
|
||||
raise ValueError(f"不支持的 api_type: {api_type}")
|
||||
|
||||
|
||||
def load_prompt(prompt_file: Optional[str]) -> str:
|
||||
if prompt_file:
|
||||
return Path(prompt_file).read_text(encoding="utf-8").strip()
|
||||
env_prompt = os.getenv("VISION_ANALYSIS_PROMPT")
|
||||
return (env_prompt or DEFAULT_ANALYSIS_PROMPT).strip()
|
||||
|
||||
|
||||
def run_batch_recognition(args: argparse.Namespace) -> List[RecognitionResult]:
|
||||
image_dir = Path(args.image_dir).resolve()
|
||||
if not image_dir.exists():
|
||||
raise FileNotFoundError(f"图片目录不存在: {image_dir}")
|
||||
|
||||
output_path = Path(args.output).resolve()
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
prompt = load_prompt(args.prompt_file)
|
||||
client = create_client(args)
|
||||
|
||||
images = list_image_files(image_dir, recursive=args.recursive)
|
||||
if args.offset:
|
||||
images = images[args.offset :]
|
||||
if args.limit:
|
||||
images = images[: args.limit]
|
||||
|
||||
if not images:
|
||||
print("未找到任何图片文件。")
|
||||
return []
|
||||
|
||||
max_workers = max(1, args.max_workers)
|
||||
debug = args.debug
|
||||
|
||||
# 简化的输出头部
|
||||
if not debug:
|
||||
print(f"- 开始处理 {len(images)} 张图片 | API: {client.api_type} ({client.model})")
|
||||
else:
|
||||
print("=" * 60)
|
||||
print(f"图片目录: {image_dir}")
|
||||
print(f"输出文件: {output_path}")
|
||||
print(f"匹配到 {len(images)} 张图片,将使用 {client.api_type} ({client.model}) 执行识别")
|
||||
print(f"并发 worker 数: {max_workers}")
|
||||
print("=" * 60)
|
||||
|
||||
results: List[RecognitionResult] = []
|
||||
total_images = len(images)
|
||||
|
||||
if max_workers == 1:
|
||||
# 串行处理 - 使用进度条
|
||||
pbar = tqdm(images, desc="处理进度", unit="图", disable=debug)
|
||||
for idx, image_path in enumerate(pbar, 1):
|
||||
if debug:
|
||||
print(f"[{idx}/{total_images}] 识别 {image_path.name} ...")
|
||||
|
||||
result = client.analyze_image(image_path, prompt)
|
||||
results.append(result)
|
||||
|
||||
if debug:
|
||||
if result.is_success:
|
||||
print(
|
||||
f" ✓ 成功,文字 {len(result.detected_text)} 行,物品 {len(result.detected_objects)} 项,"
|
||||
f"耗时 {result.latency_seconds:.1f}s"
|
||||
)
|
||||
else:
|
||||
print(f" ✗ 失败: {result.error}")
|
||||
else:
|
||||
# 精简输出:更新进度条描述
|
||||
status = "✓" if result.is_success else "✗"
|
||||
pbar.set_postfix_str(f"{status} {image_path.name[:20]}...")
|
||||
|
||||
if args.sleep > 0:
|
||||
time.sleep(args.sleep)
|
||||
pbar.close()
|
||||
else:
|
||||
if args.sleep > 0:
|
||||
print("提示: 并发模式下忽略 --sleep 节流参数。")
|
||||
|
||||
result_lookup: Dict[Path, RecognitionResult] = {}
|
||||
|
||||
# 并发处理 - 使用进度条
|
||||
pbar = tqdm(total=total_images, desc="处理进度", unit="图", disable=debug)
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
future_map = {
|
||||
executor.submit(client.analyze_image, image_path, prompt): image_path
|
||||
for image_path in images
|
||||
}
|
||||
|
||||
for completed_idx, future in enumerate(as_completed(future_map), 1):
|
||||
image_path = future_map[future]
|
||||
result = future.result()
|
||||
result_lookup[image_path] = result
|
||||
|
||||
if debug:
|
||||
print(f"[{completed_idx}/{total_images}] 识别 {image_path.name} ...")
|
||||
if result.is_success:
|
||||
print(
|
||||
f" ✓ 成功,文字 {len(result.detected_text)} 行,物品 {len(result.detected_objects)} 项,"
|
||||
f"耗时 {result.latency_seconds:.1f}s"
|
||||
)
|
||||
else:
|
||||
print(f" ✗ 失败: {result.error}")
|
||||
else:
|
||||
# 精简输出:更新进度条
|
||||
status = "✓" if result.is_success else "✗"
|
||||
pbar.set_postfix_str(f"{status} {image_path.name[:20]}...")
|
||||
pbar.update(1)
|
||||
|
||||
pbar.close()
|
||||
results = [result_lookup[path] for path in images]
|
||||
|
||||
save_results(results, output_path)
|
||||
return results
|
||||
|
||||
|
||||
def save_results(results: Sequence[RecognitionResult], output_path: Path) -> None:
|
||||
if not results:
|
||||
print("没有可保存的结果。")
|
||||
return
|
||||
|
||||
df = pd.DataFrame([res.to_row() for res in results])
|
||||
df.to_excel(output_path, index=False, engine="openpyxl")
|
||||
success_count = sum(1 for res in results if res.is_success)
|
||||
print(f"\n识别完成,成功 {success_count}/{len(results)},结果已写入 {output_path}")
|
||||
|
||||
|
||||
def parse_args(argv: Optional[Sequence[str]] = None) -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="批量图片多模态识别(文字 + 物品)",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
base_dir = Path(__file__).resolve().parent.parent
|
||||
default_image_dir = base_dir / "data" / "images"
|
||||
default_output = base_dir / "data" / "output" / "image_recognition_results.xlsx"
|
||||
default_env_file = base_dir / ".env"
|
||||
|
||||
parser.add_argument("--image-dir", type=str, default=str(default_image_dir), help="图片目录")
|
||||
parser.add_argument("--output", type=str, default=str(default_output), help="识别结果输出文件(Excel)")
|
||||
parser.add_argument("--api-type", choices=["openai", "anthropic", "dmx", "dify", "mock"], help="选择使用的 API 类型")
|
||||
parser.add_argument("--model", type=str, help="指定模型名称(默认读取环境变量)")
|
||||
parser.add_argument("--prompt-file", type=str, help="自定义提示词文件")
|
||||
parser.add_argument("--recursive", action="store_true", help="递归搜索子目录")
|
||||
parser.add_argument("--limit", type=int, help="限制最大处理图片数")
|
||||
parser.add_argument("--offset", type=int, default=0, help="跳过前 N 张图片")
|
||||
parser.add_argument("--sleep", type=float, default=DEFAULT_SLEEP_SECONDS, help="每次请求后的等待秒数")
|
||||
parser.add_argument("--timeout", type=int, default=DEFAULT_TIMEOUT, help="HTTP 请求超时")
|
||||
parser.add_argument("--mock", action="store_true", help="强制使用模拟模式(无需 API)")
|
||||
parser.add_argument("--env-file", type=str, default=str(default_env_file), help="包含 API Key / Base URL / Model 配置的 .env 文件")
|
||||
parser.add_argument("--max-workers", type=int, default=1, help="并行识别线程数")
|
||||
parser.add_argument("--debug", action="store_true", help="启用详细输出(默认为精简模式)")
|
||||
|
||||
return parser.parse_args(argv)
|
||||
|
||||
|
||||
def main(argv: Optional[Sequence[str]] = None) -> None:
|
||||
args = parse_args(argv)
|
||||
load_env_file(args.env_file)
|
||||
try:
|
||||
run_batch_recognition(args)
|
||||
except Exception as exc:
|
||||
print(f"执行失败: {exc}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
778
scripts/keyword_matcher.py
Normal file
778
scripts/keyword_matcher.py
Normal file
@@ -0,0 +1,778 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
多模式关键词匹配工具(重构版)
|
||||
- CAS号识别:专注于 `CAS号` 列,支持多种格式(-, 空格, 无分隔符等)
|
||||
- 模糊识别:对所有候选文本(含CAS)进行容错匹配
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Set, Optional, Tuple
|
||||
|
||||
import pandas as pd
|
||||
|
||||
# 可选依赖
|
||||
try:
|
||||
import ahocorasick
|
||||
HAS_AC = True
|
||||
except ImportError:
|
||||
HAS_AC = False
|
||||
|
||||
|
||||
# ========== 常量定义 ==========
|
||||
SEPARATOR = "|||"
|
||||
MATCH_RESULT_SEPARATOR = " | "
|
||||
PROGRESS_INTERVAL = 1000
|
||||
DEFAULT_FUZZY_THRESHOLD = 85
|
||||
|
||||
# CAS号正则表达式
|
||||
# 匹配格式:2-7位数字 + 分隔符(可选) + 2位数字 + 分隔符(可选) + 1位数字
|
||||
# 支持分隔符:- (连字符), 空格, . (点), _ (下划线), 或无分隔符
|
||||
CAS_REGEX_PATTERN = r'\b(\d{2,7})[\s\-._]?(\d{2})[\s\-._]?(\d)\b'
|
||||
|
||||
MODE_KEYWORD_COLUMNS: Dict[str, List[str]] = {
|
||||
"cas": ["CAS号"],
|
||||
"exact": ["中文名", "英文名", "CAS号", "简称", "可能名称"],
|
||||
}
|
||||
|
||||
MODE_LABELS = {
|
||||
"cas": "CAS号识别",
|
||||
"exact": "精确匹配",
|
||||
}
|
||||
|
||||
|
||||
# ========== 数据类 ==========
|
||||
@dataclass
|
||||
class MatchResult:
|
||||
"""匹配结果数据类"""
|
||||
matched_indices: List[int]
|
||||
matched_keywords: List[str]
|
||||
elapsed_time: float
|
||||
total_rows: int
|
||||
matcher_name: str
|
||||
|
||||
@property
|
||||
def match_count(self) -> int:
|
||||
return len(self.matched_indices)
|
||||
|
||||
@property
|
||||
def match_rate(self) -> float:
|
||||
return (self.match_count / self.total_rows * 100) if self.total_rows > 0 else 0.0
|
||||
|
||||
@property
|
||||
def speed(self) -> float:
|
||||
return (self.total_rows / self.elapsed_time) if self.elapsed_time > 0 else 0.0
|
||||
|
||||
|
||||
# ========== 工具函数 ==========
|
||||
def normalize_cas(cas_str: str) -> str:
|
||||
"""
|
||||
将各种格式的CAS号规范化为标准格式 XXX-XX-X
|
||||
|
||||
支持的输入格式:
|
||||
- 123-45-6 (标准格式)
|
||||
- 123 45 6 (空格分隔)
|
||||
- 123.45.6 (点分隔)
|
||||
- 123_45_6 (下划线分隔)
|
||||
- 12345 6 或 1234 56 (部分分隔)
|
||||
- 123456 (无分隔符,仅当总长度正确时)
|
||||
|
||||
返回:标准格式的CAS号,如果无法解析则返回原字符串
|
||||
"""
|
||||
if not cas_str or not isinstance(cas_str, str):
|
||||
return str(cas_str)
|
||||
|
||||
# 移除所有非数字字符,只保留数字
|
||||
digits_only = re.sub(r'[^\d]', '', cas_str)
|
||||
|
||||
# CAS号至少需要5位数字(最短格式:XX-XX-X)
|
||||
if len(digits_only) < 5:
|
||||
return cas_str
|
||||
|
||||
# 重新格式化为标准格式:前n-3位-中间2位-最后1位
|
||||
# 例如:123456 -> 1234-5-6, 12345 -> 123-4-5
|
||||
return f"{digits_only[:-3]}-{digits_only[-3:-1]}-{digits_only[-1]}"
|
||||
|
||||
|
||||
def extract_cas_numbers(text: str, pattern: str = CAS_REGEX_PATTERN) -> Set[str]:
|
||||
"""
|
||||
从文本中提取所有CAS号并规范化
|
||||
|
||||
参数:
|
||||
text: 待搜索的文本
|
||||
pattern: CAS号正则表达式
|
||||
|
||||
返回:规范化后的CAS号集合
|
||||
"""
|
||||
if not text:
|
||||
return set()
|
||||
|
||||
matches = re.finditer(pattern, str(text))
|
||||
cas_numbers = set()
|
||||
|
||||
for match in matches:
|
||||
# 提取完整匹配
|
||||
raw_cas = match.group(0)
|
||||
# 规范化
|
||||
normalized = normalize_cas(raw_cas)
|
||||
cas_numbers.add(normalized)
|
||||
|
||||
return cas_numbers
|
||||
|
||||
|
||||
def split_value(value: str, separator: str) -> List[str]:
|
||||
"""将单元格内容拆分为多个候选关键词"""
|
||||
if separator and separator in value:
|
||||
parts = value.split(separator)
|
||||
else:
|
||||
parts = [value]
|
||||
return [part.strip() for part in parts if part and part.strip()]
|
||||
|
||||
|
||||
def load_keywords_for_mode(
|
||||
df: pd.DataFrame,
|
||||
mode: str,
|
||||
separator: str = SEPARATOR
|
||||
) -> Set[str]:
|
||||
"""根据模式加载关键词集合"""
|
||||
mode_lower = mode.lower()
|
||||
if mode_lower not in MODE_KEYWORD_COLUMNS:
|
||||
raise ValueError(f"不支持的模式: {mode}")
|
||||
|
||||
target_columns = MODE_KEYWORD_COLUMNS[mode_lower]
|
||||
available_columns = [col for col in target_columns if col in df.columns]
|
||||
missing_columns = [col for col in target_columns if col not in df.columns]
|
||||
|
||||
if not available_columns:
|
||||
raise ValueError(
|
||||
f"模式 '{mode_lower}' 需要的列 {target_columns} 均不存在,"
|
||||
f"当前可用列: {df.columns.tolist()}"
|
||||
)
|
||||
|
||||
if missing_columns:
|
||||
print(f"警告: 以下列在关键词文件中缺失: {missing_columns}")
|
||||
|
||||
keywords: Set[str] = set()
|
||||
for column in available_columns:
|
||||
for value in df[column].dropna():
|
||||
value_str = str(value).strip()
|
||||
if not value_str or value_str in ['#N/A#', 'nan', 'None']:
|
||||
continue
|
||||
for token in split_value(value_str, separator):
|
||||
# CAS模式下规范化CAS号
|
||||
if mode_lower == "cas":
|
||||
token = normalize_cas(token)
|
||||
keywords.add(token)
|
||||
|
||||
label = MODE_LABELS.get(mode_lower, mode_lower)
|
||||
print(f"模式「{label}」共加载 {len(keywords)} 个候选关键词,来源列: {available_columns}")
|
||||
|
||||
if not keywords:
|
||||
print(f"警告: 模式 '{mode_lower}' 未加载到任何关键词!")
|
||||
|
||||
return keywords
|
||||
|
||||
|
||||
def show_progress(
|
||||
current: int,
|
||||
total: int,
|
||||
start_time: float,
|
||||
matched_count: int,
|
||||
interval: int = PROGRESS_INTERVAL
|
||||
) -> None:
|
||||
"""显示处理进度"""
|
||||
if (current + 1) % interval == 0:
|
||||
elapsed = time.time() - start_time
|
||||
speed = (current + 1) / elapsed
|
||||
print(f"已处理 {current + 1}/{total} 行,速度: {speed:.1f} 行/秒,匹配到 {matched_count} 行")
|
||||
|
||||
|
||||
# ========== 匹配器基类(策略模式) ==========
|
||||
class KeywordMatcher(ABC):
|
||||
"""关键词匹配器抽象基类"""
|
||||
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
|
||||
def match(
|
||||
self,
|
||||
df: pd.DataFrame,
|
||||
keywords: Set[str],
|
||||
text_column: str
|
||||
) -> MatchResult:
|
||||
"""执行匹配(模板方法)"""
|
||||
print(f"开始匹配(使用{self.name})...")
|
||||
self._prepare(keywords)
|
||||
|
||||
matched_indices = []
|
||||
matched_keywords_list = []
|
||||
start_time = time.time()
|
||||
|
||||
for idx, text in enumerate(df[text_column]):
|
||||
if pd.isna(text):
|
||||
continue
|
||||
|
||||
text_str = str(text)
|
||||
matches = self._match_single_text(text_str, keywords)
|
||||
|
||||
if matches:
|
||||
matched_indices.append(idx)
|
||||
formatted = self._format_matches(matches)
|
||||
matched_keywords_list.append(formatted)
|
||||
|
||||
show_progress(idx, len(df), start_time, len(matched_indices))
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
return MatchResult(
|
||||
matched_indices=matched_indices,
|
||||
matched_keywords=matched_keywords_list,
|
||||
elapsed_time=elapsed,
|
||||
total_rows=len(df),
|
||||
matcher_name=self.name
|
||||
)
|
||||
|
||||
def _prepare(self, keywords: Set[str]) -> None:
|
||||
"""预处理(子类可选实现)"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _match_single_text(self, text: str, keywords: Set[str]) -> Set[str]:
|
||||
"""匹配单条文本(子类必须实现)"""
|
||||
pass
|
||||
|
||||
def _format_matches(self, matches: Set[str]) -> str:
|
||||
"""格式化匹配结果(子类可重写)"""
|
||||
return MATCH_RESULT_SEPARATOR.join(sorted(matches))
|
||||
|
||||
|
||||
# ========== 具体匹配器实现 ==========
|
||||
class AhoCorasickMatcher(KeywordMatcher):
|
||||
"""Aho-Corasick 自动机匹配器"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("Aho-Corasick 自动机")
|
||||
self.automaton = None
|
||||
|
||||
def _prepare(self, keywords: Set[str]) -> None:
|
||||
"""构建自动机"""
|
||||
if not HAS_AC:
|
||||
raise RuntimeError("pyahocorasick 未安装")
|
||||
|
||||
print("正在构建Aho-Corasick自动机...")
|
||||
self.automaton = ahocorasick.Automaton()
|
||||
|
||||
for keyword in keywords:
|
||||
self.automaton.add_word(keyword, keyword)
|
||||
|
||||
self.automaton.make_automaton()
|
||||
print("自动机构建完成")
|
||||
|
||||
def _match_single_text(self, text: str, keywords: Set[str]) -> Set[str]:
|
||||
"""使用自动机匹配"""
|
||||
matched = set()
|
||||
for end_index, keyword in self.automaton.iter(text):
|
||||
matched.add(keyword)
|
||||
return matched
|
||||
|
||||
|
||||
class SetMatcher(KeywordMatcher):
|
||||
"""标准集合匹配器"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("标准集合匹配")
|
||||
|
||||
def _match_single_text(self, text: str, keywords: Set[str]) -> Set[str]:
|
||||
"""使用集合成员检查匹配"""
|
||||
return {kw for kw in keywords if kw in text}
|
||||
|
||||
|
||||
class CASRegexMatcher(KeywordMatcher):
|
||||
"""CAS号正则表达式匹配器(支持多种格式)"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("CAS号正则匹配(支持多种格式)")
|
||||
self.pattern = re.compile(CAS_REGEX_PATTERN)
|
||||
|
||||
def _match_single_text(self, text: str, keywords: Set[str]) -> Set[str]:
|
||||
"""
|
||||
使用正则表达式提取CAS号,规范化后与关键词库比对
|
||||
|
||||
流程:
|
||||
1. 用正则提取文本中所有可能的CAS号
|
||||
2. 将提取的CAS号规范化为标准格式
|
||||
3. 与关键词库(已规范化)进行比对
|
||||
"""
|
||||
# 从文本中提取并规范化CAS号
|
||||
found_cas_numbers = extract_cas_numbers(text, self.pattern)
|
||||
|
||||
# 与关键词库求交集
|
||||
matched = found_cas_numbers & keywords
|
||||
|
||||
return matched
|
||||
|
||||
|
||||
class RegexExactMatcher(KeywordMatcher):
|
||||
"""正则表达式精确匹配器(支持词边界)"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("正则表达式精确匹配(词边界)")
|
||||
self.pattern = None
|
||||
|
||||
def _prepare(self, keywords: Set[str]) -> None:
|
||||
"""构建正则表达式模式"""
|
||||
print("正在构建正则表达式模式...")
|
||||
|
||||
# 转义所有特殊字符,并使用词边界 \b 确保完整词匹配
|
||||
escaped_keywords = [re.escape(kw) for kw in keywords]
|
||||
|
||||
# 构建正则模式:\b(keyword1|keyword2|...)\b
|
||||
# 词边界确保不会匹配到部分词
|
||||
pattern_str = r'\b(' + '|'.join(escaped_keywords) + r')\b'
|
||||
self.pattern = re.compile(pattern_str)
|
||||
|
||||
print(f"正则模式构建完成,共 {len(keywords)} 个关键词")
|
||||
|
||||
def _match_single_text(self, text: str, keywords: Set[str]) -> Set[str]:
|
||||
"""使用正则表达式精确匹配"""
|
||||
if not self.pattern:
|
||||
return set()
|
||||
|
||||
# 查找所有匹配项
|
||||
matches = self.pattern.findall(text)
|
||||
return set(matches)
|
||||
|
||||
|
||||
# ========== 匹配器工厂 ==========
|
||||
def create_matcher(algorithm: str, fuzzy_threshold: int = DEFAULT_FUZZY_THRESHOLD, mode: str = None) -> KeywordMatcher:
|
||||
"""
|
||||
根据算法类型和模式创建匹配器
|
||||
|
||||
参数:
|
||||
algorithm: 匹配算法 (auto, set, exact)
|
||||
fuzzy_threshold: 已废弃,保留仅为向后兼容
|
||||
mode: 检测模式 (cas, exact),用于选择特定匹配器
|
||||
"""
|
||||
algorithm_lower = algorithm.lower()
|
||||
|
||||
# CAS模式使用CAS正则匹配器
|
||||
if mode and mode.lower() == "cas":
|
||||
return CASRegexMatcher()
|
||||
|
||||
# exact模式或exact算法使用正则精确匹配器
|
||||
if mode and mode.lower() == "exact":
|
||||
return RegexExactMatcher()
|
||||
|
||||
if algorithm_lower == "exact":
|
||||
return RegexExactMatcher()
|
||||
|
||||
elif algorithm_lower == "set":
|
||||
return SetMatcher()
|
||||
|
||||
elif algorithm_lower == "auto":
|
||||
if HAS_AC:
|
||||
return AhoCorasickMatcher()
|
||||
else:
|
||||
print("警告: 未安装 pyahocorasick,使用标准匹配方法")
|
||||
return SetMatcher()
|
||||
|
||||
else:
|
||||
raise ValueError(f"不支持的算法: {algorithm}")
|
||||
|
||||
|
||||
# ========== 结果处理 ==========
|
||||
def save_results(
|
||||
df: pd.DataFrame,
|
||||
result: MatchResult,
|
||||
output_file: str
|
||||
) -> Optional[pd.DataFrame]:
|
||||
"""保存匹配结果到 Excel"""
|
||||
if result.match_count == 0:
|
||||
print("未找到任何匹配")
|
||||
return None
|
||||
|
||||
result_df = df.iloc[result.matched_indices].copy()
|
||||
result_df.insert(0, "匹配到的关键词", result.matched_keywords)
|
||||
|
||||
result_df.to_excel(output_file, index=False, engine='openpyxl')
|
||||
print(f"结果已保存到: {output_file}")
|
||||
return result_df
|
||||
|
||||
|
||||
def print_statistics(result: MatchResult) -> None:
|
||||
"""打印匹配统计信息"""
|
||||
print(f"\n{'='*60}")
|
||||
print(f"匹配完成!")
|
||||
print(f"匹配模式: {result.matcher_name}")
|
||||
print(f"总耗时: {result.elapsed_time:.2f} 秒")
|
||||
print(f"处理速度: {result.speed:.1f} 行/秒")
|
||||
print(f"匹配到 {result.match_count} 行数据 ({result.match_rate:.2f}%)")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
|
||||
def preview_results(result_df: pd.DataFrame, num_rows: int = 5) -> None:
|
||||
"""预览匹配结果"""
|
||||
if result_df.empty:
|
||||
return
|
||||
|
||||
print(f"前{num_rows}行匹配结果预览:")
|
||||
print("=" * 80)
|
||||
pd.set_option('display.max_columns', None)
|
||||
pd.set_option('display.width', None)
|
||||
pd.set_option('display.max_colwidth', 100)
|
||||
print(result_df.head(num_rows))
|
||||
print("=" * 80)
|
||||
print(f"\n✓ 总共匹配到 {len(result_df)} 行数据")
|
||||
|
||||
|
||||
# ========== 主流程 ==========
|
||||
def perform_matching(
|
||||
df: pd.DataFrame,
|
||||
keywords: Set[str],
|
||||
text_column: str,
|
||||
output_file: str,
|
||||
algorithm: str = "auto",
|
||||
mode: str = None
|
||||
) -> Optional[pd.DataFrame]:
|
||||
"""执行完整的匹配流程"""
|
||||
# 验证列存在
|
||||
if text_column not in df.columns:
|
||||
print(f"可用列名: {df.columns.tolist()}")
|
||||
raise ValueError(f"列 '{text_column}' 不存在")
|
||||
|
||||
print(f"文本文件共有 {len(df)} 行数据\n")
|
||||
|
||||
# 创建匹配器并执行匹配
|
||||
matcher = create_matcher(algorithm, mode=mode)
|
||||
result = matcher.match(df, keywords, text_column)
|
||||
|
||||
# 输出统计信息
|
||||
print_statistics(result)
|
||||
|
||||
# 保存结果
|
||||
result_df = save_results(df, result, output_file)
|
||||
|
||||
return result_df
|
||||
|
||||
|
||||
def process_single_mode(
|
||||
keywords_df: pd.DataFrame,
|
||||
text_df: pd.DataFrame,
|
||||
mode: str,
|
||||
text_column: str,
|
||||
output_file: Path,
|
||||
separator: str = SEPARATOR,
|
||||
save_to_file: bool = True
|
||||
) -> Optional[pd.DataFrame]:
|
||||
"""
|
||||
处理单个检测模式
|
||||
|
||||
返回:匹配结果 DataFrame(包含原始索引)
|
||||
"""
|
||||
mode_lower = mode.lower()
|
||||
label = MODE_LABELS.get(mode_lower, mode_lower)
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print(f">>> 正在执行识别模式: {label}")
|
||||
print(f"{'='*60}")
|
||||
|
||||
# 加载关键词
|
||||
keywords = load_keywords_for_mode(keywords_df, mode_lower, separator=separator)
|
||||
|
||||
# 显示关键词样例
|
||||
sample_keywords = list(keywords)[:10]
|
||||
if sample_keywords:
|
||||
print(f"\n{label} - 关键词样例(前10个):")
|
||||
for idx, kw in enumerate(sample_keywords, 1):
|
||||
print(f" {idx}. {kw}")
|
||||
print()
|
||||
|
||||
# 选择算法
|
||||
algorithm = "exact" if mode_lower == "exact" else "auto"
|
||||
|
||||
# 执行匹配(如果不需要保存到文件,传入临时路径)
|
||||
temp_output = str(output_file) if save_to_file else "/tmp/temp_match.xlsx"
|
||||
result_df = perform_matching(
|
||||
df=text_df,
|
||||
keywords=keywords,
|
||||
text_column=text_column,
|
||||
output_file=temp_output,
|
||||
algorithm=algorithm,
|
||||
mode=mode_lower # 传递模式参数
|
||||
)
|
||||
|
||||
# 如果不保存文件,删除临时文件
|
||||
if not save_to_file and result_df is not None:
|
||||
import os
|
||||
if os.path.exists(temp_output):
|
||||
os.remove(temp_output)
|
||||
|
||||
# 预览结果
|
||||
if result_df is not None and save_to_file:
|
||||
preview_results(result_df)
|
||||
|
||||
# 添加模式标识列
|
||||
if result_df is not None:
|
||||
result_df['匹配模式'] = label
|
||||
|
||||
return result_df
|
||||
|
||||
|
||||
def run_multiple_modes(
|
||||
keywords_file: Path,
|
||||
text_file: Path,
|
||||
output_file: Path,
|
||||
text_column: str,
|
||||
modes: List[str],
|
||||
separator: str = SEPARATOR
|
||||
) -> None:
|
||||
"""运行多个检测模式,合并结果到单一文件"""
|
||||
# 验证文件存在
|
||||
if not keywords_file.exists():
|
||||
raise FileNotFoundError(f"找不到关键词文件: {keywords_file}")
|
||||
if not text_file.exists():
|
||||
raise FileNotFoundError(f"找不到文本文件: {text_file}")
|
||||
|
||||
# 加载数据
|
||||
print(f"正在加载关键词文件: {keywords_file}")
|
||||
keywords_df = pd.read_excel(keywords_file)
|
||||
print(f"可用列: {keywords_df.columns.tolist()}\n")
|
||||
|
||||
print(f"正在加载文本文件: {text_file}")
|
||||
text_df = pd.read_excel(text_file)
|
||||
print(f"文本列: {text_column}\n")
|
||||
|
||||
# 验证模式
|
||||
if not modes:
|
||||
raise ValueError("modes 不能为空,请至少指定一个模式")
|
||||
|
||||
for mode in modes:
|
||||
if mode.lower() not in MODE_KEYWORD_COLUMNS:
|
||||
raise ValueError(f"不支持的识别模式: {mode}")
|
||||
|
||||
# 收集所有模式的匹配结果
|
||||
all_results = []
|
||||
multiple_modes = len(modes) > 1
|
||||
|
||||
for mode in modes:
|
||||
mode_lower = mode.lower()
|
||||
|
||||
# 处理该模式(不保存到单独文件)
|
||||
result_df = process_single_mode(
|
||||
keywords_df=keywords_df,
|
||||
text_df=text_df,
|
||||
mode=mode_lower,
|
||||
text_column=text_column,
|
||||
output_file=output_file, # 这个参数在 save_to_file=False 时不使用
|
||||
separator=separator,
|
||||
save_to_file=False # 不保存到单独文件
|
||||
)
|
||||
|
||||
if result_df is not None and not result_df.empty:
|
||||
all_results.append(result_df)
|
||||
|
||||
# 合并所有结果
|
||||
if not all_results:
|
||||
print("\n所有模式均未匹配到数据")
|
||||
return
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print("正在合并所有模式的匹配结果...")
|
||||
print(f"{'='*60}")
|
||||
|
||||
# 合并结果
|
||||
merged_df = merge_mode_results(all_results, text_df)
|
||||
|
||||
# 保存合并后的结果
|
||||
merged_df.to_excel(output_file, index=False, engine='openpyxl')
|
||||
print(f"\n合并结果已保存到: {output_file}")
|
||||
print(f" 总匹配行数: {len(merged_df)} 行")
|
||||
|
||||
# 统计每个模式的贡献
|
||||
print(f"\n各模式匹配统计:")
|
||||
for mode_result in all_results:
|
||||
mode_name = mode_result['匹配模式'].iloc[0]
|
||||
count = len(mode_result)
|
||||
print(f" {mode_name:20s}: {count:4d} 行")
|
||||
|
||||
# 预览合并结果
|
||||
print(f"\n{'='*60}")
|
||||
print("合并结果预览(前5行):")
|
||||
print(f"{'='*60}")
|
||||
preview_results(merged_df, num_rows=5)
|
||||
|
||||
|
||||
def merge_mode_results(
|
||||
results: List[pd.DataFrame],
|
||||
original_df: pd.DataFrame
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
合并多个模式的匹配结果
|
||||
|
||||
策略:
|
||||
1. 按原始数据行索引合并
|
||||
2. 如果同一行被多个模式匹配,合并关键词和模式标识
|
||||
3. 保留原始数据的所有列
|
||||
"""
|
||||
if not results:
|
||||
return pd.DataFrame()
|
||||
|
||||
# 记录每个原始行索引的匹配信息
|
||||
row_matches = {}
|
||||
|
||||
for result_df in results:
|
||||
for idx, row in result_df.iterrows():
|
||||
if idx not in row_matches:
|
||||
# 首次出现该行
|
||||
row_matches[idx] = {
|
||||
'keywords': row['匹配到的关键词'],
|
||||
'modes': [row['匹配模式']]
|
||||
}
|
||||
else:
|
||||
# 该行已被其他模式匹配过,合并关键词和模式
|
||||
existing_keywords = row_matches[idx]['keywords']
|
||||
new_keywords = row['匹配到的关键词']
|
||||
|
||||
# 合并关键词(去重)
|
||||
all_keywords = set(str(existing_keywords).split(' | ')) | set(str(new_keywords).split(' | '))
|
||||
row_matches[idx]['keywords'] = ' | '.join(sorted(all_keywords))
|
||||
|
||||
# 添加模式标识
|
||||
row_matches[idx]['modes'].append(row['匹配模式'])
|
||||
|
||||
# 构建最终结果
|
||||
final_indices = list(row_matches.keys())
|
||||
final_df = original_df.loc[final_indices].copy()
|
||||
|
||||
# 添加合并后的列
|
||||
final_df.insert(0, '匹配到的关键词', [row_matches[idx]['keywords'] for idx in final_indices])
|
||||
final_df.insert(1, '匹配模式', [' + '.join(row_matches[idx]['modes']) for idx in final_indices])
|
||||
|
||||
# 按原始索引排序
|
||||
final_df = final_df.sort_index()
|
||||
|
||||
return final_df
|
||||
|
||||
|
||||
# ========== 命令行接口 ==========
|
||||
def parse_args():
|
||||
"""解析命令行参数"""
|
||||
parser = argparse.ArgumentParser(
|
||||
description='多模式关键词匹配工具',
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
示例:
|
||||
# 使用默认配置(两种模式)
|
||||
python keyword_matcher.py
|
||||
|
||||
# 仅执行 CAS 号识别
|
||||
python keyword_matcher.py -m cas
|
||||
|
||||
# 仅执行精确匹配
|
||||
python keyword_matcher.py -m exact
|
||||
|
||||
# 指定自定义文件路径
|
||||
python keyword_matcher.py -k ../data/input/keywords.xlsx -t ../data/input/text.xlsx
|
||||
"""
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-k', '--keywords',
|
||||
type=str,
|
||||
help='关键词文件路径 (默认: ../data/input/keywords.xlsx)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-t', '--text',
|
||||
type=str,
|
||||
help='文本文件路径 (默认: ../data/input/clickin_text_img.xlsx)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-o', '--output',
|
||||
type=str,
|
||||
help='输出文件路径 (默认: ../data/output/keyword_matched_results.xlsx)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-c', '--text-column',
|
||||
type=str,
|
||||
default='文本',
|
||||
help='文本列名 (默认: 文本)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-m', '--modes',
|
||||
nargs='+',
|
||||
choices=['cas', 'exact'],
|
||||
default=['cas', 'exact'],
|
||||
help='识别模式 (默认: cas exact)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--separator',
|
||||
type=str,
|
||||
default=SEPARATOR,
|
||||
help=f'关键词分隔符 (默认: {SEPARATOR})'
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
args = parse_args()
|
||||
|
||||
# 确定文件路径
|
||||
base_dir = Path(__file__).resolve().parent
|
||||
|
||||
keywords_file = Path(args.keywords) if args.keywords else (
|
||||
base_dir.parent / "data" / "input" / "keywords.xlsx"
|
||||
)
|
||||
|
||||
text_file = Path(args.text) if args.text else (
|
||||
base_dir.parent / "data" / "input" / "clickin_text_img.xlsx"
|
||||
)
|
||||
|
||||
output_file = Path(args.output) if args.output else (
|
||||
base_dir.parent / "data" / "output" / "keyword_matched_results.xlsx"
|
||||
)
|
||||
|
||||
# 显示依赖库状态
|
||||
print("=" * 60)
|
||||
print("依赖库状态:")
|
||||
print(f" pyahocorasick: {'已安装 ✓' if HAS_AC else '未安装 ✗'}")
|
||||
print("=" * 60)
|
||||
print()
|
||||
|
||||
if not HAS_AC:
|
||||
print("提示: 安装 pyahocorasick 可获得 5x 性能提升: pip install pyahocorasick\n")
|
||||
|
||||
try:
|
||||
run_multiple_modes(
|
||||
keywords_file=keywords_file,
|
||||
text_file=text_file,
|
||||
output_file=output_file,
|
||||
text_column=args.text_column,
|
||||
modes=args.modes,
|
||||
separator=args.separator
|
||||
)
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("✓ 所有模式处理完成!")
|
||||
print("=" * 60)
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n错误: {e}")
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
154
scripts/run.sh
Normal file
154
scripts/run.sh
Normal file
@@ -0,0 +1,154 @@
|
||||
#!/bin/bash
|
||||
# 批量图片分析管理脚本 - 极简版
|
||||
|
||||
BASE_DIR="$HOME/bin/p_anlz"
|
||||
BATCH_SCRIPT="$BASE_DIR/scripts/batch_run.sh"
|
||||
LOG_DIR="$BASE_DIR/logs"
|
||||
MAIN_LOG="$LOG_DIR/batch_run.log"
|
||||
|
||||
# 显示帮助
|
||||
show_help() {
|
||||
cat << EOF
|
||||
批量图片分析管理工具
|
||||
|
||||
用法: $0 <命令>
|
||||
|
||||
命令:
|
||||
start 后台启动任务
|
||||
stop 停止任务
|
||||
status 查看运行状态
|
||||
log 查看实时日志
|
||||
help 显示帮助
|
||||
|
||||
示例:
|
||||
$0 start # 启动任务
|
||||
$0 log # 查看日志
|
||||
$0 status # 查看状态
|
||||
|
||||
修改参数:
|
||||
API_TYPE=openai $0 start
|
||||
MAX_WORKERS=3 $0 start
|
||||
TIMEOUT=120 $0 start
|
||||
EOF
|
||||
}
|
||||
|
||||
# 启动任务
|
||||
start_task() {
|
||||
if pgrep -f "batch_run.sh" > /dev/null; then
|
||||
echo "任务已在运行中"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "启动批量分析任务..."
|
||||
|
||||
# 导出环境变量
|
||||
export API_TYPE="${API_TYPE:-dify}"
|
||||
export MAX_WORKERS="${MAX_WORKERS:-1}"
|
||||
export TIMEOUT="${TIMEOUT:-90}"
|
||||
|
||||
# 后台运行
|
||||
nohup bash "$BATCH_SCRIPT" > /dev/null 2>&1 &
|
||||
|
||||
sleep 2
|
||||
|
||||
if pgrep -f "batch_run.sh" > /dev/null; then
|
||||
echo "✓ 任务已启动"
|
||||
echo ""
|
||||
echo "查看日志: $0 log"
|
||||
echo "查看状态: $0 status"
|
||||
else
|
||||
echo "✗ 启动失败"
|
||||
exit 1
|
||||
fi
|
||||
}
|
||||
|
||||
# 停止任务
|
||||
stop_task() {
|
||||
if ! pgrep -f "batch_run.sh" > /dev/null; then
|
||||
echo "任务未运行"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
echo "停止任务..."
|
||||
pkill -f "batch_run.sh"
|
||||
pkill -f "image_batch_recognizer.py"
|
||||
|
||||
sleep 1
|
||||
|
||||
if ! pgrep -f "batch_run.sh" > /dev/null; then
|
||||
echo "✓ 任务已停止"
|
||||
else
|
||||
echo "✗ 停止失败,尝试强制终止..."
|
||||
pkill -9 -f "batch_run.sh"
|
||||
pkill -9 -f "image_batch_recognizer.py"
|
||||
fi
|
||||
}
|
||||
|
||||
# 查看状态
|
||||
show_status() {
|
||||
echo "任务状态:"
|
||||
echo ""
|
||||
|
||||
if pgrep -f "batch_run.sh" > /dev/null; then
|
||||
echo "✓ 运行中"
|
||||
echo ""
|
||||
echo "进程信息:"
|
||||
ps aux | grep -E "batch_run.sh|image_batch_recognizer.py" | grep -v grep
|
||||
echo ""
|
||||
|
||||
# 显示当前进度
|
||||
if [ -f "$MAIN_LOG" ]; then
|
||||
echo "最新日志:"
|
||||
tail -10 "$MAIN_LOG"
|
||||
fi
|
||||
else
|
||||
echo "○ 未运行"
|
||||
|
||||
if [ -f "$MAIN_LOG" ]; then
|
||||
echo ""
|
||||
echo "上次运行结果:"
|
||||
tail -15 "$MAIN_LOG" | grep -A 10 "批量分析完成"
|
||||
fi
|
||||
fi
|
||||
}
|
||||
|
||||
# 查看日志
|
||||
show_log() {
|
||||
if [ ! -f "$MAIN_LOG" ]; then
|
||||
echo "日志文件不存在: $MAIN_LOG"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "实时查看日志 (Ctrl+C 退出):"
|
||||
echo ""
|
||||
tail -f "$MAIN_LOG"
|
||||
}
|
||||
|
||||
# 主函数
|
||||
main() {
|
||||
case "${1:-help}" in
|
||||
start)
|
||||
start_task
|
||||
;;
|
||||
stop)
|
||||
stop_task
|
||||
;;
|
||||
status)
|
||||
show_status
|
||||
;;
|
||||
log)
|
||||
show_log
|
||||
;;
|
||||
help|--help|-h)
|
||||
show_help
|
||||
;;
|
||||
*)
|
||||
echo "未知命令: $1"
|
||||
echo ""
|
||||
show_help
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
}
|
||||
|
||||
main "$@"
|
||||
349
scripts/run_batch_background.sh
Normal file
349
scripts/run_batch_background.sh
Normal file
@@ -0,0 +1,349 @@
|
||||
#!/bin/bash
|
||||
# 后台启动批量分析任务的包装脚本
|
||||
# 支持 nohup、screen、tmux 三种方式
|
||||
|
||||
set -e
|
||||
|
||||
BASE_DIR="$HOME/bin/p_anlz"
|
||||
SCRIPT_DIR="$BASE_DIR/scripts"
|
||||
BATCH_SCRIPT="$SCRIPT_DIR/batch_analyze_all_folders.sh"
|
||||
LOG_DIR="$BASE_DIR/logs"
|
||||
NOHUP_LOG="$LOG_DIR/nohup.out"
|
||||
|
||||
# 颜色输出
|
||||
GREEN='\033[0;32m'
|
||||
YELLOW='\033[1;33m'
|
||||
BLUE='\033[0;34m'
|
||||
NC='\033[0m'
|
||||
|
||||
# 显示使用说明
|
||||
show_usage() {
|
||||
cat << EOF
|
||||
批量图片分析后台运行工具
|
||||
|
||||
用法:
|
||||
$0 [方式]
|
||||
|
||||
方式选项:
|
||||
nohup 使用 nohup 后台运行(默认,最简单)
|
||||
screen 使用 screen 会话运行(可重新连接)
|
||||
tmux 使用 tmux 会话运行(推荐,功能最强)
|
||||
status 查看当前运行状态
|
||||
logs 实时查看主日志
|
||||
progress 查看处理进度
|
||||
live 实时查看当前 Python 程序输出(推荐)
|
||||
|
||||
示例:
|
||||
$0 nohup # 使用 nohup 启动
|
||||
$0 screen # 使用 screen 启动
|
||||
$0 tmux # 使用 tmux 启动
|
||||
$0 status # 查看状态
|
||||
$0 logs # 查看主日志
|
||||
$0 progress # 查看进度
|
||||
$0 live # 实时查看 Python 输出(含进度条)
|
||||
|
||||
后台运行后如何查看:
|
||||
- nohup 方式: tail -f $NOHUP_LOG
|
||||
- screen 方式: screen -r batch_analyze
|
||||
- tmux 方式: tmux attach -t batch_analyze
|
||||
- 实时输出: $0 live
|
||||
EOF
|
||||
}
|
||||
|
||||
# 检查脚本是否存在
|
||||
check_script() {
|
||||
if [ ! -f "$BATCH_SCRIPT" ]; then
|
||||
echo -e "${RED}错误: 批处理脚本不存在: $BATCH_SCRIPT${NC}"
|
||||
exit 1
|
||||
fi
|
||||
chmod +x "$BATCH_SCRIPT"
|
||||
}
|
||||
|
||||
# 检查是否已经在运行
|
||||
check_running() {
|
||||
if pgrep -f "batch_analyze_all_folders.sh" > /dev/null; then
|
||||
echo -e "${YELLOW}警告: 批量分析任务已在运行中${NC}"
|
||||
echo "进程信息:"
|
||||
ps aux | grep "batch_analyze_all_folders.sh" | grep -v grep
|
||||
echo ""
|
||||
read -p "是否继续启动新任务? (y/N): " confirm
|
||||
if [ "$confirm" != "y" ] && [ "$confirm" != "Y" ]; then
|
||||
echo "已取消"
|
||||
exit 0
|
||||
fi
|
||||
fi
|
||||
}
|
||||
|
||||
# 使用 nohup 启动
|
||||
start_nohup() {
|
||||
echo -e "${BLUE}使用 nohup 启动批量分析任务...${NC}"
|
||||
mkdir -p "$LOG_DIR"
|
||||
|
||||
nohup bash "$BATCH_SCRIPT" > "$NOHUP_LOG" 2>&1 &
|
||||
local pid=$!
|
||||
|
||||
echo -e "${GREEN}✓ 任务已在后台启动 (PID: $pid)${NC}"
|
||||
echo ""
|
||||
echo "查看日志:"
|
||||
echo " tail -f $NOHUP_LOG"
|
||||
echo ""
|
||||
echo "查看进度:"
|
||||
echo " $0 progress"
|
||||
echo ""
|
||||
echo "停止任务:"
|
||||
echo " kill $pid"
|
||||
}
|
||||
|
||||
# 使用 screen 启动
|
||||
start_screen() {
|
||||
if ! command -v screen &> /dev/null; then
|
||||
echo -e "${YELLOW}screen 未安装,请先安装: sudo apt-get install screen${NC}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo -e "${BLUE}使用 screen 启动批量分析任务...${NC}"
|
||||
|
||||
# 检查是否已有同名 session
|
||||
if screen -list | grep -q "batch_analyze"; then
|
||||
echo -e "${YELLOW}警告: screen 会话 'batch_analyze' 已存在${NC}"
|
||||
read -p "是否删除旧会话并创建新会话? (y/N): " confirm
|
||||
if [ "$confirm" = "y" ] || [ "$confirm" = "Y" ]; then
|
||||
screen -S batch_analyze -X quit 2>/dev/null || true
|
||||
else
|
||||
echo "已取消"
|
||||
exit 0
|
||||
fi
|
||||
fi
|
||||
|
||||
screen -dmS batch_analyze bash "$BATCH_SCRIPT"
|
||||
|
||||
echo -e "${GREEN}✓ 任务已在 screen 会话中启动${NC}"
|
||||
echo ""
|
||||
echo "重新连接到会话:"
|
||||
echo " screen -r batch_analyze"
|
||||
echo ""
|
||||
echo "分离会话: Ctrl+A, 然后按 D"
|
||||
echo ""
|
||||
echo "查看所有会话:"
|
||||
echo " screen -ls"
|
||||
echo ""
|
||||
echo "终止会话:"
|
||||
echo " screen -S batch_analyze -X quit"
|
||||
}
|
||||
|
||||
# 使用 tmux 启动
|
||||
start_tmux() {
|
||||
if ! command -v tmux &> /dev/null; then
|
||||
echo -e "${YELLOW}tmux 未安装,请先安装: sudo apt-get install tmux${NC}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo -e "${BLUE}使用 tmux 启动批量分析任务...${NC}"
|
||||
|
||||
# 检查是否已有同名 session
|
||||
if tmux has-session -t batch_analyze 2>/dev/null; then
|
||||
echo -e "${YELLOW}警告: tmux 会话 'batch_analyze' 已存在${NC}"
|
||||
read -p "是否删除旧会话并创建新会话? (y/N): " confirm
|
||||
if [ "$confirm" = "y" ] || [ "$confirm" = "Y" ]; then
|
||||
tmux kill-session -t batch_analyze 2>/dev/null || true
|
||||
else
|
||||
echo "已取消"
|
||||
exit 0
|
||||
fi
|
||||
fi
|
||||
|
||||
tmux new-session -d -s batch_analyze "bash $BATCH_SCRIPT"
|
||||
|
||||
echo -e "${GREEN}✓ 任务已在 tmux 会话中启动${NC}"
|
||||
echo ""
|
||||
echo "重新连接到会话:"
|
||||
echo " tmux attach -t batch_analyze"
|
||||
echo ""
|
||||
echo "分离会话: Ctrl+B, 然后按 D"
|
||||
echo ""
|
||||
echo "查看所有会话:"
|
||||
echo " tmux ls"
|
||||
echo ""
|
||||
echo "终止会话:"
|
||||
echo " tmux kill-session -t batch_analyze"
|
||||
}
|
||||
|
||||
# 查看运行状态
|
||||
show_status() {
|
||||
echo -e "${BLUE}批量分析任务状态:${NC}"
|
||||
echo ""
|
||||
|
||||
# 检查进程
|
||||
if pgrep -f "batch_analyze_all_folders.sh" > /dev/null; then
|
||||
echo -e "${GREEN}✓ 任务正在运行${NC}"
|
||||
echo ""
|
||||
echo "进程信息:"
|
||||
ps aux | grep "batch_analyze_all_folders.sh" | grep -v grep
|
||||
else
|
||||
echo -e "${YELLOW}○ 任务未运行${NC}"
|
||||
fi
|
||||
|
||||
echo ""
|
||||
|
||||
# 检查 screen 会话
|
||||
if command -v screen &> /dev/null; then
|
||||
if screen -list | grep -q "batch_analyze"; then
|
||||
echo -e "${GREEN}✓ Screen 会话存在${NC}"
|
||||
screen -list | grep "batch_analyze"
|
||||
fi
|
||||
fi
|
||||
|
||||
echo ""
|
||||
|
||||
# 检查 tmux 会话
|
||||
if command -v tmux &> /dev/null; then
|
||||
if tmux has-session -t batch_analyze 2>/dev/null; then
|
||||
echo -e "${GREEN}✓ Tmux 会话存在${NC}"
|
||||
tmux list-sessions | grep "batch_analyze"
|
||||
fi
|
||||
fi
|
||||
}
|
||||
|
||||
# 查看日志
|
||||
show_logs() {
|
||||
local main_log="$LOG_DIR/batch_analyze.log"
|
||||
|
||||
if [ -f "$main_log" ]; then
|
||||
echo -e "${BLUE}实时查看主日志 (Ctrl+C 退出):${NC}"
|
||||
echo ""
|
||||
tail -f "$main_log"
|
||||
elif [ -f "$NOHUP_LOG" ]; then
|
||||
echo -e "${BLUE}实时查看 nohup 日志 (Ctrl+C 退出):${NC}"
|
||||
echo ""
|
||||
tail -f "$NOHUP_LOG"
|
||||
else
|
||||
echo -e "${YELLOW}未找到日志文件${NC}"
|
||||
echo "日志目录: $LOG_DIR"
|
||||
fi
|
||||
}
|
||||
|
||||
# 查看进度
|
||||
show_progress() {
|
||||
local progress_file="$LOG_DIR/progress.txt"
|
||||
|
||||
if [ ! -f "$progress_file" ]; then
|
||||
echo -e "${YELLOW}进度文件不存在: $progress_file${NC}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo -e "${BLUE}批量分析进度:${NC}"
|
||||
echo ""
|
||||
|
||||
# 显示进度文件内容
|
||||
cat "$progress_file"
|
||||
|
||||
echo ""
|
||||
echo -e "${BLUE}统计信息:${NC}"
|
||||
|
||||
# 统计各状态数量
|
||||
local total=$(grep -c "SUCCESS\|FAILED\|SKIPPED" "$progress_file" || echo "0")
|
||||
local success=$(grep -c "SUCCESS" "$progress_file" || echo "0")
|
||||
local failed=$(grep -c "FAILED" "$progress_file" || echo "0")
|
||||
local skipped=$(grep -c "SKIPPED" "$progress_file" || echo "0")
|
||||
|
||||
echo "总计: $total"
|
||||
echo -e "${GREEN}成功: $success${NC}"
|
||||
echo -e "${YELLOW}失败: $failed${NC}"
|
||||
echo -e "${BLUE}跳过: $skipped${NC}"
|
||||
|
||||
# 显示最后更新时间
|
||||
if [ -f "$progress_file" ]; then
|
||||
local last_update=$(stat -c %y "$progress_file" 2>/dev/null || stat -f "%Sm" "$progress_file" 2>/dev/null)
|
||||
echo ""
|
||||
echo "最后更新: $last_update"
|
||||
fi
|
||||
}
|
||||
|
||||
# 实时查看当前 Python 程序输出
|
||||
show_live() {
|
||||
local main_log="$LOG_DIR/batch_analyze.log"
|
||||
|
||||
echo -e "${BLUE}实时查看 Python 程序输出 (Ctrl+C 退出)${NC}"
|
||||
echo ""
|
||||
|
||||
# 检查任务是否在运行
|
||||
if ! pgrep -f "batch_analyze_all_folders.sh" > /dev/null; then
|
||||
echo -e "${YELLOW}任务未运行${NC}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# 查找当前正在处理的文件夹
|
||||
if [ -f "$main_log" ]; then
|
||||
local current_folder=$(grep -oP "处理文件夹: \K.*" "$main_log" | tail -1)
|
||||
if [ -n "$current_folder" ]; then
|
||||
local folder_log="$LOG_DIR/${current_folder}.log"
|
||||
|
||||
if [ -f "$folder_log" ]; then
|
||||
echo -e "${GREEN}当前处理: $current_folder${NC}"
|
||||
echo -e "${BLUE}日志文件: $folder_log${NC}"
|
||||
echo ""
|
||||
echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━"
|
||||
|
||||
# 实时跟踪文件夹日志(显示 Python 程序的完整输出)
|
||||
tail -f "$folder_log"
|
||||
else
|
||||
echo -e "${YELLOW}等待日志文件生成...${NC}"
|
||||
sleep 2
|
||||
show_live
|
||||
fi
|
||||
else
|
||||
echo -e "${YELLOW}等待任务开始...${NC}"
|
||||
sleep 2
|
||||
show_live
|
||||
fi
|
||||
else
|
||||
echo -e "${YELLOW}主日志文件不存在${NC}"
|
||||
exit 1
|
||||
fi
|
||||
}
|
||||
|
||||
# 主函数
|
||||
main() {
|
||||
local mode="${1:-nohup}"
|
||||
|
||||
case "$mode" in
|
||||
nohup)
|
||||
check_script
|
||||
check_running
|
||||
start_nohup
|
||||
;;
|
||||
screen)
|
||||
check_script
|
||||
check_running
|
||||
start_screen
|
||||
;;
|
||||
tmux)
|
||||
check_script
|
||||
check_running
|
||||
start_tmux
|
||||
;;
|
||||
status)
|
||||
show_status
|
||||
;;
|
||||
logs)
|
||||
show_logs
|
||||
;;
|
||||
progress)
|
||||
show_progress
|
||||
;;
|
||||
live)
|
||||
show_live
|
||||
;;
|
||||
-h|--help|help)
|
||||
show_usage
|
||||
;;
|
||||
*)
|
||||
echo "未知选项: $mode"
|
||||
echo ""
|
||||
show_usage
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
}
|
||||
|
||||
main "$@"
|
||||
Reference in New Issue
Block a user