518 lines
18 KiB
Python
518 lines
18 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
验证高置信度未匹配记录
|
||
|
||
功能:比对 keyword_matcher 结果与原始 Excel,找出高置信度未匹配行,调用 LLM 二次验证。
|
||
|
||
用法:
|
||
python3 verify_high_confidence.py -o original.xlsx -m matched.xlsx
|
||
python3 verify_high_confidence.py -o original.xlsx -m matched.xlsx --mock --limit 5
|
||
"""
|
||
|
||
import argparse
|
||
import json
|
||
import os
|
||
import sys
|
||
import time
|
||
from abc import ABC, abstractmethod
|
||
from dataclasses import dataclass, field
|
||
from pathlib import Path
|
||
from typing import List, Optional
|
||
|
||
import pandas as pd
|
||
|
||
# 可选依赖
|
||
try:
|
||
import openai
|
||
HAS_OPENAI = True
|
||
except ImportError:
|
||
HAS_OPENAI = False
|
||
|
||
try:
|
||
import requests
|
||
import urllib3
|
||
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
||
HAS_REQUESTS = True
|
||
except ImportError:
|
||
HAS_REQUESTS = False
|
||
|
||
# ========== 常量与配置 ==========
|
||
CONFIDENCE_LEVELS = ["High", "Medium"]
|
||
REQUEST_DELAY = 0.5
|
||
|
||
# 环境变量映射: api_type -> (key_env, url_env, model_env, default_model)
|
||
ENV_MAPPING = {
|
||
"openai": ("OPENAI_API_KEY", "OPENAI_BASE_URL", "OPENAI_MODEL", "gpt-4o-mini"),
|
||
"dmx": ("DMX_API_KEY", "DMX_BASE_URL", "DMX_MODEL", "gpt-4o-mini"),
|
||
"dify": ("DIFY_API_KEY", "DIFY_BASE_URL", "DIFY_MODEL", "dify-chatflow"),
|
||
"ollama": (None, "OLLAMA_BASE_URL", "OLLAMA_MODEL", "qwen2.5:7b"),
|
||
}
|
||
|
||
SYSTEM_PROMPT = """你是一位化学品风险识别专家。请分析文本内容,判断是否涉及管制化学品、毒品前体或非法药物交易。
|
||
|
||
请以 JSON 格式回答,包含以下字段:
|
||
- is_risky: 布尔值,是否涉及风险
|
||
- substances: 数组,涉及的具体物质名称或CAS号
|
||
- risk_level: 字符串,风险等级(高/中/低)
|
||
- reason: 字符串,判定理由(简要)
|
||
|
||
示例输出:
|
||
{"is_risky": true, "substances": ["甲基苯丙胺", "CAS 537-46-2"], "risk_level": "高", "reason": "文本中明确提到毒品名称和交易信息"}
|
||
"""
|
||
|
||
USER_PROMPT_TEMPLATE = """请分析以下内容是否涉及管制化学品或毒品:
|
||
|
||
【图片分析结果】
|
||
{raw_response}
|
||
|
||
【原始文本】
|
||
{original_text}
|
||
|
||
请以 JSON 格式输出分析结果。"""
|
||
|
||
|
||
# ========== 数据类 ==========
|
||
@dataclass
|
||
class VerifyConfig:
|
||
api_type: str = "openai"
|
||
api_key: str = ""
|
||
base_url: Optional[str] = None
|
||
model: str = "gpt-4o-mini"
|
||
user_id: str = "default-user"
|
||
|
||
|
||
@dataclass
|
||
class VerificationResult:
|
||
is_risky: Optional[bool] = None
|
||
substances: List[str] = field(default_factory=list)
|
||
risk_level: str = ""
|
||
reason: str = ""
|
||
raw_response: str = ""
|
||
|
||
def to_columns(self) -> dict:
|
||
return {
|
||
"llm_is_risky": self.is_risky,
|
||
"llm_substances": " | ".join(self.substances) if self.substances else "",
|
||
"llm_risk_level": self.risk_level,
|
||
"llm_reason": self.reason,
|
||
"llm_raw_response": self.raw_response,
|
||
}
|
||
|
||
|
||
# ========== 工具函数 ==========
|
||
def load_env_file(env_path: str) -> None:
|
||
"""从 .env 文件加载环境变量"""
|
||
env_file = Path(env_path)
|
||
if not env_file.exists():
|
||
return
|
||
print(f"加载环境配置: {env_file}")
|
||
with open(env_file, "r", encoding="utf-8") as f:
|
||
for line in f:
|
||
line = line.strip()
|
||
if not line or line.startswith("#"):
|
||
continue
|
||
if line.startswith("export "):
|
||
line = line[7:]
|
||
if "=" in line:
|
||
key, _, value = line.partition("=")
|
||
os.environ[key.strip()] = value.strip().strip('"').strip("'")
|
||
|
||
|
||
def get_config() -> VerifyConfig:
|
||
"""获取验证配置,优先使用 VERIFY_ 前缀"""
|
||
api_type = (os.getenv("VERIFY_API_TYPE") or os.getenv("LLM_API_TYPE") or "openai").lower()
|
||
mapping = ENV_MAPPING.get(api_type, (None, None, None, "gpt-4o-mini"))
|
||
key_env, url_env, model_env, default_model = mapping
|
||
|
||
return VerifyConfig(
|
||
api_type=api_type,
|
||
api_key=os.getenv("VERIFY_API_KEY") or (os.getenv(key_env) if key_env else "") or "",
|
||
base_url=os.getenv("VERIFY_BASE_URL") or (os.getenv(url_env) if url_env else None),
|
||
model=os.getenv("VERIFY_MODEL") or (os.getenv(model_env) if model_env else default_model) or default_model,
|
||
user_id=os.getenv("VERIFY_USER_ID") or os.getenv("DIFY_USER_ID") or "default-user",
|
||
)
|
||
|
||
|
||
def parse_json_response(content: str) -> dict:
|
||
"""从 LLM 响应提取 JSON,处理 markdown 代码块"""
|
||
# 移除 markdown 代码块
|
||
if "```json" in content:
|
||
start = content.find("```json") + 7
|
||
end = content.find("```", start)
|
||
content = content[start:end].strip()
|
||
elif "```" in content:
|
||
start = content.find("```") + 3
|
||
end = content.find("```", start)
|
||
content = content[start:end].strip()
|
||
|
||
try:
|
||
start = content.find("{")
|
||
end = content.rfind("}") + 1
|
||
if start >= 0 and end > start:
|
||
return json.loads(content[start:end])
|
||
except json.JSONDecodeError:
|
||
pass
|
||
|
||
return {"is_risky": None, "substances": [], "risk_level": "未知", "reason": "JSON 解析失败"}
|
||
|
||
|
||
def build_prompt(row: pd.Series, max_len: int = 3000) -> str:
|
||
"""构建用户提示"""
|
||
raw = str(row.get("raw_response", "") or "")
|
||
text = str(row.get("文本", "") or "")
|
||
if len(raw) > max_len:
|
||
raw = raw[:max_len] + "...(截断)"
|
||
if len(text) > max_len:
|
||
text = text[:max_len] + "...(截断)"
|
||
return USER_PROMPT_TEMPLATE.format(raw_response=raw, original_text=text)
|
||
|
||
|
||
# ========== 验证器类 ==========
|
||
class LLMVerifier(ABC):
|
||
"""LLM 验证器抽象基类"""
|
||
|
||
@abstractmethod
|
||
def verify(self, row: pd.Series) -> VerificationResult:
|
||
pass
|
||
|
||
|
||
class OpenAIVerifier(LLMVerifier):
|
||
"""OpenAI 兼容 API 验证器 (支持 OpenAI, DMX, Ollama)"""
|
||
|
||
def __init__(self, config: VerifyConfig):
|
||
if not HAS_OPENAI:
|
||
raise ImportError("请安装 openai: pip install openai")
|
||
if config.api_type != "ollama" and not config.api_key:
|
||
raise ValueError("未提供 API Key")
|
||
|
||
base_url = config.base_url
|
||
if config.api_type == "ollama":
|
||
base_url = (config.base_url or "http://localhost:11434") + "/v1"
|
||
|
||
self.client = openai.OpenAI(
|
||
api_key=config.api_key or "ollama",
|
||
base_url=base_url,
|
||
)
|
||
self.model = config.model
|
||
|
||
def verify(self, row: pd.Series) -> VerificationResult:
|
||
prompt = build_prompt(row)
|
||
try:
|
||
response = self.client.chat.completions.create(
|
||
model=self.model,
|
||
messages=[
|
||
{"role": "system", "content": SYSTEM_PROMPT},
|
||
{"role": "user", "content": prompt},
|
||
],
|
||
temperature=0.1,
|
||
max_tokens=500,
|
||
)
|
||
content = response.choices[0].message.content or ""
|
||
if response.choices[0].finish_reason != "stop":
|
||
return VerificationResult(
|
||
risk_level="错误",
|
||
reason=f"响应不完整 (finish_reason={response.choices[0].finish_reason})",
|
||
raw_response=content,
|
||
)
|
||
parsed = parse_json_response(content)
|
||
return VerificationResult(
|
||
is_risky=parsed.get("is_risky"),
|
||
substances=parsed.get("substances", []),
|
||
risk_level=parsed.get("risk_level", ""),
|
||
reason=parsed.get("reason", ""),
|
||
raw_response=content,
|
||
)
|
||
except Exception as e:
|
||
return VerificationResult(risk_level="错误", reason=f"API 调用失败: {e}", raw_response=str(e))
|
||
|
||
|
||
class DifyVerifier(LLMVerifier):
|
||
"""Dify API 验证器"""
|
||
|
||
def __init__(self, config: VerifyConfig):
|
||
if not HAS_REQUESTS:
|
||
raise ImportError("请安装 requests: pip install requests")
|
||
if not config.api_key:
|
||
raise ValueError("未提供 Dify API Key")
|
||
self.base_url = (config.base_url or "").rstrip("/")
|
||
self.api_key = config.api_key
|
||
self.user_id = config.user_id
|
||
|
||
def verify(self, row: pd.Series) -> VerificationResult:
|
||
prompt = f"{SYSTEM_PROMPT}\n\n{build_prompt(row)}"
|
||
try:
|
||
resp = requests.post(
|
||
f"{self.base_url}/v1/chat-messages",
|
||
headers={"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"},
|
||
json={"inputs": {}, "query": prompt, "response_mode": "blocking", "user": self.user_id},
|
||
timeout=120,
|
||
verify=False,
|
||
)
|
||
resp.raise_for_status()
|
||
content = resp.json().get("answer", "")
|
||
parsed = parse_json_response(content)
|
||
return VerificationResult(
|
||
is_risky=parsed.get("is_risky"),
|
||
substances=parsed.get("substances", []),
|
||
risk_level=parsed.get("risk_level", ""),
|
||
reason=parsed.get("reason", ""),
|
||
raw_response=content,
|
||
)
|
||
except Exception as e:
|
||
return VerificationResult(risk_level="错误", reason=f"Dify 调用失败: {e}", raw_response=str(e))
|
||
|
||
|
||
class MockVerifier(LLMVerifier):
|
||
"""Mock 验证器(测试用)"""
|
||
|
||
RISK_KEYWORDS = [
|
||
"毒品", "非法", "管制", "药物", "化学品", "CAS", "阿片", "芬太尼",
|
||
"冰毒", "大麻", "可卡因", "海洛因", "摇头丸", "麻黄碱",
|
||
"fentanyl", "methamphetamine", "cocaine", "heroin", "mdma",
|
||
"ketamine", "lsd", "precursor", "controlled",
|
||
]
|
||
|
||
def verify(self, row: pd.Series) -> VerificationResult:
|
||
all_text = f"{row.get('raw_response', '')} {row.get('文本', '')}".lower()
|
||
found = [kw for kw in self.RISK_KEYWORDS if kw.lower() in all_text]
|
||
is_risky = len(found) > 0
|
||
return VerificationResult(
|
||
is_risky=is_risky,
|
||
substances=found[:5],
|
||
risk_level="中" if is_risky else "低",
|
||
reason=f"Mock模式 - 发现关键词: {found[:3]}" if is_risky else "Mock模式 - 未发现风险关键词",
|
||
raw_response="(mock)",
|
||
)
|
||
|
||
|
||
def create_verifier(config: VerifyConfig) -> LLMVerifier:
|
||
"""根据配置创建验证器"""
|
||
if config.api_type == "mock":
|
||
return MockVerifier()
|
||
elif config.api_type == "dify":
|
||
return DifyVerifier(config)
|
||
elif config.api_type in ("openai", "dmx", "ollama"):
|
||
return OpenAIVerifier(config)
|
||
else:
|
||
raise ValueError(f"不支持的 API 类型: {config.api_type}")
|
||
|
||
|
||
# ========== 数据处理 ==========
|
||
def load_excel(file_path: Path) -> pd.DataFrame:
|
||
"""加载 Excel 文件"""
|
||
if not file_path.exists():
|
||
raise FileNotFoundError(f"文件不存在: {file_path}")
|
||
return pd.read_excel(file_path)
|
||
|
||
|
||
def find_unmatched(
|
||
original_df: pd.DataFrame,
|
||
matched_df: pd.DataFrame,
|
||
confidence_col: str = "confidence",
|
||
confidence_levels: List[str] = None,
|
||
) -> pd.DataFrame:
|
||
"""找出高置信度但未被关键词匹配的行"""
|
||
levels = confidence_levels or CONFIDENCE_LEVELS
|
||
|
||
if confidence_col not in original_df.columns:
|
||
print(f"警告: 原始文件中不存在 '{confidence_col}' 列")
|
||
print(f"可用列: {original_df.columns.tolist()}")
|
||
return pd.DataFrame()
|
||
|
||
# 高置信度行索引
|
||
conf_lower = original_df[confidence_col].astype(str).str.lower()
|
||
levels_lower = [l.lower() for l in levels]
|
||
high_conf_idx = set(original_df[conf_lower.isin(levels_lower)].index)
|
||
matched_idx = set(matched_df.index)
|
||
unmatched_idx = high_conf_idx - matched_idx
|
||
|
||
# 统计信息
|
||
print(f"\n{'='*50}")
|
||
print("数据比对统计")
|
||
print(f"{'='*50}")
|
||
print(f"原始数据总行数: {len(original_df)}")
|
||
print(f"高置信度 ({'/'.join(levels)}) 行数: {len(high_conf_idx)}")
|
||
print(f"关键词匹配到的行数: {len(matched_idx)}")
|
||
print(f"高置信度中已匹配: {len(high_conf_idx & matched_idx)}")
|
||
print(f"高置信度中未匹配 (需验证): {len(unmatched_idx)}")
|
||
print(f"{'='*50}\n")
|
||
|
||
if not unmatched_idx:
|
||
return pd.DataFrame()
|
||
return original_df.loc[list(unmatched_idx)].copy()
|
||
|
||
|
||
def verify_batch(df: pd.DataFrame, verifier: LLMVerifier, delay: float = REQUEST_DELAY, limit: int = 0) -> pd.DataFrame:
|
||
"""批量验证记录"""
|
||
if limit > 0:
|
||
df = df.head(limit)
|
||
|
||
total = len(df)
|
||
print(f"开始 LLM 验证,共 {total} 条记录...")
|
||
print("-" * 50)
|
||
|
||
results = []
|
||
start_time = time.time()
|
||
|
||
for i, (idx, row) in enumerate(df.iterrows()):
|
||
if (i + 1) % 10 == 0 or i == 0 or i == total - 1:
|
||
elapsed = time.time() - start_time
|
||
speed = (i + 1) / elapsed if elapsed > 0 else 0
|
||
print(f"进度: {i + 1}/{total} ({(i+1)/total*100:.1f}%) - 速度: {speed:.1f} 条/秒")
|
||
|
||
result = verifier.verify(row)
|
||
results.append({"original_index": idx, **result.to_columns()})
|
||
|
||
if delay > 0 and i < total - 1:
|
||
time.sleep(delay)
|
||
|
||
results_df = pd.DataFrame(results).set_index("original_index")
|
||
verified_df = df.copy()
|
||
for col in results_df.columns:
|
||
verified_df[col] = results_df[col]
|
||
return verified_df
|
||
|
||
|
||
# ========== 结果输出 ==========
|
||
def save_results(df: pd.DataFrame, output_file: Path, risky_only: bool = False) -> None:
|
||
"""保存结果"""
|
||
if risky_only and "llm_is_risky" in df.columns:
|
||
df = df[df["llm_is_risky"] == True]
|
||
df.to_excel(output_file, index=False, engine="openpyxl")
|
||
print(f"\n已保存 {len(df)} 条记录到: {output_file}")
|
||
|
||
|
||
def print_summary(df: pd.DataFrame) -> None:
|
||
"""打印验证摘要"""
|
||
print(f"\n{'='*50}")
|
||
print("验证结果摘要")
|
||
print(f"{'='*50}")
|
||
|
||
total = len(df)
|
||
if "llm_is_risky" not in df.columns:
|
||
print(f"总记录数: {total}")
|
||
return
|
||
|
||
risky = (df["llm_is_risky"] == True).sum()
|
||
not_risky = (df["llm_is_risky"] == False).sum()
|
||
unknown = total - risky - not_risky
|
||
|
||
print(f"总验证数: {total}")
|
||
print(f" ├─ LLM 判定有风险: {risky} ({risky/total*100:.1f}%)")
|
||
print(f" ├─ LLM 判定无风险: {not_risky} ({not_risky/total*100:.1f}%)")
|
||
if unknown > 0:
|
||
print(f" └─ 判定失败/未知: {unknown}")
|
||
|
||
if "llm_risk_level" in df.columns:
|
||
print(f"\n风险等级分布:")
|
||
for level, count in df["llm_risk_level"].value_counts().items():
|
||
print(f" - {level}: {count}")
|
||
print(f"{'='*50}")
|
||
|
||
|
||
# ========== CLI ==========
|
||
def parse_args():
|
||
parser = argparse.ArgumentParser(
|
||
description="验证高置信度未匹配记录",
|
||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||
epilog="""
|
||
示例:
|
||
python3 verify_high_confidence.py -o original.xlsx -m matched.xlsx
|
||
python3 verify_high_confidence.py -o original.xlsx -m matched.xlsx --mock --limit 5
|
||
python3 verify_high_confidence.py -o original.xlsx -m matched.xlsx --api dmx --model gpt-4o-mini
|
||
""",
|
||
)
|
||
|
||
parser.add_argument("-o", "--original", required=True, help="原始 Excel 文件路径")
|
||
parser.add_argument("-m", "--matched", required=True, help="keyword_matcher 匹配结果文件路径")
|
||
parser.add_argument("-r", "--result", help="输出结果文件路径 (默认: 原始文件名_llm_verified.xlsx)")
|
||
|
||
parser.add_argument("--env-file", help="环境变量文件路径 (默认: ../.env)")
|
||
parser.add_argument("--api", choices=["openai", "dmx", "dify", "ollama"], help="LLM API 类型")
|
||
parser.add_argument("--model", help="LLM 模型名称")
|
||
parser.add_argument("--base-url", help="API base URL")
|
||
parser.add_argument("--api-key", help="API Key")
|
||
parser.add_argument("--mock", action="store_true", help="使用 mock 模式(不调用 API)")
|
||
|
||
parser.add_argument("--confidence", nargs="+", default=["High", "Medium"], help="需要验证的置信度级别")
|
||
parser.add_argument("--confidence-col", default="confidence", help="置信度列名")
|
||
|
||
parser.add_argument("--delay", type=float, default=REQUEST_DELAY, help="API 请求间隔秒数")
|
||
parser.add_argument("--limit", type=int, default=0, help="限制验证条数 (0=全部)")
|
||
parser.add_argument("--risky-only", action="store_true", help="只保存有风险的记录")
|
||
|
||
return parser.parse_args()
|
||
|
||
|
||
def main():
|
||
args = parse_args()
|
||
|
||
# 加载 .env
|
||
base_dir = Path(__file__).resolve().parent
|
||
env_file = args.env_file or str(base_dir.parent / ".env")
|
||
load_env_file(env_file)
|
||
|
||
# 获取配置
|
||
config = get_config()
|
||
|
||
# 命令行参数覆盖
|
||
if args.mock:
|
||
config.api_type = "mock"
|
||
elif args.api:
|
||
config.api_type = args.api
|
||
if args.model:
|
||
config.model = args.model
|
||
if args.base_url:
|
||
config.base_url = args.base_url
|
||
if args.api_key:
|
||
config.api_key = args.api_key
|
||
|
||
# 文件路径
|
||
original_file = Path(args.original)
|
||
matched_file = Path(args.matched)
|
||
result_file = Path(args.result) if args.result else original_file.parent / f"{original_file.stem}_llm_verified.xlsx"
|
||
|
||
print("=" * 60)
|
||
print("高置信度未匹配记录验证")
|
||
print("=" * 60)
|
||
print(f"原始文件: {original_file}")
|
||
print(f"匹配结果: {matched_file}")
|
||
print(f"输出文件: {result_file}")
|
||
print(f"置信度级别: {args.confidence}")
|
||
print(f"API 类型: {config.api_type}")
|
||
print(f"模型: {config.model}")
|
||
if config.base_url:
|
||
print(f"Base URL: {config.base_url}")
|
||
|
||
# 加载数据
|
||
print("\n正在加载数据...")
|
||
original_df = load_excel(original_file)
|
||
matched_df = load_excel(matched_file)
|
||
|
||
# 找出未匹配的高置信度行
|
||
unmatched_df = find_unmatched(original_df, matched_df, args.confidence_col, args.confidence)
|
||
|
||
if unmatched_df.empty:
|
||
print("\n所有高置信度行都已被关键词匹配,无需验证。")
|
||
return
|
||
|
||
# 创建验证器
|
||
try:
|
||
verifier = create_verifier(config)
|
||
except (ImportError, ValueError) as e:
|
||
print(f"\n错误: {e}")
|
||
sys.exit(1)
|
||
|
||
# 执行验证
|
||
verified_df = verify_batch(unmatched_df, verifier, delay=args.delay, limit=args.limit)
|
||
|
||
# 打印摘要并保存
|
||
print_summary(verified_df)
|
||
save_results(verified_df, result_file, args.risky_only)
|
||
print("\n✓ 验证完成!")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|