fix: update keywords_match

This commit is contained in:
2026-01-18 18:25:36 +08:00
parent 29f6e25f70
commit 4ed90734df
7 changed files with 1406 additions and 269 deletions

107
scripts/batch_keyword_match.sh Executable file
View File

@@ -0,0 +1,107 @@
#!/bin/bash
# 批量关键词匹配脚本
# 处理 data/pho_analysis_merged/ 中的所有 xlsx 文件
set -e
# 获取脚本所在目录
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
PROJECT_DIR="$(dirname "$SCRIPT_DIR")"
# 目录配置
INPUT_DIR="$PROJECT_DIR/data/pho_analysis_merged"
OUTPUT_DIR="$PROJECT_DIR/data/output"
KEYWORDS_FILE="$PROJECT_DIR/data/keywords/keywords_all.xlsx"
# 颜色输出
GREEN='\033[0;32m'
YELLOW='\033[1;33m'
RED='\033[0;31m'
NC='\033[0m' # No Color
echo "=============================================="
echo "批量关键词匹配"
echo "=============================================="
echo "输入目录: $INPUT_DIR"
echo "输出目录: $OUTPUT_DIR"
echo "关键词文件: $KEYWORDS_FILE"
echo ""
# 检查输入目录
if [ ! -d "$INPUT_DIR" ]; then
echo -e "${RED}错误: 输入目录不存在: $INPUT_DIR${NC}"
exit 1
fi
# 检查关键词文件
if [ ! -f "$KEYWORDS_FILE" ]; then
echo -e "${YELLOW}警告: 关键词文件不存在: $KEYWORDS_FILE${NC}"
echo "将使用默认关键词文件"
KEYWORDS_FILE=""
fi
# 创建输出目录
mkdir -p "$OUTPUT_DIR"
# 统计
total=0
success=0
failed=0
# 获取所有 xlsx 文件
files=("$INPUT_DIR"/*.xlsx)
# 检查是否有文件
if [ ! -e "${files[0]}" ]; then
echo -e "${YELLOW}没有找到 xlsx 文件${NC}"
exit 0
fi
# 计算总数
for f in "${files[@]}"; do
if [ -f "$f" ]; then
((total++))
fi
done
echo "找到 $total 个文件待处理"
echo "----------------------------------------------"
# 处理每个文件
current=0
for input_file in "${files[@]}"; do
if [ ! -f "$input_file" ]; then
continue
fi
((current++))
# 获取文件名(不含扩展名)
filename=$(basename "$input_file" .xlsx)
output_file="$OUTPUT_DIR/${filename}_matched.xlsx"
echo -e "\n[$current/$total] 处理: $filename"
# 构建命令
cmd="python3 $SCRIPT_DIR/keyword_matcher.py -t \"$input_file\" -o \"$output_file\""
if [ -n "$KEYWORDS_FILE" ]; then
cmd="$cmd -k \"$KEYWORDS_FILE\""
fi
# 执行匹配
if eval "$cmd"; then
echo -e "${GREEN} ✓ 完成: ${filename}_matched.xlsx${NC}"
((success++))
else
echo -e "${RED} ✗ 失败: $filename${NC}"
((failed++))
fi
done
# 汇总
echo ""
echo "=============================================="
echo "处理完成"
echo "=============================================="
echo -e "总计: $total | ${GREEN}成功: $success${NC} | ${RED}失败: $failed${NC}"
echo "输出目录: $OUTPUT_DIR"

315
scripts/collect_xlsx.py Normal file
View File

@@ -0,0 +1,315 @@
#!/usr/bin/env python3
"""
收集并合并 xlsx 文件
功能:
1. 从 data/batch_output 子文件夹收集 results.xlsx图片分析结果
2. 与 data/data_all 中对应的原始数据({name}_text_img.xlsx合并
3. 通过图片名关联两个数据源
4. 保存合并后的文件到目标目录
用法:
python3 collect_xlsx.py # 默认合并并输出
python3 collect_xlsx.py -o ../data/merged # 指定输出目录
python3 collect_xlsx.py --no-merge # 不合并,只复制
python3 collect_xlsx.py -n # 预览模式
"""
import argparse
from pathlib import Path
from typing import Optional, Tuple, List
import pandas as pd
def extract_image_name(path: str) -> str:
"""
从完整路径提取图片文件名
支持 Windows 和 Unix 路径格式
"""
if pd.isna(path):
return ""
path_str = str(path).strip()
if not path_str:
return ""
# 同时处理 Windows (\) 和 Unix (/) 路径分隔符
# 先统一替换为 /,再提取文件名
normalized = path_str.replace("\\", "/")
filename = normalized.split("/")[-1]
return filename
def merge_xlsx_files(
results_file: Path,
original_file: Path,
results_image_col: str = "image_name",
original_image_cols: list = None,
original_text_col: str = "文本"
) -> Tuple[pd.DataFrame, dict]:
"""
合并分析结果和原始数据
Args:
results_file: 分析结果文件 (batch_output/.../results.xlsx)
original_file: 原始数据文件 (data_all/..._text_img.xlsx)
results_image_col: 结果文件中的图片名列
original_image_cols: 原始文件中可能的图片路径列(按优先级)
original_text_col: 原始文件中的文本列
Returns:
合并后的 DataFrame 和统计信息
"""
if original_image_cols is None:
original_image_cols = ["图片_新", "图片", "图片链接"]
# 读取文件
results_df = pd.read_excel(results_file)
original_df = pd.read_excel(original_file)
stats = {
"results_rows": len(results_df),
"original_rows": len(original_df),
"merged_rows": 0,
"unmatched_results": 0,
"original_columns_added": [],
"image_col_used": None
}
# 找到可用的图片列
image_col = None
for col in original_image_cols:
if col in original_df.columns:
image_col = col
break
if image_col is None:
raise ValueError(f"原始文件中未找到图片列,尝试过: {original_image_cols}")
stats["image_col_used"] = image_col
# 从原始数据提取图片名作为关联键
original_df["_image_name"] = original_df[image_col].apply(extract_image_name)
# 去重:原始数据可能有重复图片,保留第一条
original_dedup = original_df.drop_duplicates(subset=["_image_name"], keep="first")
# 确定要添加的原始数据列(排除图片路径列和临时列)
exclude_cols = set(original_image_cols + ["_image_name"])
original_cols_to_add = [col for col in original_df.columns
if col not in exclude_cols
and col not in results_df.columns]
stats["original_columns_added"] = original_cols_to_add
# 创建图片名到原始数据的映射
original_map = original_dedup.set_index("_image_name")[original_cols_to_add].to_dict("index")
# 合并:为结果数据添加原始数据列
merged_df = results_df.copy()
# 初始化新列
for col in original_cols_to_add:
merged_df[col] = None
# 逐行匹配并填充
matched_count = 0
for idx, row in merged_df.iterrows():
image_name = row[results_image_col]
if image_name in original_map:
for col in original_cols_to_add:
merged_df.at[idx, col] = original_map[image_name].get(col)
matched_count += 1
stats["merged_rows"] = len(merged_df)
stats["matched_count"] = matched_count
stats["unmatched_results"] = len(merged_df) - matched_count
return merged_df, stats
def collect_and_merge_xlsx(
source_dir: str,
data_all_dir: str,
output_dir: str,
merge: bool = True,
dry_run: bool = False
) -> List[dict]:
"""
收集并合并 xlsx 文件
Args:
source_dir: batch_output 目录路径
data_all_dir: data_all 目录路径
output_dir: 输出目录路径
merge: 是否合并原始数据
dry_run: 预览模式
Returns:
处理结果列表
"""
source_path = Path(source_dir)
data_all_path = Path(data_all_dir)
output_path = Path(output_dir)
if not source_path.exists():
print(f"错误: 源目录不存在: {source_dir}")
return []
# 创建输出目录
if not dry_run:
output_path.mkdir(parents=True, exist_ok=True)
results = []
# 遍历子文件夹
for folder in sorted(source_path.iterdir()):
if not folder.is_dir():
continue
folder_name = folder.name
results_file = folder / "results.xlsx"
if not results_file.exists():
continue
# 输出文件名
output_file = output_path / f"{folder_name}.xlsx"
# 查找对应的原始数据文件
original_file = data_all_path / f"{folder_name}_text_img.xlsx"
result_info = {
"folder": folder_name,
"results_file": str(results_file),
"original_file": str(original_file) if original_file.exists() else None,
"output_file": str(output_file),
"merged": False,
"stats": {}
}
if dry_run:
if merge and original_file.exists():
print(f"[预览] 合并: {folder_name}/results.xlsx + {folder_name}_text_img.xlsx -> {folder_name}.xlsx")
else:
print(f"[预览] 复制: {folder_name}/results.xlsx -> {folder_name}.xlsx")
results.append(result_info)
continue
# 执行合并或复制
if merge and original_file.exists():
try:
merged_df, stats = merge_xlsx_files(results_file, original_file)
merged_df.to_excel(output_file, index=False, engine="openpyxl")
result_info["merged"] = True
result_info["stats"] = stats
print(f"已合并: {folder_name}")
print(f" - 分析结果: {stats['results_rows']}")
print(f" - 原始数据: {stats['original_rows']}")
print(f" - 匹配成功: {stats['matched_count']}")
print(f" - 添加列: {stats['original_columns_added']}")
except Exception as e:
print(f"合并失败 {folder_name}: {e}")
# 回退到复制模式
import shutil
shutil.copy2(results_file, output_file)
print(f" 已回退到复制模式")
else:
# 只复制,不合并
import shutil
shutil.copy2(results_file, output_file)
if merge and not original_file.exists():
print(f"已复制: {folder_name} (原始数据不存在: {folder_name}_text_img.xlsx)")
else:
print(f"已复制: {folder_name}")
results.append(result_info)
return results
def main():
parser = argparse.ArgumentParser(
description="收集并合并 batch_output 和 data_all 中的 xlsx 文件",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
示例:
python3 collect_xlsx.py # 默认合并并输出
python3 collect_xlsx.py -o ../data/merged # 指定输出目录
python3 collect_xlsx.py --no-merge # 不合并,只复制
python3 collect_xlsx.py -n # 预览模式
"""
)
parser.add_argument(
"-s", "--source",
default="../data/batch_output",
help="batch_output 目录路径 (默认: ../data/batch_output)"
)
parser.add_argument(
"-d", "--data-all",
default="../data/data_all",
help="data_all 目录路径 (默认: ../data/data_all)"
)
parser.add_argument(
"-o", "--output",
default="../data/collected_xlsx",
help="输出目录路径 (默认: ../data/collected_xlsx)"
)
parser.add_argument(
"--no-merge",
action="store_true",
help="不合并原始数据,只复制分析结果"
)
parser.add_argument(
"-n", "--dry-run",
action="store_true",
help="预览模式,只打印不执行"
)
args = parser.parse_args()
# 转换为绝对路径
script_dir = Path(__file__).parent
source_dir = (script_dir / args.source).resolve()
data_all_dir = (script_dir / args.data_all).resolve()
output_dir = (script_dir / args.output).resolve()
print("=" * 60)
print("收集并合并 xlsx 文件")
print("=" * 60)
print(f"分析结果目录: {source_dir}")
print(f"原始数据目录: {data_all_dir}")
print(f"输出目录: {output_dir}")
print(f"合并模式: {'' if args.no_merge else ''}")
print("-" * 60)
results = collect_and_merge_xlsx(
str(source_dir),
str(data_all_dir),
str(output_dir),
merge=not args.no_merge,
dry_run=args.dry_run
)
print("-" * 60)
merged_count = sum(1 for r in results if r.get("merged"))
print(f"共处理 {len(results)} 个文件")
if not args.no_merge:
print(f" - 合并成功: {merged_count}")
print(f" - 仅复制: {len(results) - merged_count}")
if __name__ == "__main__":
main()

View File

@@ -47,6 +47,18 @@ MODE_LABELS = {
"exact": "精确匹配",
}
# 常见的文本列名(按优先级排序)
COMMON_TEXT_COLUMNS = [
"detected_text", # 新格式(图片分析结果)
"文本", # 旧格式 / 合并后的原始文本
"text",
"content",
"summary",
]
# 默认多列匹配组合
DEFAULT_TEXT_COLUMNS = ["detected_text", "文本"]
# ========== 数据类 ==========
@dataclass
@@ -136,6 +148,71 @@ def split_value(value: str, separator: str) -> List[str]:
return [part.strip() for part in parts if part and part.strip()]
def detect_text_columns(
df: pd.DataFrame,
specified_columns: Optional[List[str]] = None
) -> List[str]:
"""
检测并验证文本列名
参数:
df: 数据框
specified_columns: 用户指定的列名列表
返回:存在的文本列名列表
异常:如果找不到任何合适的列则抛出 ValueError
"""
# 如果用户指定了列名
if specified_columns:
available = [col for col in specified_columns if col in df.columns]
missing = [col for col in specified_columns if col not in df.columns]
if missing:
print(f"警告: 以下指定的列不存在: {missing}")
if available:
return available
else:
print("警告: 所有指定的列都不存在,尝试自动检测...")
# 自动检测:优先使用默认多列组合
available_default = [col for col in DEFAULT_TEXT_COLUMNS if col in df.columns]
if available_default:
print(f"自动检测到文本列: {available_default}")
return available_default
# 回退:使用第一个找到的常见列
for col in COMMON_TEXT_COLUMNS:
if col in df.columns:
print(f"自动检测到文本列: ['{col}']")
return [col]
# 都没找到,抛出异常
raise ValueError(
f"无法自动检测文本列。可用列: {df.columns.tolist()}\n"
f"请使用 -c 参数指定文本列名"
)
def combine_text_columns(row: pd.Series, text_columns: List[str]) -> str:
"""
合并多列文本内容
参数:
row: DataFrame 的一行
text_columns: 要合并的列名列表
返回:合并后的文本(用换行符分隔)
"""
texts = []
for col in text_columns:
val = row.get(col)
if pd.notna(val) and str(val).strip():
texts.append(str(val).strip())
return "\n".join(texts)
def load_keywords_for_mode(
df: pd.DataFrame,
mode: str,
@@ -205,22 +282,32 @@ class KeywordMatcher(ABC):
self,
df: pd.DataFrame,
keywords: Set[str],
text_column: str
text_columns: List[str]
) -> MatchResult:
"""执行匹配(模板方法)"""
"""执行匹配(模板方法)
参数:
df: 数据框
keywords: 关键词集合
text_columns: 文本列名列表(支持多列)
"""
print(f"开始匹配(使用{self.name}...")
print(f"搜索列: {text_columns}")
self._prepare(keywords)
matched_indices = []
matched_keywords_list = []
start_time = time.time()
for idx, text in enumerate(df[text_column]):
if pd.isna(text):
for idx in range(len(df)):
row = df.iloc[idx]
# 合并多列文本
combined_text = combine_text_columns(row, text_columns)
if not combined_text:
continue
text_str = str(text)
matches = self._match_single_text(text_str, keywords)
matches = self._match_single_text(combined_text, keywords)
if matches:
matched_indices.append(idx)
@@ -435,22 +522,36 @@ def preview_results(result_df: pd.DataFrame, num_rows: int = 5) -> None:
def perform_matching(
df: pd.DataFrame,
keywords: Set[str],
text_column: str,
text_columns: List[str],
output_file: str,
algorithm: str = "auto",
mode: str = None
) -> Optional[pd.DataFrame]:
"""执行完整的匹配流程"""
"""执行完整的匹配流程
参数:
df: 数据框
keywords: 关键词集合
text_columns: 文本列名列表(支持多列)
output_file: 输出文件路径
algorithm: 匹配算法
mode: 匹配模式
"""
# 验证列存在
if text_column not in df.columns:
missing_cols = [col for col in text_columns if col not in df.columns]
if missing_cols:
print(f"警告: 以下列不存在: {missing_cols}")
text_columns = [col for col in text_columns if col in df.columns]
if not text_columns:
print(f"可用列名: {df.columns.tolist()}")
raise ValueError(f"'{text_column}' 不存在")
raise ValueError("没有可用的文本列")
print(f"文本文件共有 {len(df)} 行数据\n")
# 创建匹配器并执行匹配
matcher = create_matcher(algorithm, mode=mode)
result = matcher.match(df, keywords, text_column)
result = matcher.match(df, keywords, text_columns)
# 输出统计信息
print_statistics(result)
@@ -465,7 +566,7 @@ def process_single_mode(
keywords_df: pd.DataFrame,
text_df: pd.DataFrame,
mode: str,
text_column: str,
text_columns: List[str],
output_file: Path,
separator: str = SEPARATOR,
save_to_file: bool = True
@@ -473,6 +574,9 @@ def process_single_mode(
"""
处理单个检测模式
参数:
text_columns: 文本列名列表(支持多列)
返回:匹配结果 DataFrame包含原始索引
"""
mode_lower = mode.lower()
@@ -501,7 +605,7 @@ def process_single_mode(
result_df = perform_matching(
df=text_df,
keywords=keywords,
text_column=text_column,
text_columns=text_columns,
output_file=temp_output,
algorithm=algorithm,
mode=mode_lower # 传递模式参数
@@ -528,11 +632,15 @@ def run_multiple_modes(
keywords_file: Path,
text_file: Path,
output_file: Path,
text_column: str,
text_columns: Optional[List[str]],
modes: List[str],
separator: str = SEPARATOR
) -> None:
"""运行多个检测模式,合并结果到单一文件"""
"""运行多个检测模式,合并结果到单一文件
参数:
text_columns: 文本列名列表支持多列None 表示自动检测
"""
# 验证文件存在
if not keywords_file.exists():
raise FileNotFoundError(f"找不到关键词文件: {keywords_file}")
@@ -546,7 +654,10 @@ def run_multiple_modes(
print(f"正在加载文本文件: {text_file}")
text_df = pd.read_excel(text_file)
print(f"文本列: {text_column}\n")
# 自动检测或验证文本列
actual_text_columns = detect_text_columns(text_df, text_columns)
print(f"使用文本列: {actual_text_columns}\n")
# 验证模式
if not modes:
@@ -568,7 +679,7 @@ def run_multiple_modes(
keywords_df=keywords_df,
text_df=text_df,
mode=mode_lower,
text_column=text_column,
text_columns=actual_text_columns,
output_file=output_file, # 这个参数在 save_to_file=False 时不使用
separator=separator,
save_to_file=False # 不保存到单独文件
@@ -668,7 +779,7 @@ def parse_args():
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
示例:
# 使用默认配置(两种模式
# 使用默认配置(自动检测 detected_text 和 文本 列
python keyword_matcher.py
# 仅执行 CAS 号识别
@@ -677,6 +788,12 @@ def parse_args():
# 仅执行精确匹配
python keyword_matcher.py -m exact
# 指定单个文本列
python keyword_matcher.py -c detected_text
# 指定多个文本列
python keyword_matcher.py -c detected_text 文本 summary
# 指定自定义文件路径
python keyword_matcher.py -k ../data/input/keywords.xlsx -t ../data/input/text.xlsx
"""
@@ -701,10 +818,11 @@ def parse_args():
)
parser.add_argument(
'-c', '--text-column',
'-c', '--text-columns',
nargs='+',
type=str,
default='文本',
help='文本列名 (默认: 文本)'
default=None,
help='文本列名,支持多列 (默认: 自动检测 detected_text 和 文本)'
)
parser.add_argument(
@@ -759,7 +877,7 @@ def main():
keywords_file=keywords_file,
text_file=text_file,
output_file=output_file,
text_column=args.text_column,
text_columns=args.text_columns,
modes=args.modes,
separator=args.separator
)

View File

@@ -0,0 +1,517 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
验证高置信度未匹配记录
功能:比对 keyword_matcher 结果与原始 Excel找出高置信度未匹配行调用 LLM 二次验证。
用法:
python3 verify_high_confidence.py -o original.xlsx -m matched.xlsx
python3 verify_high_confidence.py -o original.xlsx -m matched.xlsx --mock --limit 5
"""
import argparse
import json
import os
import sys
import time
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from pathlib import Path
from typing import List, Optional
import pandas as pd
# 可选依赖
try:
import openai
HAS_OPENAI = True
except ImportError:
HAS_OPENAI = False
try:
import requests
import urllib3
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
HAS_REQUESTS = True
except ImportError:
HAS_REQUESTS = False
# ========== 常量与配置 ==========
CONFIDENCE_LEVELS = ["High", "Medium"]
REQUEST_DELAY = 0.5
# 环境变量映射: api_type -> (key_env, url_env, model_env, default_model)
ENV_MAPPING = {
"openai": ("OPENAI_API_KEY", "OPENAI_BASE_URL", "OPENAI_MODEL", "gpt-4o-mini"),
"dmx": ("DMX_API_KEY", "DMX_BASE_URL", "DMX_MODEL", "gpt-4o-mini"),
"dify": ("DIFY_API_KEY", "DIFY_BASE_URL", "DIFY_MODEL", "dify-chatflow"),
"ollama": (None, "OLLAMA_BASE_URL", "OLLAMA_MODEL", "qwen2.5:7b"),
}
SYSTEM_PROMPT = """你是一位化学品风险识别专家。请分析文本内容,判断是否涉及管制化学品、毒品前体或非法药物交易。
请以 JSON 格式回答,包含以下字段:
- is_risky: 布尔值,是否涉及风险
- substances: 数组涉及的具体物质名称或CAS号
- risk_level: 字符串,风险等级(高/中/低)
- reason: 字符串,判定理由(简要)
示例输出:
{"is_risky": true, "substances": ["甲基苯丙胺", "CAS 537-46-2"], "risk_level": "", "reason": "文本中明确提到毒品名称和交易信息"}
"""
USER_PROMPT_TEMPLATE = """请分析以下内容是否涉及管制化学品或毒品:
【图片分析结果】
{raw_response}
【原始文本】
{original_text}
请以 JSON 格式输出分析结果。"""
# ========== 数据类 ==========
@dataclass
class VerifyConfig:
api_type: str = "openai"
api_key: str = ""
base_url: Optional[str] = None
model: str = "gpt-4o-mini"
user_id: str = "default-user"
@dataclass
class VerificationResult:
is_risky: Optional[bool] = None
substances: List[str] = field(default_factory=list)
risk_level: str = ""
reason: str = ""
raw_response: str = ""
def to_columns(self) -> dict:
return {
"llm_is_risky": self.is_risky,
"llm_substances": " | ".join(self.substances) if self.substances else "",
"llm_risk_level": self.risk_level,
"llm_reason": self.reason,
"llm_raw_response": self.raw_response,
}
# ========== 工具函数 ==========
def load_env_file(env_path: str) -> None:
"""从 .env 文件加载环境变量"""
env_file = Path(env_path)
if not env_file.exists():
return
print(f"加载环境配置: {env_file}")
with open(env_file, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line or line.startswith("#"):
continue
if line.startswith("export "):
line = line[7:]
if "=" in line:
key, _, value = line.partition("=")
os.environ[key.strip()] = value.strip().strip('"').strip("'")
def get_config() -> VerifyConfig:
"""获取验证配置,优先使用 VERIFY_ 前缀"""
api_type = (os.getenv("VERIFY_API_TYPE") or os.getenv("LLM_API_TYPE") or "openai").lower()
mapping = ENV_MAPPING.get(api_type, (None, None, None, "gpt-4o-mini"))
key_env, url_env, model_env, default_model = mapping
return VerifyConfig(
api_type=api_type,
api_key=os.getenv("VERIFY_API_KEY") or (os.getenv(key_env) if key_env else "") or "",
base_url=os.getenv("VERIFY_BASE_URL") or (os.getenv(url_env) if url_env else None),
model=os.getenv("VERIFY_MODEL") or (os.getenv(model_env) if model_env else default_model) or default_model,
user_id=os.getenv("VERIFY_USER_ID") or os.getenv("DIFY_USER_ID") or "default-user",
)
def parse_json_response(content: str) -> dict:
"""从 LLM 响应提取 JSON处理 markdown 代码块"""
# 移除 markdown 代码块
if "```json" in content:
start = content.find("```json") + 7
end = content.find("```", start)
content = content[start:end].strip()
elif "```" in content:
start = content.find("```") + 3
end = content.find("```", start)
content = content[start:end].strip()
try:
start = content.find("{")
end = content.rfind("}") + 1
if start >= 0 and end > start:
return json.loads(content[start:end])
except json.JSONDecodeError:
pass
return {"is_risky": None, "substances": [], "risk_level": "未知", "reason": "JSON 解析失败"}
def build_prompt(row: pd.Series, max_len: int = 3000) -> str:
"""构建用户提示"""
raw = str(row.get("raw_response", "") or "")
text = str(row.get("文本", "") or "")
if len(raw) > max_len:
raw = raw[:max_len] + "...(截断)"
if len(text) > max_len:
text = text[:max_len] + "...(截断)"
return USER_PROMPT_TEMPLATE.format(raw_response=raw, original_text=text)
# ========== 验证器类 ==========
class LLMVerifier(ABC):
"""LLM 验证器抽象基类"""
@abstractmethod
def verify(self, row: pd.Series) -> VerificationResult:
pass
class OpenAIVerifier(LLMVerifier):
"""OpenAI 兼容 API 验证器 (支持 OpenAI, DMX, Ollama)"""
def __init__(self, config: VerifyConfig):
if not HAS_OPENAI:
raise ImportError("请安装 openai: pip install openai")
if config.api_type != "ollama" and not config.api_key:
raise ValueError("未提供 API Key")
base_url = config.base_url
if config.api_type == "ollama":
base_url = (config.base_url or "http://localhost:11434") + "/v1"
self.client = openai.OpenAI(
api_key=config.api_key or "ollama",
base_url=base_url,
)
self.model = config.model
def verify(self, row: pd.Series) -> VerificationResult:
prompt = build_prompt(row)
try:
response = self.client.chat.completions.create(
model=self.model,
messages=[
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": prompt},
],
temperature=0.1,
max_tokens=500,
)
content = response.choices[0].message.content or ""
if response.choices[0].finish_reason != "stop":
return VerificationResult(
risk_level="错误",
reason=f"响应不完整 (finish_reason={response.choices[0].finish_reason})",
raw_response=content,
)
parsed = parse_json_response(content)
return VerificationResult(
is_risky=parsed.get("is_risky"),
substances=parsed.get("substances", []),
risk_level=parsed.get("risk_level", ""),
reason=parsed.get("reason", ""),
raw_response=content,
)
except Exception as e:
return VerificationResult(risk_level="错误", reason=f"API 调用失败: {e}", raw_response=str(e))
class DifyVerifier(LLMVerifier):
"""Dify API 验证器"""
def __init__(self, config: VerifyConfig):
if not HAS_REQUESTS:
raise ImportError("请安装 requests: pip install requests")
if not config.api_key:
raise ValueError("未提供 Dify API Key")
self.base_url = (config.base_url or "").rstrip("/")
self.api_key = config.api_key
self.user_id = config.user_id
def verify(self, row: pd.Series) -> VerificationResult:
prompt = f"{SYSTEM_PROMPT}\n\n{build_prompt(row)}"
try:
resp = requests.post(
f"{self.base_url}/v1/chat-messages",
headers={"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"},
json={"inputs": {}, "query": prompt, "response_mode": "blocking", "user": self.user_id},
timeout=120,
verify=False,
)
resp.raise_for_status()
content = resp.json().get("answer", "")
parsed = parse_json_response(content)
return VerificationResult(
is_risky=parsed.get("is_risky"),
substances=parsed.get("substances", []),
risk_level=parsed.get("risk_level", ""),
reason=parsed.get("reason", ""),
raw_response=content,
)
except Exception as e:
return VerificationResult(risk_level="错误", reason=f"Dify 调用失败: {e}", raw_response=str(e))
class MockVerifier(LLMVerifier):
"""Mock 验证器(测试用)"""
RISK_KEYWORDS = [
"毒品", "非法", "管制", "药物", "化学品", "CAS", "阿片", "芬太尼",
"冰毒", "大麻", "可卡因", "海洛因", "摇头丸", "麻黄碱",
"fentanyl", "methamphetamine", "cocaine", "heroin", "mdma",
"ketamine", "lsd", "precursor", "controlled",
]
def verify(self, row: pd.Series) -> VerificationResult:
all_text = f"{row.get('raw_response', '')} {row.get('文本', '')}".lower()
found = [kw for kw in self.RISK_KEYWORDS if kw.lower() in all_text]
is_risky = len(found) > 0
return VerificationResult(
is_risky=is_risky,
substances=found[:5],
risk_level="" if is_risky else "",
reason=f"Mock模式 - 发现关键词: {found[:3]}" if is_risky else "Mock模式 - 未发现风险关键词",
raw_response="(mock)",
)
def create_verifier(config: VerifyConfig) -> LLMVerifier:
"""根据配置创建验证器"""
if config.api_type == "mock":
return MockVerifier()
elif config.api_type == "dify":
return DifyVerifier(config)
elif config.api_type in ("openai", "dmx", "ollama"):
return OpenAIVerifier(config)
else:
raise ValueError(f"不支持的 API 类型: {config.api_type}")
# ========== 数据处理 ==========
def load_excel(file_path: Path) -> pd.DataFrame:
"""加载 Excel 文件"""
if not file_path.exists():
raise FileNotFoundError(f"文件不存在: {file_path}")
return pd.read_excel(file_path)
def find_unmatched(
original_df: pd.DataFrame,
matched_df: pd.DataFrame,
confidence_col: str = "confidence",
confidence_levels: List[str] = None,
) -> pd.DataFrame:
"""找出高置信度但未被关键词匹配的行"""
levels = confidence_levels or CONFIDENCE_LEVELS
if confidence_col not in original_df.columns:
print(f"警告: 原始文件中不存在 '{confidence_col}'")
print(f"可用列: {original_df.columns.tolist()}")
return pd.DataFrame()
# 高置信度行索引
conf_lower = original_df[confidence_col].astype(str).str.lower()
levels_lower = [l.lower() for l in levels]
high_conf_idx = set(original_df[conf_lower.isin(levels_lower)].index)
matched_idx = set(matched_df.index)
unmatched_idx = high_conf_idx - matched_idx
# 统计信息
print(f"\n{'='*50}")
print("数据比对统计")
print(f"{'='*50}")
print(f"原始数据总行数: {len(original_df)}")
print(f"高置信度 ({'/'.join(levels)}) 行数: {len(high_conf_idx)}")
print(f"关键词匹配到的行数: {len(matched_idx)}")
print(f"高置信度中已匹配: {len(high_conf_idx & matched_idx)}")
print(f"高置信度中未匹配 (需验证): {len(unmatched_idx)}")
print(f"{'='*50}\n")
if not unmatched_idx:
return pd.DataFrame()
return original_df.loc[list(unmatched_idx)].copy()
def verify_batch(df: pd.DataFrame, verifier: LLMVerifier, delay: float = REQUEST_DELAY, limit: int = 0) -> pd.DataFrame:
"""批量验证记录"""
if limit > 0:
df = df.head(limit)
total = len(df)
print(f"开始 LLM 验证,共 {total} 条记录...")
print("-" * 50)
results = []
start_time = time.time()
for i, (idx, row) in enumerate(df.iterrows()):
if (i + 1) % 10 == 0 or i == 0 or i == total - 1:
elapsed = time.time() - start_time
speed = (i + 1) / elapsed if elapsed > 0 else 0
print(f"进度: {i + 1}/{total} ({(i+1)/total*100:.1f}%) - 速度: {speed:.1f} 条/秒")
result = verifier.verify(row)
results.append({"original_index": idx, **result.to_columns()})
if delay > 0 and i < total - 1:
time.sleep(delay)
results_df = pd.DataFrame(results).set_index("original_index")
verified_df = df.copy()
for col in results_df.columns:
verified_df[col] = results_df[col]
return verified_df
# ========== 结果输出 ==========
def save_results(df: pd.DataFrame, output_file: Path, risky_only: bool = False) -> None:
"""保存结果"""
if risky_only and "llm_is_risky" in df.columns:
df = df[df["llm_is_risky"] == True]
df.to_excel(output_file, index=False, engine="openpyxl")
print(f"\n已保存 {len(df)} 条记录到: {output_file}")
def print_summary(df: pd.DataFrame) -> None:
"""打印验证摘要"""
print(f"\n{'='*50}")
print("验证结果摘要")
print(f"{'='*50}")
total = len(df)
if "llm_is_risky" not in df.columns:
print(f"总记录数: {total}")
return
risky = (df["llm_is_risky"] == True).sum()
not_risky = (df["llm_is_risky"] == False).sum()
unknown = total - risky - not_risky
print(f"总验证数: {total}")
print(f" ├─ LLM 判定有风险: {risky} ({risky/total*100:.1f}%)")
print(f" ├─ LLM 判定无风险: {not_risky} ({not_risky/total*100:.1f}%)")
if unknown > 0:
print(f" └─ 判定失败/未知: {unknown}")
if "llm_risk_level" in df.columns:
print(f"\n风险等级分布:")
for level, count in df["llm_risk_level"].value_counts().items():
print(f" - {level}: {count}")
print(f"{'='*50}")
# ========== CLI ==========
def parse_args():
parser = argparse.ArgumentParser(
description="验证高置信度未匹配记录",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
示例:
python3 verify_high_confidence.py -o original.xlsx -m matched.xlsx
python3 verify_high_confidence.py -o original.xlsx -m matched.xlsx --mock --limit 5
python3 verify_high_confidence.py -o original.xlsx -m matched.xlsx --api dmx --model gpt-4o-mini
""",
)
parser.add_argument("-o", "--original", required=True, help="原始 Excel 文件路径")
parser.add_argument("-m", "--matched", required=True, help="keyword_matcher 匹配结果文件路径")
parser.add_argument("-r", "--result", help="输出结果文件路径 (默认: 原始文件名_llm_verified.xlsx)")
parser.add_argument("--env-file", help="环境变量文件路径 (默认: ../.env)")
parser.add_argument("--api", choices=["openai", "dmx", "dify", "ollama"], help="LLM API 类型")
parser.add_argument("--model", help="LLM 模型名称")
parser.add_argument("--base-url", help="API base URL")
parser.add_argument("--api-key", help="API Key")
parser.add_argument("--mock", action="store_true", help="使用 mock 模式(不调用 API")
parser.add_argument("--confidence", nargs="+", default=["High", "Medium"], help="需要验证的置信度级别")
parser.add_argument("--confidence-col", default="confidence", help="置信度列名")
parser.add_argument("--delay", type=float, default=REQUEST_DELAY, help="API 请求间隔秒数")
parser.add_argument("--limit", type=int, default=0, help="限制验证条数 (0=全部)")
parser.add_argument("--risky-only", action="store_true", help="只保存有风险的记录")
return parser.parse_args()
def main():
args = parse_args()
# 加载 .env
base_dir = Path(__file__).resolve().parent
env_file = args.env_file or str(base_dir.parent / ".env")
load_env_file(env_file)
# 获取配置
config = get_config()
# 命令行参数覆盖
if args.mock:
config.api_type = "mock"
elif args.api:
config.api_type = args.api
if args.model:
config.model = args.model
if args.base_url:
config.base_url = args.base_url
if args.api_key:
config.api_key = args.api_key
# 文件路径
original_file = Path(args.original)
matched_file = Path(args.matched)
result_file = Path(args.result) if args.result else original_file.parent / f"{original_file.stem}_llm_verified.xlsx"
print("=" * 60)
print("高置信度未匹配记录验证")
print("=" * 60)
print(f"原始文件: {original_file}")
print(f"匹配结果: {matched_file}")
print(f"输出文件: {result_file}")
print(f"置信度级别: {args.confidence}")
print(f"API 类型: {config.api_type}")
print(f"模型: {config.model}")
if config.base_url:
print(f"Base URL: {config.base_url}")
# 加载数据
print("\n正在加载数据...")
original_df = load_excel(original_file)
matched_df = load_excel(matched_file)
# 找出未匹配的高置信度行
unmatched_df = find_unmatched(original_df, matched_df, args.confidence_col, args.confidence)
if unmatched_df.empty:
print("\n所有高置信度行都已被关键词匹配,无需验证。")
return
# 创建验证器
try:
verifier = create_verifier(config)
except (ImportError, ValueError) as e:
print(f"\n错误: {e}")
sys.exit(1)
# 执行验证
verified_df = verify_batch(unmatched_df, verifier, delay=args.delay, limit=args.limit)
# 打印摘要并保存
print_summary(verified_df)
save_results(verified_df, result_file, args.risky_only)
print("\n✓ 验证完成!")
if __name__ == "__main__":
main()