Files
chem-risk-detect/scripts/verify_high_confidence.py

518 lines
18 KiB
Python
Raw Permalink Normal View History

2026-01-18 18:25:36 +08:00
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
验证高置信度未匹配记录
功能比对 keyword_matcher 结果与原始 Excel找出高置信度未匹配行调用 LLM 二次验证
用法
python3 verify_high_confidence.py -o original.xlsx -m matched.xlsx
python3 verify_high_confidence.py -o original.xlsx -m matched.xlsx --mock --limit 5
"""
import argparse
import json
import os
import sys
import time
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from pathlib import Path
from typing import List, Optional
import pandas as pd
# 可选依赖
try:
import openai
HAS_OPENAI = True
except ImportError:
HAS_OPENAI = False
try:
import requests
import urllib3
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
HAS_REQUESTS = True
except ImportError:
HAS_REQUESTS = False
# ========== 常量与配置 ==========
CONFIDENCE_LEVELS = ["High", "Medium"]
REQUEST_DELAY = 0.5
# 环境变量映射: api_type -> (key_env, url_env, model_env, default_model)
ENV_MAPPING = {
"openai": ("OPENAI_API_KEY", "OPENAI_BASE_URL", "OPENAI_MODEL", "gpt-4o-mini"),
"dmx": ("DMX_API_KEY", "DMX_BASE_URL", "DMX_MODEL", "gpt-4o-mini"),
"dify": ("DIFY_API_KEY", "DIFY_BASE_URL", "DIFY_MODEL", "dify-chatflow"),
"ollama": (None, "OLLAMA_BASE_URL", "OLLAMA_MODEL", "qwen2.5:7b"),
}
SYSTEM_PROMPT = """你是一位化学品风险识别专家。请分析文本内容,判断是否涉及管制化学品、毒品前体或非法药物交易。
请以 JSON 格式回答包含以下字段
- is_risky: 布尔值是否涉及风险
- substances: 数组涉及的具体物质名称或CAS号
- risk_level: 字符串风险等级//
- reason: 字符串判定理由简要
示例输出
{"is_risky": true, "substances": ["甲基苯丙胺", "CAS 537-46-2"], "risk_level": "", "reason": "文本中明确提到毒品名称和交易信息"}
"""
USER_PROMPT_TEMPLATE = """请分析以下内容是否涉及管制化学品或毒品:
图片分析结果
{raw_response}
原始文本
{original_text}
请以 JSON 格式输出分析结果"""
# ========== 数据类 ==========
@dataclass
class VerifyConfig:
api_type: str = "openai"
api_key: str = ""
base_url: Optional[str] = None
model: str = "gpt-4o-mini"
user_id: str = "default-user"
@dataclass
class VerificationResult:
is_risky: Optional[bool] = None
substances: List[str] = field(default_factory=list)
risk_level: str = ""
reason: str = ""
raw_response: str = ""
def to_columns(self) -> dict:
return {
"llm_is_risky": self.is_risky,
"llm_substances": " | ".join(self.substances) if self.substances else "",
"llm_risk_level": self.risk_level,
"llm_reason": self.reason,
"llm_raw_response": self.raw_response,
}
# ========== 工具函数 ==========
def load_env_file(env_path: str) -> None:
"""从 .env 文件加载环境变量"""
env_file = Path(env_path)
if not env_file.exists():
return
print(f"加载环境配置: {env_file}")
with open(env_file, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line or line.startswith("#"):
continue
if line.startswith("export "):
line = line[7:]
if "=" in line:
key, _, value = line.partition("=")
os.environ[key.strip()] = value.strip().strip('"').strip("'")
def get_config() -> VerifyConfig:
"""获取验证配置,优先使用 VERIFY_ 前缀"""
api_type = (os.getenv("VERIFY_API_TYPE") or os.getenv("LLM_API_TYPE") or "openai").lower()
mapping = ENV_MAPPING.get(api_type, (None, None, None, "gpt-4o-mini"))
key_env, url_env, model_env, default_model = mapping
return VerifyConfig(
api_type=api_type,
api_key=os.getenv("VERIFY_API_KEY") or (os.getenv(key_env) if key_env else "") or "",
base_url=os.getenv("VERIFY_BASE_URL") or (os.getenv(url_env) if url_env else None),
model=os.getenv("VERIFY_MODEL") or (os.getenv(model_env) if model_env else default_model) or default_model,
user_id=os.getenv("VERIFY_USER_ID") or os.getenv("DIFY_USER_ID") or "default-user",
)
def parse_json_response(content: str) -> dict:
"""从 LLM 响应提取 JSON处理 markdown 代码块"""
# 移除 markdown 代码块
if "```json" in content:
start = content.find("```json") + 7
end = content.find("```", start)
content = content[start:end].strip()
elif "```" in content:
start = content.find("```") + 3
end = content.find("```", start)
content = content[start:end].strip()
try:
start = content.find("{")
end = content.rfind("}") + 1
if start >= 0 and end > start:
return json.loads(content[start:end])
except json.JSONDecodeError:
pass
return {"is_risky": None, "substances": [], "risk_level": "未知", "reason": "JSON 解析失败"}
def build_prompt(row: pd.Series, max_len: int = 3000) -> str:
"""构建用户提示"""
raw = str(row.get("raw_response", "") or "")
text = str(row.get("文本", "") or "")
if len(raw) > max_len:
raw = raw[:max_len] + "...(截断)"
if len(text) > max_len:
text = text[:max_len] + "...(截断)"
return USER_PROMPT_TEMPLATE.format(raw_response=raw, original_text=text)
# ========== 验证器类 ==========
class LLMVerifier(ABC):
"""LLM 验证器抽象基类"""
@abstractmethod
def verify(self, row: pd.Series) -> VerificationResult:
pass
class OpenAIVerifier(LLMVerifier):
"""OpenAI 兼容 API 验证器 (支持 OpenAI, DMX, Ollama)"""
def __init__(self, config: VerifyConfig):
if not HAS_OPENAI:
raise ImportError("请安装 openai: pip install openai")
if config.api_type != "ollama" and not config.api_key:
raise ValueError("未提供 API Key")
base_url = config.base_url
if config.api_type == "ollama":
base_url = (config.base_url or "http://localhost:11434") + "/v1"
self.client = openai.OpenAI(
api_key=config.api_key or "ollama",
base_url=base_url,
)
self.model = config.model
def verify(self, row: pd.Series) -> VerificationResult:
prompt = build_prompt(row)
try:
response = self.client.chat.completions.create(
model=self.model,
messages=[
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": prompt},
],
temperature=0.1,
max_tokens=500,
)
content = response.choices[0].message.content or ""
if response.choices[0].finish_reason != "stop":
return VerificationResult(
risk_level="错误",
reason=f"响应不完整 (finish_reason={response.choices[0].finish_reason})",
raw_response=content,
)
parsed = parse_json_response(content)
return VerificationResult(
is_risky=parsed.get("is_risky"),
substances=parsed.get("substances", []),
risk_level=parsed.get("risk_level", ""),
reason=parsed.get("reason", ""),
raw_response=content,
)
except Exception as e:
return VerificationResult(risk_level="错误", reason=f"API 调用失败: {e}", raw_response=str(e))
class DifyVerifier(LLMVerifier):
"""Dify API 验证器"""
def __init__(self, config: VerifyConfig):
if not HAS_REQUESTS:
raise ImportError("请安装 requests: pip install requests")
if not config.api_key:
raise ValueError("未提供 Dify API Key")
self.base_url = (config.base_url or "").rstrip("/")
self.api_key = config.api_key
self.user_id = config.user_id
def verify(self, row: pd.Series) -> VerificationResult:
prompt = f"{SYSTEM_PROMPT}\n\n{build_prompt(row)}"
try:
resp = requests.post(
f"{self.base_url}/v1/chat-messages",
headers={"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"},
json={"inputs": {}, "query": prompt, "response_mode": "blocking", "user": self.user_id},
timeout=120,
verify=False,
)
resp.raise_for_status()
content = resp.json().get("answer", "")
parsed = parse_json_response(content)
return VerificationResult(
is_risky=parsed.get("is_risky"),
substances=parsed.get("substances", []),
risk_level=parsed.get("risk_level", ""),
reason=parsed.get("reason", ""),
raw_response=content,
)
except Exception as e:
return VerificationResult(risk_level="错误", reason=f"Dify 调用失败: {e}", raw_response=str(e))
class MockVerifier(LLMVerifier):
"""Mock 验证器(测试用)"""
RISK_KEYWORDS = [
"毒品", "非法", "管制", "药物", "化学品", "CAS", "阿片", "芬太尼",
"冰毒", "大麻", "可卡因", "海洛因", "摇头丸", "麻黄碱",
"fentanyl", "methamphetamine", "cocaine", "heroin", "mdma",
"ketamine", "lsd", "precursor", "controlled",
]
def verify(self, row: pd.Series) -> VerificationResult:
all_text = f"{row.get('raw_response', '')} {row.get('文本', '')}".lower()
found = [kw for kw in self.RISK_KEYWORDS if kw.lower() in all_text]
is_risky = len(found) > 0
return VerificationResult(
is_risky=is_risky,
substances=found[:5],
risk_level="" if is_risky else "",
reason=f"Mock模式 - 发现关键词: {found[:3]}" if is_risky else "Mock模式 - 未发现风险关键词",
raw_response="(mock)",
)
def create_verifier(config: VerifyConfig) -> LLMVerifier:
"""根据配置创建验证器"""
if config.api_type == "mock":
return MockVerifier()
elif config.api_type == "dify":
return DifyVerifier(config)
elif config.api_type in ("openai", "dmx", "ollama"):
return OpenAIVerifier(config)
else:
raise ValueError(f"不支持的 API 类型: {config.api_type}")
# ========== 数据处理 ==========
def load_excel(file_path: Path) -> pd.DataFrame:
"""加载 Excel 文件"""
if not file_path.exists():
raise FileNotFoundError(f"文件不存在: {file_path}")
return pd.read_excel(file_path)
def find_unmatched(
original_df: pd.DataFrame,
matched_df: pd.DataFrame,
confidence_col: str = "confidence",
confidence_levels: List[str] = None,
) -> pd.DataFrame:
"""找出高置信度但未被关键词匹配的行"""
levels = confidence_levels or CONFIDENCE_LEVELS
if confidence_col not in original_df.columns:
print(f"警告: 原始文件中不存在 '{confidence_col}'")
print(f"可用列: {original_df.columns.tolist()}")
return pd.DataFrame()
# 高置信度行索引
conf_lower = original_df[confidence_col].astype(str).str.lower()
levels_lower = [l.lower() for l in levels]
high_conf_idx = set(original_df[conf_lower.isin(levels_lower)].index)
matched_idx = set(matched_df.index)
unmatched_idx = high_conf_idx - matched_idx
# 统计信息
print(f"\n{'='*50}")
print("数据比对统计")
print(f"{'='*50}")
print(f"原始数据总行数: {len(original_df)}")
print(f"高置信度 ({'/'.join(levels)}) 行数: {len(high_conf_idx)}")
print(f"关键词匹配到的行数: {len(matched_idx)}")
print(f"高置信度中已匹配: {len(high_conf_idx & matched_idx)}")
print(f"高置信度中未匹配 (需验证): {len(unmatched_idx)}")
print(f"{'='*50}\n")
if not unmatched_idx:
return pd.DataFrame()
return original_df.loc[list(unmatched_idx)].copy()
def verify_batch(df: pd.DataFrame, verifier: LLMVerifier, delay: float = REQUEST_DELAY, limit: int = 0) -> pd.DataFrame:
"""批量验证记录"""
if limit > 0:
df = df.head(limit)
total = len(df)
print(f"开始 LLM 验证,共 {total} 条记录...")
print("-" * 50)
results = []
start_time = time.time()
for i, (idx, row) in enumerate(df.iterrows()):
if (i + 1) % 10 == 0 or i == 0 or i == total - 1:
elapsed = time.time() - start_time
speed = (i + 1) / elapsed if elapsed > 0 else 0
print(f"进度: {i + 1}/{total} ({(i+1)/total*100:.1f}%) - 速度: {speed:.1f} 条/秒")
result = verifier.verify(row)
results.append({"original_index": idx, **result.to_columns()})
if delay > 0 and i < total - 1:
time.sleep(delay)
results_df = pd.DataFrame(results).set_index("original_index")
verified_df = df.copy()
for col in results_df.columns:
verified_df[col] = results_df[col]
return verified_df
# ========== 结果输出 ==========
def save_results(df: pd.DataFrame, output_file: Path, risky_only: bool = False) -> None:
"""保存结果"""
if risky_only and "llm_is_risky" in df.columns:
df = df[df["llm_is_risky"] == True]
df.to_excel(output_file, index=False, engine="openpyxl")
print(f"\n已保存 {len(df)} 条记录到: {output_file}")
def print_summary(df: pd.DataFrame) -> None:
"""打印验证摘要"""
print(f"\n{'='*50}")
print("验证结果摘要")
print(f"{'='*50}")
total = len(df)
if "llm_is_risky" not in df.columns:
print(f"总记录数: {total}")
return
risky = (df["llm_is_risky"] == True).sum()
not_risky = (df["llm_is_risky"] == False).sum()
unknown = total - risky - not_risky
print(f"总验证数: {total}")
print(f" ├─ LLM 判定有风险: {risky} ({risky/total*100:.1f}%)")
print(f" ├─ LLM 判定无风险: {not_risky} ({not_risky/total*100:.1f}%)")
if unknown > 0:
print(f" └─ 判定失败/未知: {unknown}")
if "llm_risk_level" in df.columns:
print(f"\n风险等级分布:")
for level, count in df["llm_risk_level"].value_counts().items():
print(f" - {level}: {count}")
print(f"{'='*50}")
# ========== CLI ==========
def parse_args():
parser = argparse.ArgumentParser(
description="验证高置信度未匹配记录",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
示例:
python3 verify_high_confidence.py -o original.xlsx -m matched.xlsx
python3 verify_high_confidence.py -o original.xlsx -m matched.xlsx --mock --limit 5
python3 verify_high_confidence.py -o original.xlsx -m matched.xlsx --api dmx --model gpt-4o-mini
""",
)
parser.add_argument("-o", "--original", required=True, help="原始 Excel 文件路径")
parser.add_argument("-m", "--matched", required=True, help="keyword_matcher 匹配结果文件路径")
parser.add_argument("-r", "--result", help="输出结果文件路径 (默认: 原始文件名_llm_verified.xlsx)")
parser.add_argument("--env-file", help="环境变量文件路径 (默认: ../.env)")
parser.add_argument("--api", choices=["openai", "dmx", "dify", "ollama"], help="LLM API 类型")
parser.add_argument("--model", help="LLM 模型名称")
parser.add_argument("--base-url", help="API base URL")
parser.add_argument("--api-key", help="API Key")
parser.add_argument("--mock", action="store_true", help="使用 mock 模式(不调用 API")
parser.add_argument("--confidence", nargs="+", default=["High", "Medium"], help="需要验证的置信度级别")
parser.add_argument("--confidence-col", default="confidence", help="置信度列名")
parser.add_argument("--delay", type=float, default=REQUEST_DELAY, help="API 请求间隔秒数")
parser.add_argument("--limit", type=int, default=0, help="限制验证条数 (0=全部)")
parser.add_argument("--risky-only", action="store_true", help="只保存有风险的记录")
return parser.parse_args()
def main():
args = parse_args()
# 加载 .env
base_dir = Path(__file__).resolve().parent
env_file = args.env_file or str(base_dir.parent / ".env")
load_env_file(env_file)
# 获取配置
config = get_config()
# 命令行参数覆盖
if args.mock:
config.api_type = "mock"
elif args.api:
config.api_type = args.api
if args.model:
config.model = args.model
if args.base_url:
config.base_url = args.base_url
if args.api_key:
config.api_key = args.api_key
# 文件路径
original_file = Path(args.original)
matched_file = Path(args.matched)
result_file = Path(args.result) if args.result else original_file.parent / f"{original_file.stem}_llm_verified.xlsx"
print("=" * 60)
print("高置信度未匹配记录验证")
print("=" * 60)
print(f"原始文件: {original_file}")
print(f"匹配结果: {matched_file}")
print(f"输出文件: {result_file}")
print(f"置信度级别: {args.confidence}")
print(f"API 类型: {config.api_type}")
print(f"模型: {config.model}")
if config.base_url:
print(f"Base URL: {config.base_url}")
# 加载数据
print("\n正在加载数据...")
original_df = load_excel(original_file)
matched_df = load_excel(matched_file)
# 找出未匹配的高置信度行
unmatched_df = find_unmatched(original_df, matched_df, args.confidence_col, args.confidence)
if unmatched_df.empty:
print("\n所有高置信度行都已被关键词匹配,无需验证。")
return
# 创建验证器
try:
verifier = create_verifier(config)
except (ImportError, ValueError) as e:
print(f"\n错误: {e}")
sys.exit(1)
# 执行验证
verified_df = verify_batch(unmatched_df, verifier, delay=args.delay, limit=args.limit)
# 打印摘要并保存
print_summary(verified_df)
save_results(verified_df, result_file, args.risky_only)
print("\n✓ 验证完成!")
if __name__ == "__main__":
main()