7 Commits

Author SHA1 Message Date
5825cf81b7 modify: cleanup project structure and docs 2025-11-02 08:40:28 +08:00
a5e3c4c1da modify: 更新了requirements.txt 2025-10-10 08:14:46 +08:00
69488904a0 modify: 删除隐私文件 2025-10-10 08:13:32 +08:00
00e766eaa7 modify: 删除隐私文件 2025-10-10 08:13:22 +08:00
af4285e147 modify: 删除隐私文件 2025-10-10 08:12:07 +08:00
c69419d816 modify: 删除隐私文件 2025-10-10 08:10:25 +08:00
a9845d084e modify: 增加了热点识别和策略建议功能 2025-10-10 07:54:45 +08:00
26 changed files with 1919 additions and 602 deletions

12
.gitignore vendored Normal file
View File

@@ -0,0 +1,12 @@
sample/
strategy_evaluation_results.csv
run_metadata.json
*.log
simulation.html
.DS_Store
overview_series.html
tmp/
__pycache__/
*.py[cod]
.venv/
.streamlit/

36
AGENTS.md Normal file
View File

@@ -0,0 +1,36 @@
# Repository Guidelines
## Project Structure & Module Organization
- Source: `app.py` (Streamlit UI, data processing, forecasting, anomaly detection, evaluation).
- Docs & outputs: `docs/`, `overview_series.html`, `strategy_evaluation_results.csv`.
- Samples: `sample/` for example data only; avoid sensitive content.
- Meta: `requirements.txt`, `readme.md`, `LICENSE`, `CHANGELOG.md`.
## Build, Test, and Development Commands
- Create env: `python -m venv .venv && source .venv/bin/activate` (or follow conda steps in `readme.md`).
- Install deps: `pip install -r requirements.txt`.
- Run app: `streamlit run app.py` then open `http://localhost:8501`.
- Export artifacts: charts save as HTML (Plotly); forecasts may be written to CSV as noted in `readme.md`.
## Coding Style & Naming Conventions
- Python ≥3.8; 4-space indentation; UTF-8.
- Names: functions/variables `snake_case`; classes `PascalCase`; constants `UPPER_SNAKE_CASE`.
- Files: keep scope focused; use descriptive output names (e.g., `arima_forecast.csv`).
- Data handling: prefer pandas/NumPy vectorization; validate inputs; avoid global state except constants.
## Testing Guidelines
- Framework: pytest (recommended). Place tests under `tests/`.
- Naming: `test_<module>.py` and `test_<behavior>()`.
- Run: `pytest -q`. Focus on `load_and_clean_data`, aggregation, model selection, and metrics.
- Keep tests fast and deterministic; avoid large I/O. Use small DataFrame fixtures.
## Commit & Pull Request Guidelines
- Messages: concise, present tense. Prefixes seen: `modify:`, `Add`, `Update`.
- Include scope and reason: e.g., `modify: update requirements for statsmodels`.
- PRs: clear description, linked issues, repro steps/screenshots for UI, and notes on any schema or output changes.
## Security & Configuration Tips
- Do not commit real accident data or secrets. Use `sample/` for examples.
- Optional envs: `LOG_LEVEL=DEBUG`. Keep any API keys in environment variables, not in code.
- Validate Excel column names before processing; handle missing columns/rows defensively.

839
app.py
View File

@@ -1,23 +1,17 @@
from __future__ import annotations
import os
from datetime import datetime, timedelta
import json
import hashlib
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from typing import Optional
from sklearn.neighbors import KNeighborsRegressor
from sklearn.ensemble import IsolationForest
from sklearn.svm import SVR
import statsmodels.api as sm
from statsmodels.tsa.arima.model import ARIMA
import streamlit as st
import plotly.graph_objects as go
import plotly.express as px
# --- Optional deps (graceful fallback)
try:
@@ -40,179 +34,41 @@ except Exception:
HAS_OPENAI = False
# =======================
# 1. Data Integration
# =======================
@st.cache_data(show_spinner=False)
def load_and_clean_data(accident_file, strategy_file):
accident_df = pd.read_excel(accident_file, sheet_name=None)
accident_data = pd.concat(accident_df.values(), ignore_index=True)
from services.io import (
load_and_clean_data,
aggregate_daily_data,
aggregate_daily_data_by_region,
load_accident_records,
)
from services.forecast import (
arima_forecast_with_grid_search,
knn_forecast_counterfactual,
fit_and_extrapolate,
)
from services.strategy import (
evaluate_strategy_effectiveness,
generate_output_and_recommendations,
)
from services.metrics import evaluate_models
accident_data['事故时间'] = pd.to_datetime(accident_data['事故时间'])
accident_data = accident_data.dropna(subset=['事故时间', '所在街道', '事故类型'])
strategy_df = pd.read_excel(strategy_file)
strategy_df['发布时间'] = pd.to_datetime(strategy_df['发布时间'])
strategy_df = strategy_df.dropna(subset=['发布时间', '交通策略类型'])
severity_map = {'财损': 1, '伤人': 2, '亡人': 4}
accident_data['severity'] = accident_data['事故类型'].map(severity_map).fillna(1)
accident_data = accident_data[['事故时间', '所在街道', '事故类型', 'severity']] \
.rename(columns={'事故时间': 'date_time', '所在街道': 'region', '事故类型': 'category'})
strategy_df = strategy_df[['发布时间', '交通策略类型']] \
.rename(columns={'发布时间': 'date_time', '交通策略类型': 'strategy_type'})
return accident_data, strategy_df
@st.cache_data(show_spinner=False)
def aggregate_daily_data(accident_data: pd.DataFrame, strategy_data: pd.DataFrame) -> pd.DataFrame:
# City-level aggregation
accident_data = accident_data.copy()
strategy_data = strategy_data.copy()
accident_data['date'] = accident_data['date_time'].dt.date
daily_accidents = accident_data.groupby('date').agg(
accident_count=('date_time', 'count'),
severity=('severity', 'sum')
try:
from ui_sections import (
render_overview,
render_forecast,
render_model_eval,
render_strategy_eval,
render_hotspot,
)
daily_accidents.index = pd.to_datetime(daily_accidents.index)
strategy_data['date'] = strategy_data['date_time'].dt.date
daily_strategies = strategy_data.groupby('date')['strategy_type'].apply(list)
daily_strategies.index = pd.to_datetime(daily_strategies.index)
combined = daily_accidents.join(daily_strategies, how='left')
combined['strategy_type'] = combined['strategy_type'].apply(lambda x: x if isinstance(x, list) else [])
combined = combined.asfreq('D')
combined[['accident_count', 'severity']] = combined[['accident_count', 'severity']].fillna(0)
combined['strategy_type'] = combined['strategy_type'].apply(lambda x: x if isinstance(x, list) else [])
return combined
@st.cache_data(show_spinner=False)
def aggregate_daily_data_by_region(accident_data: pd.DataFrame, strategy_data: pd.DataFrame) -> pd.DataFrame:
"""区域维度聚合。策略按天广播到所有区域(若策略本身无区域字段)。"""
df = accident_data.copy()
df['date'] = df['date_time'].dt.date
g = df.groupby(['region', 'date']).agg(
accident_count=('date_time', 'count'),
severity=('severity', 'sum')
)
g.index = g.index.set_levels([g.index.levels[0], pd.to_datetime(g.index.levels[1])])
g = g.sort_index()
# 策略(每日列表)
s = strategy_data.copy()
s['date'] = s['date_time'].dt.date
daily_strategies = s.groupby('date')['strategy_type'].apply(list)
daily_strategies.index = pd.to_datetime(daily_strategies.index)
# 广播
regions = g.index.get_level_values(0).unique()
dates = pd.date_range(g.index.get_level_values(1).min(), g.index.get_level_values(1).max(), freq='D')
full_index = pd.MultiIndex.from_product([regions, dates], names=['region', 'date'])
g = g.reindex(full_index).fillna(0)
strat_map = daily_strategies.to_dict()
g = g.assign(strategy_type=[strat_map.get(d, []) for d in g.index.get_level_values('date')])
return g
from statsmodels.tsa.arima.model import ARIMA
from statsmodels.tools.sm_exceptions import ValueWarning
import warnings
def evaluate_arima_model(series, arima_order):
"""Fit ARIMA model and return AIC for evaluation."""
try:
model = ARIMA(series, order=arima_order)
model_fit = model.fit()
return model_fit.aic
except Exception:
return float("inf")
def arima_forecast_with_grid_search(accident_series: pd.Series, start_date: pd.Timestamp,
horizon: int = 30, p_values: list = range(0, 6),
d_values: list = range(0, 2), q_values: list = range(0, 6)) -> pd.DataFrame:
# Pre-process series
series = accident_series.asfreq('D').fillna(0)
start_date = pd.to_datetime(start_date)
# Suppress warnings
warnings.filterwarnings("ignore", category=ValueWarning)
# Define the hyperparameters to search through
best_score, best_cfg = float("inf"), None
for p in p_values:
for d in d_values:
for q in q_values:
order = (p, d, q)
try:
aic = evaluate_arima_model(series, order)
if aic < best_score:
best_score, best_cfg = aic, order
except Exception as e:
continue
# Fit the model with the best found order
print(best_cfg)
model = ARIMA(series, order=best_cfg)
fit = model.fit()
# Forecasting
forecast_index = pd.date_range(start=start_date, periods=horizon, freq='D')
res = fit.get_forecast(steps=horizon)
df = res.summary_frame()
df.index = forecast_index
df.index.name = 'date'
df.rename(columns={'mean': 'forecast'}, inplace=True)
return df
# Example usage:
# dataframe = your_data_frame_here
# forecast_df = arima_forecast_with_grid_search(dataframe['accident_count'], start_date=pd.Timestamp('YYYY-MM-DD'), horizon=30)
def knn_forecast_counterfactual(accident_series: pd.Series,
intervention_date: pd.Timestamp,
lookback: int = 14,
horizon: int = 30):
series = accident_series.asfreq('D').fillna(0)
intervention_date = pd.to_datetime(intervention_date).normalize()
df = pd.DataFrame({'y': series})
for i in range(1, lookback + 1):
df[f'lag_{i}'] = df['y'].shift(i)
train = df.loc[:intervention_date - pd.Timedelta(days=1)].dropna()
if len(train) < 5:
return None, None
X_train = train.filter(like='lag_').values
y_train = train['y'].values
knn = KNeighborsRegressor(n_neighbors=5)
knn.fit(X_train, y_train)
history = df.loc[:intervention_date - pd.Timedelta(days=1), 'y'].tolist()
preds = []
for _ in range(horizon):
if len(history) < lookback:
return None, None
x = np.array(history[-lookback:][::-1]).reshape(1, -1)
pred = knn.predict(x)[0]
preds.append(pred)
history.append(pred)
pred_index = pd.date_range(intervention_date, periods=horizon, freq='D')
return pd.Series(preds, index=pred_index, name='knn_pred'), None
except Exception: # pragma: no cover - fallback to inline logic
render_overview = None
render_forecast = None
render_model_eval = None
render_strategy_eval = None
render_hotspot = None
def detect_anomalies(series: pd.Series, contamination: float = 0.1):
series = series.asfreq('D').fillna(0)
iso = IsolationForest(contamination=contamination, random_state=42)
iso = IsolationForest(n_estimators=50, contamination=contamination, random_state=42, n_jobs=-1)
yhat = iso.fit_predict(series.values.reshape(-1, 1))
anomaly_mask = (yhat == -1)
anomaly_indices = series.index[anomaly_mask]
@@ -250,130 +106,11 @@ def intervention_model(series: pd.Series,
return Y_t, Z_t
def fit_and_extrapolate(series: pd.Series,
intervention_date: pd.Timestamp,
days: int = 30):
series = series.asfreq('D').fillna(0)
# 统一为无时区、按天的时间戳
series.index = pd.to_datetime(series.index).tz_localize(None).normalize()
intervention_date = pd.to_datetime(intervention_date).tz_localize(None).normalize()
pre = series.loc[:intervention_date - pd.Timedelta(days=1)]
if len(pre) < 5:
return None, None, None
x_pre = np.arange(len(pre))
x_future = np.arange(len(pre), len(pre) + days)
# 1⃣ GLM加入二次项
X_pre_glm = sm.add_constant(np.column_stack([x_pre, x_pre**2]))
glm = sm.GLM(pre.values, X_pre_glm, family=sm.families.Poisson())
glm_res = glm.fit()
X_future_glm = sm.add_constant(np.column_stack([x_future, x_future**2]))
glm_pred = glm_res.predict(X_future_glm)
# SVR
# 2⃣ SVR加标准化 & 调参 / 改线性核
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import make_pipeline
svr = make_pipeline(
StandardScaler(),
SVR(kernel='rbf', C=10, gamma=0.1) # 或 kernel='linear'
)
svr.fit(x_pre.reshape(-1, 1), pre.values)
svr_pred = svr.predict(x_future.reshape(-1, 1))
# 目标预测索引(未来可能超出历史范围 —— 用 reindex不要 .loc[...]
post_index = pd.date_range(intervention_date, periods=days, freq='D')
glm_pred = pd.Series(glm_pred, index=post_index, name='glm_pred')
svr_pred = pd.Series(svr_pred, index=post_index, name='svr_pred')
# ✅ 关键修复:对不存在的日期补 NaN而不是 .loc[post_index]
post = series.reindex(post_index)
residuals = pd.Series(post.values - svr_pred[:len(post)],
index=post_index, name='residual')
return glm_pred, svr_pred, residuals
def evaluate_strategy_effectiveness(actual_series: pd.Series,
counterfactual_series: pd.Series,
severity_series: pd.Series,
strategy_date: pd.Timestamp,
window: int = 30):
strategy_date = pd.to_datetime(strategy_date)
pre_sev = severity_series.loc[strategy_date - pd.Timedelta(days=window):strategy_date - pd.Timedelta(days=1)].sum()
post_sev = severity_series.loc[strategy_date:strategy_date + pd.Timedelta(days=window - 1)].sum()
actual_post = actual_series.loc[strategy_date:strategy_date + pd.Timedelta(days=window - 1)]
counter_post = counterfactual_series.loc[strategy_date:strategy_date + pd.Timedelta(days=window - 1)]
counter_post = counter_post.reindex(actual_post.index)
effective_days = (actual_post < counter_post).sum()
count_effective = effective_days >= (window / 2)
severity_effective = post_sev < pre_sev
cf_sum = counter_post.sum()
F1 = (cf_sum - actual_post.sum()) / cf_sum if cf_sum > 0 else 0.0
F2 = (pre_sev - post_sev) / pre_sev if pre_sev > 0 else 0.0
if F1 > 0.5 and F2 > 0.5:
safety_state = '一级'
elif F1 > 0.3:
safety_state = '二级'
else:
safety_state = '三级'
return count_effective, severity_effective, (F1, F2), safety_state
def generate_output_and_recommendations(combined_data: pd.DataFrame,
strategy_types: list,
region: str = '全市',
horizon: int = 30):
results = {}
accident_series = combined_data['accident_count']
severity_series = combined_data['severity']
for strategy in strategy_types:
has_strategy = combined_data['strategy_type'].apply(lambda x: strategy in x)
if not has_strategy.any():
continue
intervention_date = has_strategy[has_strategy].index[0]
glm_pred, svr_pred, residuals = fit_and_extrapolate(accident_series, intervention_date, days=horizon)
if svr_pred is None:
continue
count_eff, sev_eff, (F1, F2), state = evaluate_strategy_effectiveness(
actual_series=accident_series,
counterfactual_series=svr_pred,
severity_series=severity_series,
strategy_date=intervention_date,
window=horizon
)
results[strategy] = {
'effect_strength': float(residuals.mean()),
'adaptability': float(F1 + F2),
'count_effective': bool(count_eff),
'severity_effective': bool(sev_eff),
'safety_state': state,
'F1': float(F1),
'F2': float(F2),
'intervention_date': str(intervention_date.date())
}
best_strategy = max(results, key=lambda x: results[x]['adaptability']) if results else None
recommendation = f"建议在{region}区域长期实施策略类型 {best_strategy}" if best_strategy else "无足够数据推荐策略"
pd.DataFrame(results).T.to_csv('strategy_evaluation_results.csv', encoding='utf-8-sig')
with open('recommendation.txt', 'w', encoding='utf-8') as f:
f.write(recommendation)
return results, recommendation
# =======================
# 3. UI Helpers
# =======================
def hash_like(obj: str) -> str:
return hashlib.md5(obj.encode('utf-8')).hexdigest()[:8]
def compute_kpis(df_city: pd.DataFrame, arima_df: pd.DataFrame | None,
def compute_kpis(df_city: pd.DataFrame, arima_df: Optional[pd.DataFrame],
today: pd.Timestamp, window:int=30):
# 今日/昨日
today_date = pd.to_datetime(today.date())
@@ -451,112 +188,69 @@ def save_fig_as_html(fig, filename):
f.write(html)
return filename
import numpy as np
import pandas as pd
from sklearn.metrics import mean_absolute_error, mean_squared_error
from statsmodels.tsa.arima.model import ARIMA
# 依赖:已在脚本前面定义的 knn_forecast_counterfactual() 和 fit_and_extrapolate()
import numpy as np
import pandas as pd
from sklearn.metrics import mean_absolute_error, mean_squared_error
from statsmodels.tsa.arima.model import ARIMA
# 依赖knn_forecast_counterfactual、fit_and_extrapolate 已存在
def evaluate_models(series: pd.Series,
horizon: int = 30,
lookback: int = 14,
p_values: range = range(0, 6),
d_values: range = range(0, 2),
q_values: range = range(0, 6)) -> pd.DataFrame:
"""
留出法(最后 horizon 天作为验证集)比较 ARIMA / KNN / GLM / SVR
输出 MAE・RMSE・MAPE并按 RMSE 升序排序。
"""
# 统一日频 & 缺失补零
series = series.asfreq('D').fillna(0)
if len(series) <= horizon + 10:
raise ValueError("序列太短,无法留出 %d 天进行评估。" % horizon)
train, test = series.iloc[:-horizon], series.iloc[-horizon:]
def _to_series_like(pred, a_index):
# 将任意预测对齐成与 actual 同索引的 Series
if isinstance(pred, pd.Series):
return pred.reindex(a_index)
return pd.Series(pred, index=a_index)
def _metrics(a: pd.Series, p) -> dict:
p = _to_series_like(p, a.index).astype(float)
a = a.astype(float)
mae = mean_absolute_error(a, p)
# 兼容旧版 sklearn没有 squared 参数时手动开方
try:
rmse = mean_squared_error(a, p, squared=False)
except TypeError:
rmse = mean_squared_error(a, p) ** 0.5
# 忽略分母为 0 的样本
mape = np.nanmean(np.abs((a - p) / np.where(a == 0, np.nan, a))) * 100
return {"MAE": mae, "RMSE": rmse, "MAPE": mape}
results = {}
# ---------- ARIMA ----------
best_aic, best_order = float('inf'), None
for p in p_values:
for d in d_values:
for q in q_values:
try:
aic = ARIMA(train, order=(p, d, q)).fit().aic
if aic < best_aic:
best_aic, best_order = aic, (p, d, q)
except Exception:
continue
arima_train = train.asfreq('D').fillna(0)
arima_pred = ARIMA(arima_train, order=best_order).fit().forecast(steps=horizon)
results['ARIMA'] = _metrics(test, arima_pred)
# ---------- KNN ----------
try:
knn_pred, _ = knn_forecast_counterfactual(series,
train.index[-1] + pd.Timedelta(days=1),
lookback=lookback,
horizon=horizon)
if knn_pred is not None:
results['KNN'] = _metrics(test, knn_pred)
except Exception:
pass
# ---------- GLM & SVR ----------
try:
glm_pred, svr_pred, _ = fit_and_extrapolate(series,
train.index[-1] + pd.Timedelta(days=1),
days=horizon)
if glm_pred is not None:
results['GLM'] = _metrics(test, glm_pred)
if svr_pred is not None:
results['SVR'] = _metrics(test, svr_pred)
except Exception:
pass
return (pd.DataFrame(results)
.T.sort_values('RMSE')
.round(3))
# =======================
# 4. App
# =======================
# =======================
# 4. App
# =======================
def run_streamlit_app():
# Must be the first Streamlit command
st.set_page_config(page_title="Traffic Safety Analysis", layout="wide")
st.title("🚦 Traffic Safety Intervention Analysis System")
# Sidebar — Upload & Global Filters & Auto Refresh
st.sidebar.header("数据与筛选")
default_min_date = pd.to_datetime('2022-01-01').date()
default_max_date = pd.to_datetime('2022-12-31').date()
def clamp_date_range(requested, minimum, maximum):
"""Ensure the requested tuple stays within [minimum, maximum]."""
if not isinstance(requested, (list, tuple)):
requested = (requested, requested)
start, end = requested
if start > end:
start, end = end, start
if end < minimum or start > maximum:
return minimum, maximum
start = max(minimum, start)
end = min(maximum, end)
return start, end
# Initialize session state to store processed data (before rendering controls)
if 'processed_data' not in st.session_state:
st.session_state['processed_data'] = {
'combined_city': None,
'combined_by_region': None,
'accident_data': None,
'accident_records': None,
'strategy_data': None,
'all_regions': ["全市"],
'all_strategy_types': [],
'min_date': default_min_date,
'max_date': default_max_date,
'region_sel': "全市",
'date_range': (default_min_date, default_max_date),
'strat_filter': [],
'accident_source_name': None,
}
sidebar_state = st.session_state['processed_data']
available_regions = sidebar_state['all_regions'] if sidebar_state['all_regions'] else ["全市"]
current_region = sidebar_state['region_sel'] if sidebar_state['region_sel'] in available_regions else available_regions[0]
available_strategies = sidebar_state['all_strategy_types']
current_strategies = [s for s in sidebar_state['strat_filter'] if s in available_strategies]
min_date = sidebar_state['min_date']
max_date = sidebar_state['max_date']
raw_start, raw_end = sidebar_state['date_range']
start_default = max(min_date, min(raw_start, max_date))
end_default = max(start_default, min(raw_end, max_date))
# Create a form for data inputs to batch updates
with st.sidebar.form(key="data_input_form"):
@@ -566,13 +260,24 @@ def run_streamlit_app():
# Global filters
st.markdown("---")
st.subheader("全局筛选器")
# Placeholder for region selection (will be populated after data is loaded)
region_sel = st.selectbox("区域", options=["全市"], index=0, key="region_select")
# Default date range (will be updated after data is loaded)
min_date = pd.to_datetime('2022-01-01').date()
max_date = pd.to_datetime('2022-12-31').date()
date_range = st.date_input("时间范围", value=(min_date, max_date), min_value=min_date, max_value=max_date)
strat_filter = st.multiselect("策略类型(过滤)", options=[], help="为空表示不过滤策略;选择后仅保留当天包含所选策略的日期")
region_sel = st.selectbox(
"区域",
options=available_regions,
index=available_regions.index(current_region),
key="region_select",
)
date_range = st.date_input(
"时间范围",
value=(start_default, end_default),
min_value=min_date,
max_value=max_date,
)
strat_filter = st.multiselect(
"策略类型(过滤)",
options=available_strategies,
default=current_strategies,
help="为空表示不过滤策略;选择后仅保留当天包含所选策略的日期",
)
# Apply button for data loading and filtering
apply_button = st.form_submit_button("应用数据与筛选")
@@ -590,29 +295,14 @@ def run_streamlit_app():
# Add OpenAI API key input in sidebar
st.sidebar.markdown("---")
st.sidebar.subheader("GPT API 配置")
openai_api_key = st.sidebar.text_input("GPT API Key", value='sk-dQhKOOG48iVEfgJfAb14458dA4474fB09aBbE8153d4aB3Fc', type="password", help="用于GPT分析结果的API密钥")
open_ai_base_url = st.sidebar.text_input("GPT Base Url", value='https://az.gptplus5.com/v1', type='default')
# Initialize session state to store processed data
if 'processed_data' not in st.session_state:
st.session_state['processed_data'] = {
'combined_city': None,
'combined_by_region': None,
'accident_data': None,
'strategy_data': None,
'all_regions': ["全市"],
'all_strategy_types': [],
'min_date': min_date,
'max_date': max_date,
'region_sel': "全市",
'date_range': (min_date, max_date),
'strat_filter': []
}
openai_api_key = st.sidebar.text_input("GPT API Key", value='sk-sXY934yPqjh7YKKC08380b198fEb47308cDa09BeE23d9c8a', type="password", help="用于GPT分析结果的API密钥")
open_ai_base_url = st.sidebar.text_input("GPT Base Url", value='https://aihubmix.com/v1', type='default')
# Process data only when Apply button is clicked
if apply_button and accident_file and strategy_file:
with st.spinner("数据载入中…"):
# Load and clean data
accident_records = load_accident_records(accident_file, require_location=True)
accident_data, strategy_data = load_and_clean_data(accident_file, strategy_file)
combined_city = aggregate_daily_data(accident_data, strategy_data)
combined_by_region = aggregate_daily_data_by_region(accident_data, strategy_data)
@@ -624,24 +314,35 @@ def run_streamlit_app():
max_date = combined_city.index.max().date()
# Store processed data in session state
sanitized_start, sanitized_end = clamp_date_range(date_range, min_date, max_date)
st.session_state['processed_data'].update({
'combined_city': combined_city,
'combined_by_region': combined_by_region,
'accident_data': accident_data,
'accident_records': accident_records,
'strategy_data': strategy_data,
'all_regions': all_regions,
'all_strategy_types': all_strategy_types,
'min_date': min_date,
'max_date': max_date,
'region_sel': region_sel,
'date_range': date_range,
'strat_filter': strat_filter
'date_range': (sanitized_start, sanitized_end),
'strat_filter': strat_filter,
'accident_source_name': getattr(accident_file, "name", "事故数据.xlsx"),
})
sanitized_start, sanitized_end = clamp_date_range(date_range, min_date, max_date)
# Persist the latest sidebar selections for display and downstream filtering
st.session_state['processed_data']['region_sel'] = region_sel
st.session_state['processed_data']['date_range'] = (sanitized_start, sanitized_end)
st.session_state['processed_data']['strat_filter'] = strat_filter
# Retrieve data from session state
combined_city = st.session_state['processed_data']['combined_city']
combined_by_region = st.session_state['processed_data']['combined_by_region']
accident_data = st.session_state['processed_data']['accident_data']
accident_records = st.session_state['processed_data']['accident_records']
strategy_data = st.session_state['processed_data']['strategy_data']
all_regions = st.session_state['processed_data']['all_regions']
all_strategy_types = st.session_state['processed_data']['all_strategy_types']
@@ -650,6 +351,7 @@ def run_streamlit_app():
region_sel = st.session_state['processed_data']['region_sel']
date_range = st.session_state['processed_data']['date_range']
strat_filter = st.session_state['processed_data']['strat_filter']
accident_source_name = st.session_state['processed_data']['accident_source_name']
# Update selectbox and multiselect options dynamically (outside the form for display)
st.sidebar.markdown("---")
@@ -700,127 +402,137 @@ def run_streamlit_app():
with meta_col2:
st.caption(f"🕒 最近刷新:{last_refresh.strftime('%Y-%m-%d %H:%M:%S')}")
# Tabs (add new tab for GPT analysis)
tab_dash, tab_pred, tab_eval, tab_anom, tab_strat, tab_comp, tab_sim, tab_gpt = st.tabs(
["🏠 总览", "📈 预测模型", "📊 模型评估", "⚠️ 异常检测", "📝 策略评估", "⚖️ 策略对比", "🧪 情景模拟", "🔍 GPT 分析"]
tab_labels = [
"🏠 总览",
"📈 预测模型",
"📊 模型评估",
"⚠️ 异常检测",
"📝 策略评估",
"⚖️ 策略对比",
"🧪 情景模拟",
"🔍 GPT 分析",
"📍 事故热点",
]
default_tab = st.session_state.get("active_tab", tab_labels[0])
if default_tab not in tab_labels:
default_tab = tab_labels[0]
selected_tab = st.radio(
"功能分区",
tab_labels,
index=tab_labels.index(default_tab),
horizontal=True,
label_visibility="collapsed",
)
st.session_state["active_tab"] = selected_tab
# --- Tab 1: 总览页
with tab_dash:
fig_line = go.Figure()
fig_line.add_trace(go.Scatter(x=base.index, y=base['accident_count'], name='事故数', mode='lines'))
fig_line.update_layout(title="事故数(过滤后)", xaxis_title="Date", yaxis_title="Count")
st.plotly_chart(fig_line, use_container_width=True)
fname = save_fig_as_html(fig_line, "overview_series.html")
st.download_button("下载图表 HTML", data=open(fname, 'rb').read(),
file_name="overview_series.html", mime="text/html")
st.dataframe(base, use_container_width=True)
csv_bytes = base.to_csv(index=True).encode('utf-8-sig')
st.download_button("下载当前视图 CSV", data=csv_bytes, file_name="filtered_view.csv", mime="text/csv")
meta = {
"region": region_sel,
"date_range": [str(start_dt.date()), str(end_dt.date())],
"strategy_filter": strat_filter,
"rows": int(len(base)),
"min_date": str(base.index.min().date()) if len(base) else None,
"max_date": str(base.index.max().date()) if len(base) else None
}
with open("run_metadata.json", "w", encoding="utf-8") as f:
json.dump(meta, f, ensure_ascii=False, indent=2)
st.download_button("下载运行参数 JSON", data=open("run_metadata.json", "rb").read(),
file_name="run_metadata.json", mime="application/json")
# --- Tab 2: 预测模型
with tab_pred:
st.subheader("多模型预测比较")
# 使用表单封装交互组件
with st.form(key="predict_form"):
default_date = base.index.max() - pd.Timedelta(days=60) if len(base) else pd.Timestamp('2022-01-01')
selected_date = st.date_input("选择干预日期 / 预测起点", value=default_date)
horizon = st.number_input("预测天数", min_value=7, max_value=90, value=30, step=1)
submit_predict = st.form_submit_button("应用预测参数")
if submit_predict and len(base.loc[:pd.to_datetime(selected_date)]) >= 10:
first_date = pd.to_datetime(selected_date)
try:
train_series = base['accident_count'].loc[:first_date]
arima30 = arima_forecast_with_grid_search(
train_series,
start_date=first_date + pd.Timedelta(days=1),
horizon=horizon
)
except Exception as e:
st.warning(f"ARIMA 运行失败:{e}")
arima30 = None
knn_pred, _ = knn_forecast_counterfactual(base['accident_count'],
first_date,
horizon=horizon)
glm_pred, svr_pred, residuals = fit_and_extrapolate(base['accident_count'],
first_date,
days=horizon)
fig_pred = go.Figure()
fig_pred.add_trace(go.Scatter(x=base.index, y=base['accident_count'],
name="实际", mode="lines"))
if arima30 is not None:
fig_pred.add_trace(go.Scatter(x=arima30.index, y=arima30['forecast'],
name="ARIMA", mode="lines"))
if knn_pred is not None:
fig_pred.add_trace(go.Scatter(x=knn_pred.index, y=knn_pred,
name="KNN", mode="lines"))
if glm_pred is not None:
fig_pred.add_trace(go.Scatter(x=glm_pred.index, y=glm_pred,
name="GLM", mode="lines"))
if svr_pred is not None:
fig_pred.add_trace(go.Scatter(x=svr_pred.index, y=svr_pred,
name="SVR", mode="lines"))
fig_pred.update_layout(
title=f"多模型预测比较(起点:{first_date.date()},预测 {horizon} 天)",
xaxis_title="日期", yaxis_title="事故数"
)
st.plotly_chart(fig_pred, use_container_width=True)
col_dl1, col_dl2 = st.columns(2)
if arima30 is not None:
col_dl1.download_button("下载 ARIMA 预测 CSV",
data=arima30.to_csv().encode("utf-8-sig"),
file_name="arima_forecast.csv",
mime="text/csv")
elif submit_predict:
st.info("⚠️ 干预前数据较少,可能影响拟合质量。")
if selected_tab == "📍 事故热点":
if render_hotspot is not None:
render_hotspot(accident_records, accident_source_name)
else:
st.info("请设置预测参数并点击“应用预测参数”按钮")
st.warning("事故热点模块未能加载,请检查 `ui_sections/hotspot.py`")
elif selected_tab == "🏠 总览":
if render_overview is not None:
render_overview(base, region_sel, start_dt, end_dt, strat_filter)
else:
st.warning("概览模块未能加载,请检查 `ui_sections/overview.py`。")
elif selected_tab == "📈 预测模型":
if render_forecast is not None:
render_forecast(base)
else:
st.subheader("多模型预测比较")
# 使用表单封装交互组件
with st.form(key="predict_form"):
# 缩短默认回溯窗口,提升首次渲染速度
default_date = base.index.max() - pd.Timedelta(days=30) if len(base) else pd.Timestamp('2022-01-01')
selected_date = st.date_input("选择干预日期 / 预测起点", value=default_date)
horizon = st.number_input("预测天数", min_value=7, max_value=90, value=30, step=1)
submit_predict = st.form_submit_button("应用预测参数")
if submit_predict and len(base.loc[:pd.to_datetime(selected_date)]) >= 10:
first_date = pd.to_datetime(selected_date)
try:
train_series = base['accident_count'].loc[:first_date]
arima30 = arima_forecast_with_grid_search(
train_series,
start_date=first_date + pd.Timedelta(days=1),
horizon=horizon
)
except Exception as e:
st.warning(f"ARIMA 运行失败:{e}")
arima30 = None
knn_pred, _ = knn_forecast_counterfactual(base['accident_count'],
first_date,
horizon=horizon)
glm_pred, svr_pred, residuals = fit_and_extrapolate(base['accident_count'],
first_date,
days=horizon)
fig_pred = go.Figure()
fig_pred.add_trace(go.Scatter(x=base.index, y=base['accident_count'],
name="实际", mode="lines"))
if arima30 is not None:
fig_pred.add_trace(go.Scatter(x=arima30.index, y=arima30['forecast'],
name="ARIMA", mode="lines"))
if knn_pred is not None:
fig_pred.add_trace(go.Scatter(x=knn_pred.index, y=knn_pred,
name="KNN", mode="lines"))
if glm_pred is not None:
fig_pred.add_trace(go.Scatter(x=glm_pred.index, y=glm_pred,
name="GLM", mode="lines"))
if svr_pred is not None:
fig_pred.add_trace(go.Scatter(x=svr_pred.index, y=svr_pred,
name="SVR", mode="lines"))
fig_pred.update_layout(
title=f"多模型预测比较(起点:{first_date.date()},预测 {horizon} 天)",
xaxis_title="日期", yaxis_title="事故数"
)
st.plotly_chart(fig_pred, use_container_width=True)
col_dl1, col_dl2 = st.columns(2)
if arima30 is not None:
col_dl1.download_button("下载 ARIMA 预测 CSV",
data=arima30.to_csv().encode("utf-8-sig"),
file_name="arima_forecast.csv",
mime="text/csv")
elif submit_predict:
st.info("⚠️ 干预前数据较少,可能影响拟合质量。")
else:
st.info("请设置预测参数并点击“应用预测参数”按钮。")
# --- Tab 3: 模型评估
with tab_eval:
st.subheader("模型预测效果对比")
with st.form(key="model_eval_form"):
horizon_sel = st.slider("评估窗口(天)", 7, 60, 30, step=1)
submit_eval = st.form_submit_button("应用评估参数")
if submit_eval:
try:
df_metrics = evaluate_models(base['accident_count'], horizon=horizon_sel)
st.dataframe(df_metrics, use_container_width=True)
best_model = df_metrics['RMSE'].idxmin()
st.success(f"过去 {horizon_sel} 天中RMSE 最低的模型是:**{best_model}**")
st.download_button(
"下载评估结果 CSV",
data=df_metrics.to_csv().encode('utf-8-sig'),
file_name="model_evaluation.csv",
mime="text/csv"
)
except ValueError as err:
st.warning(str(err))
elif selected_tab == "📊 模型评估":
if render_model_eval is not None:
render_model_eval(base)
else:
st.info("请设置评估窗口并点击“应用评估参数”按钮。")
st.subheader("模型预测效果对比")
with st.form(key="model_eval_form"):
horizon_sel = st.slider("评估窗口(天)", 7, 60, 30, step=1)
submit_eval = st.form_submit_button("应用评估参数")
if submit_eval:
try:
df_metrics = evaluate_models(base['accident_count'], horizon=horizon_sel)
st.dataframe(df_metrics, use_container_width=True)
best_model = df_metrics['RMSE'].idxmin()
st.success(f"过去 {horizon_sel} 天中RMSE 最低的模型是:**{best_model}**")
st.download_button(
"下载评估结果 CSV",
data=df_metrics.to_csv().encode('utf-8-sig'),
file_name="model_evaluation.csv",
mime="text/csv"
)
except ValueError as err:
st.warning(str(err))
else:
st.info("请设置评估窗口并点击“应用评估参数”按钮。")
# --- Tab 4: 异常检测
with tab_anom:
elif selected_tab == "⚠️ 异常检测":
anomalies, anomaly_fig = detect_anomalies(base['accident_count'])
st.plotly_chart(anomaly_fig, use_container_width=True)
st.write(f"检测到异常点:{len(anomalies)}")
@@ -829,25 +541,14 @@ def run_streamlit_app():
file_name="anomalies.csv", mime="text/csv")
# --- Tab 5: 策略评估
with tab_strat:
st.info(f"📌 检测到的策略类型:{', '.join(all_strategy_types) or '(数据中没有策略)'}")
if all_strategy_types:
results, recommendation = generate_output_and_recommendations(base, all_strategy_types,
region=region_sel if region_sel!='全市' else '全市')
st.subheader("各策略指标")
df_res = pd.DataFrame(results).T
st.dataframe(df_res, use_container_width=True)
st.success(f"⭐ 推荐:{recommendation}")
st.download_button("下载策略评估 CSV",
data=df_res.to_csv().encode('utf-8-sig'),
file_name="strategy_evaluation_results.csv", mime="text/csv")
with open('recommendation.txt','r',encoding='utf-8') as f:
st.download_button("下载推荐文本", data=f.read().encode('utf-8'), file_name="recommendation.txt")
elif selected_tab == "📝 策略评估":
if render_strategy_eval is not None:
render_strategy_eval(base, all_strategy_types, region_sel)
else:
st.warning("数据中没有检测到策略。")
st.warning("策略评估模块不可用,请检查 `ui_sections/strategy_eval.py`")
# --- Tab 6: 策略对比
with tab_comp:
elif selected_tab == "⚖️ 策略对比":
def strategy_metrics(strategy):
mask = base['strategy_type'].apply(lambda x: strategy in x)
if not mask.any():
@@ -915,7 +616,7 @@ def run_streamlit_app():
st.warning("没有策略可供对比。")
# --- Tab 7: 情景模拟
with tab_sim:
elif selected_tab == "🧪 情景模拟":
st.subheader("情景模拟")
st.write("选择一个日期与策略,模拟“在该日期上线该策略”的影响:")
with st.form(key="simulation_form"):
@@ -952,7 +653,7 @@ def run_streamlit_app():
st.info("请设置模拟参数并点击“应用模拟参数”按钮。")
# --- New Tab 8: GPT 分析
with tab_gpt:
elif selected_tab == "🔍 GPT 分析":
from openai import OpenAI
st.subheader("GPT 数据分析与改进建议")
# open_ai_key = f"sk-dQhKOOG48iVEfgJfAb14458dA4474fB09aBbE8153d4aB3Fc"
@@ -983,27 +684,29 @@ def run_streamlit_app():
提供数据结果的详细分析,以及改进思路和建议。
数据:{str(data_str)}
""")
#st.text_area(prompt)
if st.button("上传数据至 GPT 并获取分析"):
try:
client = OpenAI(
base_url=open_ai_base_url,
# sk-xxx替换为自己的key
api_key=openai_api_key
)
response = client.chat.completions.create(
model="gpt-4o",
messages=[
{"role": "system", "content": "You are a helpful assistant that analyzes traffic safety data."},
{"role": "user", "content": prompt}
],
stream=False
)
gpt_response = response.choices[0].message.content
st.markdown("### GPT 分析结果与改进思路")
st.markdown(gpt_response, unsafe_allow_html=True)
except Exception as e:
st.error(f"调用 OpenAI API 失败:{str(e)}")
if False:
st.info("请将 GPT Base Url 更新为实际可访问的接口地址。")
else:
try:
client = OpenAI(
base_url=open_ai_base_url,
# sk-xxx替换为自己的key
api_key=openai_api_key
)
response = client.chat.completions.create(
model="gpt-5-mini",
messages=[
{"role": "system", "content": "You are a helpful assistant that analyzes traffic safety data."},
{"role": "user", "content": prompt}
],
stream=False
)
gpt_response = response.choices[0].message.content
st.markdown("### GPT 分析结果与改进思路")
st.markdown(gpt_response, unsafe_allow_html=True)
except Exception as e:
st.error(f"调用 OpenAI API 失败:{str(e)}")
else:
st.warning("没有策略数据可供分析。")
@@ -1014,4 +717,4 @@ def run_streamlit_app():
st.info("请先在左侧上传事故数据与策略数据,并点击“应用数据与筛选”按钮。")
if __name__ == "__main__":
run_streamlit_app()
run_streamlit_app()

21
config/settings.py Normal file
View File

@@ -0,0 +1,21 @@
from __future__ import annotations
# Default configuration and feature flags
# Forecasting
ARIMA_P = range(0, 4)
ARIMA_D = range(0, 2)
ARIMA_Q = range(0, 4)
DEFAULT_HORIZON_PREDICT = 30
DEFAULT_HORIZON_EVAL = 14
MIN_PRE_DAYS = 5
MAX_PRE_DAYS = 120
# Anomaly detection
ANOMALY_N_ESTIMATORS = 50
ANOMALY_CONTAMINATION = 0.10
# Performance flags
FAST_MODE = True

View File

@@ -0,0 +1,78 @@
# Installation Guide
This document explains how to set up TrafficSafeAnalyzer for local development and exploration. The application runs on Streamlit and officially supports Python 3.8.
## Prerequisites
- Python 3.8 (3.9+ is not yet validated; use 3.8 to avoid dependency issues)
- Git
- `pip` (bundled with Python)
- Optional: Conda (for environment management) or Docker (for container-based runs)
## 1. Obtain the source code
```bash
git clone https://github.com/tongnian0613/TrafficSafeAnalyzer.git
cd TrafficSafeAnalyzer
```
If you already have the repository, pull the latest changes instead:
```bash
git pull origin main
```
## 2. Create a dedicated environment
### Option A: Built-in virtual environment
```bash
python -m venv .venv
source .venv/bin/activate # Windows: .venv\Scripts\activate
```
### Option B: Conda environment
```bash
conda create -n trafficsa python=3.8 -y
conda activate trafficsa
```
## 3. Install project dependencies
Install the full dependency set listed in `requirements.txt`:
```bash
pip install -r requirements.txt
```
If you prefer a minimal installation before pulling in extras, install the core stack first:
```bash
pip install streamlit pandas numpy matplotlib plotly scikit-learn statsmodels scipy
```
Then add optional packages as needed (Excel readers, auto-refresh, OpenAI integration):
```bash
pip install streamlit-autorefresh openpyxl xlrd cryptography openai
```
## 4. Verify the setup
1. Ensure the environment is still active (`which python` should point to `.venv` or the conda env).
2. Launch the Streamlit app:
```bash
streamlit run app.py
```
3. Open `http://localhost:8501` in your browser. The home page should load without import errors.
## Troubleshooting tips
- **Missing package**: Re-run `pip install -r requirements.txt`.
- **Python version mismatch**: Confirm `python --version` reports 3.8.x inside your environment.
- **OpenSSL or cryptography errors** (macOS/Linux): Update the system OpenSSL libraries and reinstall `cryptography`.
- **Taking too long to install**: if a dependency download stalls due to a firewall, retry using a mirror (`-i https://pypi.tuna.tsinghua.edu.cn/simple`) consistent with your environment policy.
After a successful launch, continue with the usage guide in `docs/usage.md` to load data and explore forecasts.

View File

@@ -0,0 +1,73 @@
# Usage Guide
TrafficSafeAnalyzer delivers accident analytics and decision support through a Streamlit interface. This guide walks through the daily workflow, expected inputs, and where to find generated artefacts.
## Start the app
1. Activate your virtual or conda environment.
2. From the project root, run:
```bash
streamlit run app.py
```
3. Open `http://localhost:8501`. Keep the terminal running while you work in the browser.
## Load input data
Use the sidebar form labelled “数据与筛选”.
- **Accident data (`.xlsx`)** — columns should include at minimum:
- `事故时间` (timestamp)
- `所在街道` (region or district)
- `事故类型`
- `事故数`/`accident_count` (if absent, the loader aggregates counts)
- **Strategy data (`.xlsx`)** — include:
- `发布时间`
- `交通策略类型`
- optional descriptors such as `策略名称`, `策略内容`
- Select the global filters (region, date window, strategy filter) and click `应用数据与筛选`.
- Uploaded files are cached. Upload a new file or press “Rerun” to refresh after making edits.
- Sample datasets for rapid smoke testing live in `sample/事故/*.xlsx` (accidents) and `sample/交通策略/*.xlsx` (strategies); copy them before making modifications.
> Tip: `services/io.py` performs validation; rows missing key columns are dropped with a warning in the Streamlit log.
## Navigate the workspace
- **🏠 总览 (Overview)** — KPI cards, time-series plot, filtered table, and download buttons for HTML (`overview_series.html`), CSV (`filtered_view.csv`), and run metadata (`run_metadata.json`).
- **📈 预测模型 (Forecast)** — choose an intervention date and horizon, compare ARIMA / KNN / GLM / SVR forecasts, and export `arima_forecast.csv`(提交后结果会在同一数据集下保留,便于调整其他控件)。
- **📊 模型评估 (Model evaluation)** — run rolling-window backtests, inspect RMSE/MAE/MAPE, and download `model_evaluation.csv`.
- **⚠️ 异常检测 (Anomaly detection)** — isolation forest marks outliers on the accident series; tweak contamination via the main page controls.
- **📝 策略评估 (Strategy evaluation)** — Aggregates metrics per strategy type, recommends the best option, writes `strategy_evaluation_results.csv`, and updates `recommendation.txt`.
- **⚖️ 策略对比 (Strategy comparison)** — side-by-side metrics for selected strategies, useful for “what worked best last month” reviews.
- **🧪 情景模拟 (Scenario simulation)** — apply intervention models (persistent/decay, lagged effects) to test potential roll-outs.
- **🔍 GPT 分析** — enter your own OpenAI-compatible API key and base URL in the sidebar to generate narrative insights. Keys are read at runtime only.
- **📍 事故热点 (Hotspot)** — reuse the already uploaded accident data to identify high-risk intersections and produce targeted mitigation ideas; no separate hotspot upload is required.
Each tab remembers the active filters from the sidebar so results stay consistent.
## Downloaded artefacts
Generated files are saved to the project root unless you override paths in the code:
- `overview_series.html`
- `filtered_view.csv`
- `run_metadata.json`
- `arima_forecast.csv`
- `model_evaluation.csv`
- `strategy_evaluation_results.csv`
- `recommendation.txt`
After a session, review and archive these outputs under `docs/` or a dated folder as needed.
## Operational tips
- **Auto refresh**: enable from the sidebar (requires `streamlit-autorefresh`). Set the interval in seconds for live dashboards.
- **Logging**: set `LOG_LEVEL=DEBUG` before launch to see detailed diagnostics in the terminal and Streamlit log.
- **Reset filters**: choose “全市” and the full date span, then re-run the sidebar form.
- **Common warnings**:
- *“数据中没有检测到策略”*: verify the strategy Excel file and column names.
- *ARIMA failures*: shorten the horizon or ensure at least 10 historical data points before the intervention date.
- *Hotspot data issues*: ensure the accident workbook includes `事故时间`, `所在街道`, `事故类型`, and `事故具体地点` so intersections can be resolved.
Need deeper integration or batch automation? Extract the core functions from `services/` and orchestrate them in a notebook or scheduled job.

26
environment.yml Normal file
View File

@@ -0,0 +1,26 @@
name: trafficsa
channels:
- conda-forge
- defaults
dependencies:
- python=3.12
- pip
- streamlit>=1.20
- pandas>=1.3
- numpy>=1.21
- matplotlib>=3.4
- plotly>=5
- scikit-learn>=1.0
- statsmodels>=0.13
- scipy>=1.7
- pyarrow>=7
- python-dateutil>=2.8.2
- pytz>=2021.3
- openpyxl>=3.0.9
- xlrd>=2.0.1
- cryptography>=3.4.7
- requests
- pip:
- streamlit-autorefresh>=0.1.5
- openai>=2.0.0
- jieba>=0.42.1

File diff suppressed because one or more lines are too long

View File

@@ -8,6 +8,8 @@
- 使用 ARIMA、KNN、GLM、SVR 等模型预测事故趋势
- 检测异常事故点
- 评估交通策略效果并提供推荐
- 识别事故热点路口并生成风险分级与整治建议
- 支持 GPT 分析生成自然语言洞察
## 安装步骤
@@ -29,7 +31,7 @@ cd TrafficSafeAnalyzer
2. 创建虚拟环境(推荐):
```bash
conda create -n trafficsa python=3.12 -y
conda create -n trafficsa python=3.8 -y
conda activate trafficsa
pip install -r requirements.txt
streamlit run app.py
@@ -85,11 +87,20 @@ openai>=2.0.0
## 配置参数
- **数据文件**:上传事故数据(`accident_file`)和策略数据(`strategy_file`),格式为 Excel
- **数据文件**:上传事故数据(`accident_file`)和策略数据(`strategy_file`),格式为 Excel;事故热点分析会直接复用事故数据,无需额外上传。
- **环境变量**(可选):
- `LOG_LEVEL=DEBUG`:启用详细日志
- 示例:`export LOG_LEVEL=DEBUG`Linux/macOS或 `set LOG_LEVEL=DEBUG`Windows
## 示例数据
`sample/` 目录提供了脱敏示例数据,便于快速体验:
- `sample/事故/*.xlsx`:按年份划分的事故记录
- `sample/交通策略/*.xlsx`:策略发布记录
使用前建议复制到临时位置再进行编辑。
## 输入输出格式
### 输入
@@ -118,8 +129,11 @@ streamlit run app.py
**问题**:数据加载失败
**解决**:确保 Excel 文件格式正确,检查列名是否匹配
**问题**`NameError: name 'strategy_metrics' is not defined`
**解决**:确保 `strategy_metrics` 函数定义在 `app.py` 中,且位于 `run_streamlit_app` 函数内
**问题**预测模型页面点击后图表未显示
**解决**:确认干预日期之前至少有 10 条历史记录,或缩短预测天数重新提交
**问题**:热点分析提示“请上传事故数据”
**解决**:侧边栏上传事故数据后点击“应用数据与筛选”,热点模块会复用相同数据集
## 日志分析

View File

@@ -27,5 +27,8 @@ cryptography>=3.4.7
# OpenAI
openai
# jieba for Chinese text segmentation
jieba
# Note: hashlib and json are part of Python standard library
# Note: os and datetime are part of Python standard library

View File

@@ -1,11 +0,0 @@
{
"region": "全市",
"date_range": [
"2022-01-01",
"2022-12-31"
],
"strategy_filter": [],
"rows": 365,
"min_date": "2022-01-01",
"max_date": "2022-12-31"
}

Binary file not shown.

Binary file not shown.

139
services/forecast.py Normal file
View File

@@ -0,0 +1,139 @@
from __future__ import annotations
import warnings
import pandas as pd
import numpy as np
import streamlit as st
import statsmodels.api as sm
from statsmodels.tsa.arima.model import ARIMA
from statsmodels.tools.sm_exceptions import ValueWarning
from sklearn.neighbors import KNeighborsRegressor
from sklearn.svm import SVR
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
from config.settings import ARIMA_P, ARIMA_D, ARIMA_Q, MAX_PRE_DAYS
@st.cache_data(show_spinner=False)
def evaluate_arima_model(series, arima_order):
try:
model = ARIMA(series, order=arima_order)
model_fit = model.fit()
return model_fit.aic
except Exception:
return float("inf")
@st.cache_data(show_spinner=False)
def arima_forecast_with_grid_search(accident_series: pd.Series,
start_date: pd.Timestamp,
horizon: int = 30,
p_values: list = tuple(ARIMA_P),
d_values: list = tuple(ARIMA_D),
q_values: list = tuple(ARIMA_Q)) -> pd.DataFrame:
series = accident_series.asfreq('D').fillna(0)
start_date = pd.to_datetime(start_date)
warnings.filterwarnings("ignore", category=ValueWarning)
best_score, best_cfg = float("inf"), None
for p in p_values:
for d in d_values:
for q in q_values:
order = (p, d, q)
try:
aic = evaluate_arima_model(series, order)
if aic < best_score:
best_score, best_cfg = aic, order
except Exception:
continue
model = ARIMA(series, order=best_cfg)
fit = model.fit()
forecast_index = pd.date_range(start=start_date, periods=horizon, freq='D')
res = fit.get_forecast(steps=horizon)
df = res.summary_frame()
df.index = forecast_index
df.index.name = 'date'
df.rename(columns={'mean': 'forecast'}, inplace=True)
return df
def knn_forecast_counterfactual(accident_series: pd.Series,
intervention_date: pd.Timestamp,
lookback: int = 14,
horizon: int = 30):
series = accident_series.asfreq('D').fillna(0)
intervention_date = pd.to_datetime(intervention_date).normalize()
df = pd.DataFrame({'y': series})
for i in range(1, lookback + 1):
df[f'lag_{i}'] = df['y'].shift(i)
train = df.loc[:intervention_date - pd.Timedelta(days=1)].dropna()
if len(train) < 5:
return None, None
X_train = train.filter(like='lag_').values
y_train = train['y'].values
knn = KNeighborsRegressor(n_neighbors=5)
knn.fit(X_train, y_train)
history = df.loc[:intervention_date - pd.Timedelta(days=1), 'y'].tolist()
preds = []
for _ in range(horizon):
if len(history) < lookback:
return None, None
x = np.array(history[-lookback:][::-1]).reshape(1, -1)
pred = knn.predict(x)[0]
preds.append(pred)
history.append(pred)
pred_index = pd.date_range(intervention_date, periods=horizon, freq='D')
return pd.Series(preds, index=pred_index, name='knn_pred'), None
def fit_and_extrapolate(series: pd.Series,
intervention_date: pd.Timestamp,
days: int = 30,
max_pre_days: int = MAX_PRE_DAYS):
series = series.asfreq('D').fillna(0)
series.index = pd.to_datetime(series.index).tz_localize(None).normalize()
intervention_date = pd.to_datetime(intervention_date).tz_localize(None).normalize()
pre = series.loc[:intervention_date - pd.Timedelta(days=1)]
if len(pre) > max_pre_days:
pre = pre.iloc[-max_pre_days:]
if len(pre) < 3:
return None, None, None
x_pre = np.arange(len(pre))
x_future = np.arange(len(pre), len(pre) + days)
try:
X_pre_glm = sm.add_constant(np.column_stack([x_pre, x_pre**2]))
glm = sm.GLM(pre.values, X_pre_glm, family=sm.families.Poisson())
glm_res = glm.fit()
X_future_glm = sm.add_constant(np.column_stack([x_future, x_future**2]))
glm_pred = glm_res.predict(X_future_glm)
except Exception:
glm_pred = None
try:
svr = make_pipeline(StandardScaler(), SVR(kernel='rbf', C=10, gamma=0.1))
svr.fit(x_pre.reshape(-1, 1), pre.values)
svr_pred = svr.predict(x_future.reshape(-1, 1))
except Exception:
svr_pred = None
post_index = pd.date_range(intervention_date, periods=days, freq='D')
glm_pred = pd.Series(glm_pred, index=post_index, name='glm_pred') if glm_pred is not None else None
svr_pred = pd.Series(svr_pred, index=post_index, name='svr_pred') if svr_pred is not None else None
post = series.reindex(post_index)
residuals = None
if svr_pred is not None:
residuals = pd.Series(post.values - svr_pred[:len(post)], index=post_index, name='residual')
return glm_pred, svr_pred, residuals

227
services/hotspot.py Normal file
View File

@@ -0,0 +1,227 @@
from __future__ import annotations
from datetime import datetime
from typing import Iterable
import numpy as np
import pandas as pd
LOCATION_KEYWORDS: tuple[str, ...] = (
"",
"",
"",
"",
"路口",
"交叉口",
"大道",
"公路",
"",
)
AREA_KEYWORDS: tuple[str, ...] = (
"新城",
"临城",
"千岛",
"翁山",
"海天",
"海宇",
"定沈",
"滨海",
"港岛",
"体育",
"长升",
"金岛",
"桃湾",
)
LOCATION_MAPPING: dict[str, str] = {
"新城千岛路": "千岛路",
"千岛路海天大道": "千岛路海天大道口",
"海天大道千岛路": "千岛路海天大道口",
"新城翁山路": "翁山路",
"翁山路金岛路": "翁山路金岛路口",
"海天大道临长路": "海天大道临长路口",
"定沈路卫生医院门口": "定沈路医院段",
"翁山路海城路西口": "翁山路海城路口",
"海宇道路口": "海宇道",
"海天大道路口": "海天大道",
"定沈路交叉路口": "定沈路",
"千岛路路口": "千岛路",
"体育路路口": "体育路",
"金岛路路口": "金岛路",
}
SEVERITY_MAP: dict[str, int] = {"财损": 1, "伤人": 2, "亡人": 4}
def _extract_road_info(location: str | float | None) -> str:
if pd.isna(location):
return "未知路段"
text = str(location)
for keyword in LOCATION_KEYWORDS + AREA_KEYWORDS:
if keyword in text:
words = text.replace("", " ").replace(",", " ").split()
for word in words:
if keyword in word:
return word
return text
return text[:20] if len(text) > 20 else text
def prepare_hotspot_dataset(accident_records: pd.DataFrame) -> pd.DataFrame:
df = accident_records.copy()
required_defaults: dict[str, str] = {
"道路类型": "未知道路类型",
"路口路段类型": "未知路段",
"事故具体地点": "未知路段",
"事故类型": "财损",
"所在街道": "未知街道",
}
for column, default_value in required_defaults.items():
if column not in df.columns:
df[column] = default_value
else:
df[column] = df[column].fillna(default_value)
if "severity" not in df.columns:
df["severity"] = df["事故类型"].map(SEVERITY_MAP).fillna(1).astype(int)
df["事故时间"] = pd.to_datetime(df["事故时间"], errors="coerce")
df = df.dropna(subset=["事故时间"]).sort_values("事故时间").reset_index(drop=True)
df["standardized_location"] = (
df["事故具体地点"].apply(_extract_road_info).replace(LOCATION_MAPPING)
)
return df
def analyze_hotspot_frequency(df: pd.DataFrame, time_window: str = "7D") -> pd.DataFrame:
recent_cutoff = df["事故时间"].max() - pd.Timedelta(time_window)
overall_stats = df.groupby("standardized_location").agg(
accident_count=("事故时间", "count"),
last_accident=("事故时间", "max"),
main_accident_type=("事故类型", _mode_fallback),
main_road_type=("道路类型", _mode_fallback),
main_intersection_type=("路口路段类型", _mode_fallback),
total_severity=("severity", "sum"),
)
recent_stats = (
df[df["事故时间"] >= recent_cutoff]
.groupby("standardized_location")
.agg(
recent_count=("事故时间", "count"),
recent_accident_type=("事故类型", _mode_fallback),
recent_severity=("severity", "sum"),
)
)
result = (
overall_stats.merge(recent_stats, left_index=True, right_index=True, how="left")
.fillna({"recent_count": 0, "recent_severity": 0})
.fillna("")
)
result["recent_count"] = result["recent_count"].astype(int)
result["trend_ratio"] = result["recent_count"] / result["accident_count"]
result["days_since_last"] = (
df["事故时间"].max() - result["last_accident"]
).dt.days.astype(int)
result["avg_severity"] = result["total_severity"] / result["accident_count"]
return result.sort_values(["recent_count", "accident_count"], ascending=False)
def calculate_hotspot_risk_score(hotspot_df: pd.DataFrame) -> pd.DataFrame:
df = hotspot_df.copy()
if df.empty:
return df
df["frequency_score"] = (df["accident_count"] / df["accident_count"].max() * 40).clip(
0, 40
)
df["trend_score"] = (df["trend_ratio"] * 30).clip(0, 30)
severity_map = {"财损": 5, "伤人": 15, "亡人": 20}
df["severity_score"] = df["main_accident_type"].map(severity_map).fillna(5)
df["urgency_score"] = ((30 - df["days_since_last"]) / 30 * 10).clip(0, 10)
df["risk_score"] = (
df["frequency_score"]
+ df["trend_score"]
+ df["severity_score"]
+ df["urgency_score"]
)
conditions = [
df["risk_score"] >= 70,
df["risk_score"] >= 50,
df["risk_score"] >= 30,
]
choices = ["高风险", "中风险", "低风险"]
df["risk_level"] = np.select(conditions, choices, default="一般风险")
return df.sort_values("risk_score", ascending=False)
def generate_hotspot_strategies(
hotspot_df: pd.DataFrame, time_period: str = "本周"
) -> list[dict[str, str | float]]:
strategies: list[dict[str, str | float]] = []
for location_name, location_data in hotspot_df.iterrows():
accident_count = float(location_data["accident_count"])
recent_count = float(location_data.get("recent_count", 0))
accident_type = str(location_data.get("main_accident_type", "财损"))
intersection_type = str(location_data.get("main_intersection_type", "普通路段"))
trend_ratio = float(location_data.get("trend_ratio", 0))
risk_level = str(location_data.get("risk_level", "一般风险"))
base_info = f"{time_period}对【{location_name}"
data_support = (
f"(近期{int(recent_count)}起,累计{int(accident_count)}起,{accident_type}为主)"
)
strategy_parts: list[str] = []
if "信号灯" in intersection_type:
if accident_type == "财损":
strategy_parts.extend(["加强闯红灯查处", "优化信号配时", "整治不按规定让行"])
else:
strategy_parts.extend(["完善人行过街设施", "加强非机动车管理", "设置警示标志"])
elif "普通路段" in intersection_type:
strategy_parts.extend(["加强巡逻管控", "整治违法停车", "设置限速标志"])
else:
strategy_parts.extend(["分析事故成因", "制定综合整治方案"])
if risk_level == "高风险":
strategy_parts.extend(["列为重点整治路段", "开展专项整治行动"])
elif risk_level == "中风险":
strategy_parts.append("加强日常监管")
if trend_ratio > 0.4:
strategy_parts.append("近期重点监控")
strategy_text = (
base_info + "" + "".join(strategy_parts) + data_support
if strategy_parts
else base_info + "加强交通安全管理" + data_support
)
strategies.append(
{
"location": location_name,
"strategy": strategy_text,
"risk_level": risk_level,
"accident_count": accident_count,
"recent_count": recent_count,
}
)
return strategies
def serialise_datetime_columns(df: pd.DataFrame, columns: Iterable[str]) -> pd.DataFrame:
result = df.copy()
for column in columns:
if column in result.columns and pd.api.types.is_datetime64_any_dtype(result[column]):
result[column] = result[column].dt.strftime("%Y-%m-%d %H:%M:%S")
return result
def _mode_fallback(series: pd.Series) -> str:
if series.empty:
return ""
mode = series.mode()
return str(mode.iloc[0]) if not mode.empty else str(series.iloc[0])

270
services/io.py Normal file
View File

@@ -0,0 +1,270 @@
from __future__ import annotations
import re
from typing import Iterable, Mapping
import pandas as pd
import streamlit as st
COLUMN_ALIASES: Mapping[str, str] = {
'事故发生时间': '事故时间',
'发生时间': '事故时间',
'时间': '事故时间',
'街道': '所在街道',
'所属街道': '所在街道',
'所属辖区': '所在区县',
'辖区街道': '所在街道',
'事故发生地点': '事故地点',
'事故地址': '事故地点',
'事故位置': '事故地点',
'事故具体地址': '事故具体地点',
'案件类型': '事故类型',
'事故类别': '事故类型',
'事故性质': '事故类型',
'事故类型1': '事故类型',
}
ACCIDENT_TYPE_NORMALIZATION: Mapping[str, str] = {
'财产损失': '财损',
'财产损失事故': '财损',
'一般程序': '伤人',
'一般程序事故': '伤人',
'伤人事故': '伤人',
'造成人员受伤': '伤人',
'造成人员死亡': '亡人',
'死亡事故': '亡人',
'亡人事故': '亡人',
'亡人死亡': '亡人',
'': '财损',
}
REGION_FROM_LOCATION_PATTERN = re.compile(r'([一-龥]{2,8}(街道|新区|开发区|镇|区))')
REGION_NORMALIZATION: Mapping[str, str] = {
'临城中队': '临城街道',
'临城新区': '临城街道',
'临城': '临城街道',
'新城': '临城街道',
'千岛中队': '千岛街道',
'千岛新区': '千岛街道',
'千岛': '千岛街道',
'沈家门中队': '沈家门街道',
'沈家门': '沈家门街道',
'普陀城区': '沈家门街道',
'普陀': '沈家门街道',
}
def _clean_text(series: pd.Series) -> pd.Series:
"""Strip whitespace and normalise obvious null placeholders."""
cleaned = series.astype(str).str.strip()
null_tokens = {'', 'nan', 'NaN', 'None', 'NULL', '<NA>', '', ''}
return cleaned.mask(cleaned.isin(null_tokens))
def _maybe_seek_start(file_obj) -> None:
if hasattr(file_obj, "seek"):
try:
file_obj.seek(0)
except Exception: # pragma: no cover - guard against non file-likes
pass
def _prepare_sheet(df: pd.DataFrame) -> pd.DataFrame:
"""Standardise a single sheet from the事故数据 workbook."""
if df is None or df.empty:
return pd.DataFrame()
sheet = df.copy()
# Normalise column names first
sheet.columns = [str(col).strip() for col in sheet.columns]
# If 栏目 still not recognised, attempt to locate header row inside the data
if '事故时间' not in sheet.columns and '事故发生时间' not in sheet.columns:
header_row = None
for idx, row in sheet.iterrows():
values = [str(cell).strip() for cell in row.tolist()]
if '事故时间' in values or '事故发生时间' in values or '报警时间' in values:
header_row = idx
break
if header_row is not None:
sheet.columns = [str(x).strip() for x in sheet.iloc[header_row].tolist()]
sheet = sheet.iloc[header_row + 1 :].reset_index(drop=True)
sheet.columns = [str(col).strip() for col in sheet.columns]
# Apply aliases after potential header relocation
sheet = sheet.rename(columns={src: dst for src, dst in COLUMN_ALIASES.items() if src in sheet.columns})
return sheet
def _coalesce_columns(df: pd.DataFrame, columns: Iterable[str]) -> pd.Series:
result = pd.Series(pd.NA, index=df.index, dtype="object")
for col in columns:
if col in df.columns:
candidate = _clean_text(df[col])
result = result.fillna(candidate)
return result
def _infer_region_from_location(location: str) -> str | None:
if pd.isna(location):
return None
text = str(location).strip()
if not text:
return None
match = REGION_FROM_LOCATION_PATTERN.search(text)
if match:
return match.group(1)
return None
def _normalise_region_series(series: pd.Series) -> pd.Series:
return series.map(lambda val: REGION_NORMALIZATION.get(val, val) if pd.notna(val) else val)
def load_accident_records(accident_file, *, require_location: bool = False) -> pd.DataFrame:
"""
Load accident records from the updated Excel template.
The function supports workbooks with a single sheet (e.g. sample/事故处理/事故2021-2022.xlsx)
as well as legacy multi-sheet formats where the header row might sit within the data.
"""
_maybe_seek_start(accident_file)
sheets = pd.read_excel(accident_file, sheet_name=None)
if isinstance(sheets, dict):
frames = [frame for frame in ( _prepare_sheet(df) for df in sheets.values() ) if not frame.empty]
else: # pragma: no cover - pandas only returns dict when sheet_name=None, but keep guard
frames = [_prepare_sheet(sheets)]
if not frames:
raise ValueError("未在上传的事故数据中检测到有效的事故记录,请确认文件内容。")
accident_df = pd.concat(frames, ignore_index=True)
# Normalise columns of interest
if '事故时间' not in accident_df.columns and '报警时间' in accident_df.columns:
accident_df['事故时间'] = accident_df['报警时间']
if '事故时间' not in accident_df.columns:
raise ValueError("事故数据缺少“事故时间”字段,请确认模板是否为最新版本。")
accident_df['事故时间'] = pd.to_datetime(accident_df['事故时间'], errors='coerce')
# Location harmonisation (used for both region inference and hotspot analysis)
location_columns_available = [col for col in ['事故具体地点', '事故地点'] if col in accident_df.columns]
location_series = _coalesce_columns(accident_df, ['事故具体地点', '事故地点'])
# Region handling
region = _coalesce_columns(accident_df, ['所在街道', '所属街道', '所在区县', '辖区中队'])
# Infer region from location fields when still missing
if region.isna().any():
inferred = location_series.map(_infer_region_from_location)
region = region.fillna(inferred)
region = region.fillna(_clean_text(location_series))
region_clean = _clean_text(region)
accident_df['所在街道'] = _normalise_region_series(region_clean)
# Accident type normalisation
accident_type = _coalesce_columns(accident_df, ['事故类型', '事故类别', '事故性质'])
accident_type = accident_type.replace(ACCIDENT_TYPE_NORMALIZATION)
accident_type = _clean_text(accident_type).replace(ACCIDENT_TYPE_NORMALIZATION)
accident_df['事故类型'] = accident_type.fillna('财损')
# Location column harmonisation
if require_location and not location_columns_available and location_series.isna().all():
raise ValueError("事故数据缺少“事故具体地点”字段,请确认模板是否与 sample/事故处理 中示例一致。")
accident_df['事故具体地点'] = _clean_text(location_series)
# Drop records with missing core fields
subset = ['事故时间', '所在街道', '事故类型']
if require_location:
subset.append('事故具体地点')
accident_df = accident_df.dropna(subset=subset)
# Severity score
severity_map = {'财损': 1, '伤人': 2, '亡人': 4}
accident_df['severity'] = accident_df['事故类型'].map(severity_map).fillna(1).astype(int)
accident_df = accident_df.sort_values('事故时间').reset_index(drop=True)
return accident_df
@st.cache_data(show_spinner=False)
def load_and_clean_data(accident_file, strategy_file):
accident_records = load_accident_records(accident_file)
accident_data = accident_records.rename(
columns={'事故时间': 'date_time', '所在街道': 'region', '事故类型': 'category'}
)
_maybe_seek_start(strategy_file)
strategy_df = pd.read_excel(strategy_file)
strategy_df = strategy_df.rename(columns=lambda col: str(col).strip())
if '发布时间' not in strategy_df.columns:
raise ValueError("策略数据缺少“发布时间”字段,请确认文件格式。")
strategy_df['发布时间'] = pd.to_datetime(strategy_df['发布时间'], errors='coerce')
if '交通策略类型' not in strategy_df.columns:
raise ValueError("策略数据缺少“交通策略类型”字段,请确认文件格式。")
strategy_df['交通策略类型'] = _clean_text(strategy_df['交通策略类型'])
strategy_df = strategy_df.dropna(subset=['发布时间', '交通策略类型'])
accident_data = accident_data[['date_time', 'region', 'category', 'severity']]
strategy_df = strategy_df[['发布时间', '交通策略类型']].rename(
columns={'发布时间': 'date_time', '交通策略类型': 'strategy_type'}
)
return accident_data, strategy_df
@st.cache_data(show_spinner=False)
def aggregate_daily_data(accident_data: pd.DataFrame, strategy_data: pd.DataFrame) -> pd.DataFrame:
accident_data = accident_data.copy()
strategy_data = strategy_data.copy()
accident_data['date'] = accident_data['date_time'].dt.date
daily_accidents = accident_data.groupby('date').agg(
accident_count=('date_time', 'count'),
severity=('severity', 'sum')
)
daily_accidents.index = pd.to_datetime(daily_accidents.index)
strategy_data['date'] = strategy_data['date_time'].dt.date
daily_strategies = strategy_data.groupby('date')['strategy_type'].apply(list)
daily_strategies.index = pd.to_datetime(daily_strategies.index)
combined = daily_accidents.join(daily_strategies, how='left')
combined['strategy_type'] = combined['strategy_type'].apply(lambda x: x if isinstance(x, list) else [])
combined = combined.asfreq('D')
combined[['accident_count', 'severity']] = combined[['accident_count', 'severity']].fillna(0)
combined['strategy_type'] = combined['strategy_type'].apply(lambda x: x if isinstance(x, list) else [])
return combined
@st.cache_data(show_spinner=False)
def aggregate_daily_data_by_region(accident_data: pd.DataFrame, strategy_data: pd.DataFrame) -> pd.DataFrame:
df = accident_data.copy()
df['date'] = df['date_time'].dt.date
g = df.groupby(['region', 'date']).agg(
accident_count=('date_time', 'count'),
severity=('severity', 'sum')
)
g.index = g.index.set_levels([g.index.levels[0], pd.to_datetime(g.index.levels[1])])
g = g.sort_index()
s = strategy_data.copy()
s['date'] = s['date_time'].dt.date
daily_strategies = s.groupby('date')['strategy_type'].apply(list)
daily_strategies.index = pd.to_datetime(daily_strategies.index)
regions = g.index.get_level_values(0).unique()
dates = pd.date_range(g.index.get_level_values(1).min(), g.index.get_level_values(1).max(), freq='D')
full_index = pd.MultiIndex.from_product([regions, dates], names=['region', 'date'])
g = g.reindex(full_index).fillna(0)
strat_map = daily_strategies.to_dict()
g = g.assign(strategy_type=[strat_map.get(d, []) for d in g.index.get_level_values('date')])
return g

86
services/metrics.py Normal file
View File

@@ -0,0 +1,86 @@
from __future__ import annotations
import numpy as np
import pandas as pd
from sklearn.metrics import mean_absolute_error, mean_squared_error
from statsmodels.tsa.arima.model import ARIMA
import streamlit as st
@st.cache_data(show_spinner=False)
def evaluate_models(series: pd.Series,
horizon: int = 30,
lookback: int = 14,
p_values: range = range(0, 4),
d_values: range = range(0, 2),
q_values: range = range(0, 4)) -> pd.DataFrame:
"""
留出法(最后 horizon 天作为验证集)比较 ARIMA / KNN / GLM / SVR
输出 MAE・RMSE・MAPE并按 RMSE 升序排序。
"""
series = series.asfreq('D').fillna(0)
if len(series) <= horizon + 10:
raise ValueError("序列太短,无法留出 %d 天进行评估。" % horizon)
train, test = series.iloc[:-horizon], series.iloc[-horizon:]
def _to_series_like(pred, a_index):
if isinstance(pred, pd.Series):
return pred.reindex(a_index)
return pd.Series(pred, index=a_index)
def _metrics(a: pd.Series, p) -> dict:
p = _to_series_like(p, a.index).astype(float)
a = a.astype(float)
mae = mean_absolute_error(a, p)
try:
rmse = mean_squared_error(a, p, squared=False)
except TypeError:
rmse = mean_squared_error(a, p) ** 0.5
mape = np.nanmean(np.abs((a - p) / np.where(a == 0, np.nan, a))) * 100
return {"MAE": mae, "RMSE": rmse, "MAPE": mape}
results = {}
best_aic, best_order = float('inf'), (1, 0, 1)
for p in p_values:
for d in d_values:
for q in q_values:
try:
aic = ARIMA(train, order=(p, d, q)).fit().aic
if aic < best_aic:
best_aic, best_order = aic, (p, d, q)
except Exception:
continue
arima_train = train.asfreq('D').fillna(0)
arima_pred = ARIMA(arima_train, order=best_order).fit().forecast(steps=horizon)
results['ARIMA'] = _metrics(test, arima_pred)
# Import local utilities to avoid circular dependencies
from services.forecast import knn_forecast_counterfactual, fit_and_extrapolate
try:
knn_pred, _ = knn_forecast_counterfactual(series,
train.index[-1] + pd.Timedelta(days=1),
lookback=lookback,
horizon=horizon)
if knn_pred is not None:
results['KNN'] = _metrics(test, knn_pred)
except Exception:
pass
try:
glm_pred, svr_pred, _ = fit_and_extrapolate(series,
train.index[-1] + pd.Timedelta(days=1),
days=horizon)
if glm_pred is not None:
results['GLM'] = _metrics(test, glm_pred)
if svr_pred is not None:
results['SVR'] = _metrics(test, svr_pred)
except Exception:
pass
return (pd.DataFrame(results)
.T.sort_values('RMSE')
.round(3))

157
services/strategy.py Normal file
View File

@@ -0,0 +1,157 @@
from __future__ import annotations
import pandas as pd
import streamlit as st
from services.forecast import fit_and_extrapolate, arima_forecast_with_grid_search
from config.settings import MIN_PRE_DAYS, MAX_PRE_DAYS
def evaluate_strategy_effectiveness(actual_series: pd.Series,
counterfactual_series: pd.Series,
severity_series: pd.Series,
strategy_date: pd.Timestamp,
window: int = 30):
strategy_date = pd.to_datetime(strategy_date)
window_end = strategy_date + pd.Timedelta(days=window - 1)
pre_sev = severity_series.loc[strategy_date - pd.Timedelta(days=window):strategy_date - pd.Timedelta(days=1)].sum()
post_sev = severity_series.loc[strategy_date:window_end].sum()
actual_post = actual_series.loc[strategy_date:window_end]
counter_post = counterfactual_series.loc[strategy_date:window_end].reindex(actual_post.index)
window_len = len(actual_post)
if window_len == 0:
return False, False, (0.0, 0.0), '三级'
effective_days = (actual_post < counter_post).sum()
count_effective = effective_days >= (window_len / 2)
severity_effective = post_sev < pre_sev
cf_sum = counter_post.sum()
F1 = (cf_sum - actual_post.sum()) / cf_sum if cf_sum > 0 else 0.0
F2 = (pre_sev - post_sev) / pre_sev if pre_sev > 0 else 0.0
if F1 > 0.5 and F2 > 0.5:
safety_state = '一级'
elif F1 > 0.3:
safety_state = '二级'
else:
safety_state = '三级'
return count_effective, severity_effective, (F1, F2), safety_state
@st.cache_data(show_spinner=False)
def generate_output_and_recommendations(combined_data: pd.DataFrame,
strategy_types: list,
region: str = '全市',
horizon: int = 30):
results = {}
combined_data = combined_data.copy().asfreq('D')
combined_data[['accident_count','severity']] = combined_data[['accident_count','severity']].fillna(0)
combined_data['strategy_type'] = combined_data['strategy_type'].apply(lambda x: x if isinstance(x, list) else [])
acc_full = combined_data['accident_count']
sev_full = combined_data['severity']
max_fit_days = max(horizon + 60, MAX_PRE_DAYS)
for strategy in strategy_types:
has_strategy = combined_data['strategy_type'].apply(lambda x: strategy in x)
if not has_strategy.any():
continue
candidate_dates = has_strategy[has_strategy].index
intervention_date = None
fit_start_dt = None
for dt in candidate_dates:
fit_start_dt = max(acc_full.index.min(), dt - pd.Timedelta(days=max_fit_days))
pre_hist = acc_full.loc[fit_start_dt:dt - pd.Timedelta(days=1)]
if len(pre_hist) >= MIN_PRE_DAYS:
intervention_date = dt
break
if intervention_date is None:
intervention_date = candidate_dates[0]
fit_start_dt = max(acc_full.index.min(), intervention_date - pd.Timedelta(days=max_fit_days))
acc = acc_full.loc[fit_start_dt:]
sev = sev_full.loc[fit_start_dt:]
horizon_eff = max(7, min(horizon, len(acc.loc[intervention_date:]) ))
glm_pred, svr_pred, residuals = fit_and_extrapolate(acc, intervention_date, days=horizon_eff)
counter = None
if svr_pred is not None:
counter = svr_pred
elif glm_pred is not None:
counter = glm_pred
else:
try:
arima_df = arima_forecast_with_grid_search(acc.loc[:intervention_date],
start_date=intervention_date + pd.Timedelta(days=1),
horizon=horizon_eff)
counter = pd.Series(arima_df['forecast'].values, index=arima_df.index, name='cf_arima')
residuals = (acc.reindex(counter.index) - counter)
except Exception:
counter = None
if counter is None:
continue
count_eff, sev_eff, (F1, F2), state = evaluate_strategy_effectiveness(
actual_series=acc,
counterfactual_series=counter,
severity_series=sev,
strategy_date=intervention_date,
window=horizon_eff
)
results[strategy] = {
'effect_strength': float(residuals.dropna().mean()) if residuals is not None else 0.0,
'adaptability': float(F1 + F2),
'count_effective': bool(count_eff),
'severity_effective': bool(sev_eff),
'safety_state': state,
'F1': float(F1),
'F2': float(F2),
'intervention_date': str(intervention_date.date())
}
# Secondary attempt with 14-day window if no results
if not results:
for strategy in strategy_types:
has_strategy = combined_data['strategy_type'].apply(lambda x: strategy in x)
if not has_strategy.any():
continue
intervention_date = has_strategy[has_strategy].index[0]
glm_pred, svr_pred, residuals = fit_and_extrapolate(acc_full, intervention_date, days=14)
counter = None
if svr_pred is not None:
counter = svr_pred
elif glm_pred is not None:
counter = glm_pred
else:
try:
arima_df = arima_forecast_with_grid_search(acc_full.loc[:intervention_date],
start_date=intervention_date + pd.Timedelta(days=1),
horizon=14)
counter = pd.Series(arima_df['forecast'].values, index=arima_df.index, name='cf_arima')
residuals = (acc_full.reindex(counter.index) - counter)
except Exception:
counter = None
if counter is None:
continue
count_eff, sev_eff, (F1, F2), state = evaluate_strategy_effectiveness(
actual_series=acc_full,
counterfactual_series=counter,
severity_series=sev_full,
strategy_date=intervention_date,
window=14
)
results[strategy] = {
'effect_strength': float(residuals.dropna().mean()) if residuals is not None else 0.0,
'adaptability': float(F1 + F2),
'count_effective': bool(count_eff),
'severity_effective': bool(sev_eff),
'safety_state': state,
'F1': float(F1),
'F2': float(F2),
'intervention_date': str(intervention_date.date())
}
best_strategy = max(results, key=lambda x: results[x]['adaptability']) if results else None
recommendation = f"建议在{region}区域长期实施策略类型 {best_strategy}" if best_strategy else "无足够数据推荐策略"
return results, recommendation

File diff suppressed because one or more lines are too long

View File

@@ -1,5 +0,0 @@
,effect_strength,adaptability,count_effective,severity_effective,safety_state,F1,F2,intervention_date
交通信息预警,-8.965321179202334,-0.7855379968058066,True,False,三级,0.2463091369521552,-1.0318471337579618,2022-01-13
交通整治活动,-2.651006128785241,-1.667254385637472,True,False,三级,0.08411173458110731,-1.7513661202185793,2022-01-11
交通管制措施,-10.70286313762653,0.19010392243197832,True,False,三级,0.2989387495766646,-0.1088348271446863,2022-01-20
政策制度实施,-2.6771799687750018,-5.1316650216481605,True,False,三级,0.07856225107911223,-5.2102272727272725,2022-01-06
1 effect_strength adaptability count_effective severity_effective safety_state F1 F2 intervention_date
2 交通信息预警 -8.965321179202334 -0.7855379968058066 True False 三级 0.2463091369521552 -1.0318471337579618 2022-01-13
3 交通整治活动 -2.651006128785241 -1.667254385637472 True False 三级 0.08411173458110731 -1.7513661202185793 2022-01-11
4 交通管制措施 -10.70286313762653 0.19010392243197832 True False 三级 0.2989387495766646 -0.1088348271446863 2022-01-20
5 政策制度实施 -2.6771799687750018 -5.1316650216481605 True False 三级 0.07856225107911223 -5.2102272727272725 2022-01-06

13
ui_sections/__init__.py Normal file
View File

@@ -0,0 +1,13 @@
from .overview import render_overview
from .forecast import render_forecast
from .model_eval import render_model_eval
from .strategy_eval import render_strategy_eval
from .hotspot import render_hotspot
__all__ = [
'render_overview',
'render_forecast',
'render_model_eval',
'render_strategy_eval',
'render_hotspot',
]

185
ui_sections/forecast.py Normal file
View File

@@ -0,0 +1,185 @@
from __future__ import annotations
import pandas as pd
import plotly.graph_objects as go
import streamlit as st
from services.forecast import (
arima_forecast_with_grid_search,
knn_forecast_counterfactual,
fit_and_extrapolate,
)
def render_forecast(base: pd.DataFrame):
st.subheader("多模型预测比较")
if base is None or base.empty:
st.info("暂无可用于预测的事故数据,请先在侧边栏上传数据并应用筛选。")
st.session_state.setdefault(
"forecast_state",
{"results": None, "last_message": "暂无可用于预测的事故数据。"},
)
return
forecast_state = st.session_state.setdefault(
"forecast_state",
{
"selected_date": None,
"horizon": 30,
"results": None,
"last_message": None,
"data_signature": None,
},
)
earliest_date = base.index.min().date()
latest_date = base.index.max().date()
fallback_date = max(
(base.index.max() - pd.Timedelta(days=30)).date(),
earliest_date,
)
current_signature = (
earliest_date.isoformat(),
latest_date.isoformat(),
int(len(base)),
float(base["accident_count"].sum()),
)
# Reset cached results if the underlying dataset has changed
if forecast_state.get("data_signature") != current_signature:
forecast_state.update(
{
"data_signature": current_signature,
"results": None,
"last_message": None,
"selected_date": fallback_date,
}
)
default_date = forecast_state.get("selected_date") or fallback_date
if default_date < earliest_date:
default_date = earliest_date
if default_date > latest_date:
default_date = latest_date
with st.form(key="predict_form"):
selected_date = st.date_input(
"选择干预日期 / 预测起点",
value=default_date,
min_value=earliest_date,
max_value=latest_date,
)
horizon = st.number_input(
"预测天数",
min_value=7,
max_value=90,
value=int(forecast_state.get("horizon", 30)),
step=1,
)
submit_predict = st.form_submit_button("应用预测参数")
forecast_state["selected_date"] = selected_date
forecast_state["horizon"] = int(horizon)
if submit_predict:
history = base.loc[:pd.to_datetime(selected_date)]
if len(history) < 10:
forecast_state.update(
{
"results": None,
"last_message": "干预前数据不足(至少需要 10 个观测点)。",
}
)
else:
with st.spinner("正在生成预测结果…"):
warnings: list[str] = []
try:
train_series = history["accident_count"]
arima_df = arima_forecast_with_grid_search(
train_series,
start_date=pd.to_datetime(selected_date) + pd.Timedelta(days=1),
horizon=int(horizon),
)
except Exception as exc:
arima_df = None
warnings.append(f"ARIMA 运行失败:{exc}")
knn_pred, _ = knn_forecast_counterfactual(
base["accident_count"],
pd.to_datetime(selected_date),
horizon=int(horizon),
)
if knn_pred is None:
warnings.append("KNN 预测未生成结果(历史数据不足或维度不满足要求)。")
glm_pred, svr_pred, _ = fit_and_extrapolate(
base["accident_count"],
pd.to_datetime(selected_date),
days=int(horizon),
)
if glm_pred is None and svr_pred is None:
warnings.append("GLM/SVR 预测未生成结果,建议缩短预测窗口或检查源数据。")
forecast_state.update(
{
"results": {
"selected_date": selected_date,
"horizon": int(horizon),
"arima_df": arima_df,
"knn_pred": knn_pred,
"glm_pred": glm_pred,
"svr_pred": svr_pred,
"warnings": warnings,
},
"last_message": None,
}
)
results = forecast_state.get("results")
if not results:
if forecast_state.get("last_message"):
st.warning(forecast_state["last_message"])
else:
st.info("请设置预测参数并点击“应用预测参数”按钮。")
return
first_date = pd.to_datetime(results["selected_date"])
horizon_days = int(results["horizon"])
arima_df = results["arima_df"]
knn_pred = results["knn_pred"]
glm_pred = results["glm_pred"]
svr_pred = results["svr_pred"]
fig_pred = go.Figure()
fig_pred.add_trace(
go.Scatter(x=base.index, y=base["accident_count"], name="实际", mode="lines")
)
if arima_df is not None:
fig_pred.add_trace(
go.Scatter(x=arima_df.index, y=arima_df["forecast"], name="ARIMA", mode="lines")
)
if knn_pred is not None:
fig_pred.add_trace(go.Scatter(x=knn_pred.index, y=knn_pred, name="KNN", mode="lines"))
if glm_pred is not None:
fig_pred.add_trace(go.Scatter(x=glm_pred.index, y=glm_pred, name="GLM", mode="lines"))
if svr_pred is not None:
fig_pred.add_trace(go.Scatter(x=svr_pred.index, y=svr_pred, name="SVR", mode="lines"))
fig_pred.update_layout(
title=f"多模型预测比较(起点:{first_date.date()},预测 {horizon_days} 天)",
xaxis_title="日期",
yaxis_title="事故数",
)
st.plotly_chart(fig_pred, use_container_width=True)
if arima_df is not None:
st.download_button(
"下载 ARIMA 预测 CSV",
data=arima_df.to_csv().encode("utf-8-sig"),
file_name="arima_forecast.csv",
mime="text/csv",
)
for warning_text in results.get("warnings", []):
st.warning(warning_text)

189
ui_sections/hotspot.py Normal file
View File

@@ -0,0 +1,189 @@
from __future__ import annotations
import json
from datetime import datetime
import plotly.express as px
import streamlit as st
from services.hotspot import (
analyze_hotspot_frequency,
calculate_hotspot_risk_score,
generate_hotspot_strategies,
prepare_hotspot_dataset,
serialise_datetime_columns,
)
@st.cache_data(show_spinner=False)
def _prepare_hotspot_data(df):
return prepare_hotspot_dataset(df)
def render_hotspot(accident_records, accident_source_name: str | None) -> None:
st.header("📍 事故多发路口分析")
st.markdown("独立分析事故数据,识别高风险路口并生成针对性策略。")
if accident_records is None:
st.info("请在左侧上传事故数据并点击“应用数据与筛选”后再执行热点分析。")
st.markdown(
"""
### 📝 支持的数据格式要求:
- **文件格式**Excel (.xlsx)
- **必要字段**
- `事故时间`
- `事故类型`
- `事故具体地点`
- `所在街道`
- `道路类型`
- `路口路段类型`
"""
)
return
with st.spinner("正在准备事故热点数据…"):
hotspot_data = _prepare_hotspot_data(accident_records)
st.success(f"✅ 成功加载数据:{len(hotspot_data)} 条事故记录")
metric_cols = st.columns(3)
with metric_cols[0]:
st.metric(
"数据时间范围",
f"{hotspot_data['事故时间'].min().strftime('%Y-%m-%d')}{hotspot_data['事故时间'].max().strftime('%Y-%m-%d')}",
)
with metric_cols[1]:
st.metric(
"事故类型分布",
f"财损: {len(hotspot_data[hotspot_data['事故类型'] == '财损'])}",
)
with metric_cols[2]:
st.metric("涉及区域", f"{hotspot_data['所在街道'].nunique()}个街道")
st.subheader("🔧 分析参数设置")
settings_cols = st.columns(3)
with settings_cols[0]:
time_window = st.selectbox(
"统计时间窗口",
options=["7D", "15D", "30D"],
index=0,
key="hotspot_window",
)
with settings_cols[1]:
min_accidents = st.number_input(
"最小事故数", min_value=1, max_value=50, value=3, key="hotspot_min_accidents"
)
with settings_cols[2]:
top_n = st.slider("显示热点数量", min_value=3, max_value=20, value=8, key="hotspot_top_n")
if not st.button("🚀 开始热点分析", type="primary"):
return
with st.spinner("正在分析事故热点分布…"):
hotspots = analyze_hotspot_frequency(hotspot_data, time_window=time_window)
hotspots = hotspots[hotspots["accident_count"] >= min_accidents]
if hotspots.empty:
st.warning("⚠️ 未找到符合条件的事故热点数据,请调整筛选参数。")
return
hotspots_with_risk = calculate_hotspot_risk_score(hotspots.head(top_n * 3))
top_hotspots = hotspots_with_risk.head(top_n)
st.subheader("📊 事故多发路口排名(前{0}个)".format(top_n))
display_columns = {
"accident_count": "累计事故数",
"recent_count": "近期事故数",
"trend_ratio": "趋势比例",
"main_accident_type": "主要类型",
"main_intersection_type": "路口类型",
"risk_score": "风险评分",
"risk_level": "风险等级",
}
display_df = top_hotspots[list(display_columns.keys())].rename(columns=display_columns)
styled_df = display_df.style.format({"趋势比例": "{:.2f}", "风险评分": "{:.1f}"}).background_gradient(
subset=["风险评分"], cmap="Reds"
)
st.dataframe(styled_df, use_container_width=True)
st.subheader("🎯 针对性策略建议")
strategies = generate_hotspot_strategies(top_hotspots, time_period="本周")
for index, strategy_info in enumerate(strategies, start=1):
message = f"**{index}. {strategy_info['strategy']}**"
risk_level = strategy_info["risk_level"]
if risk_level == "高风险":
st.error(f"🚨 {message}")
elif risk_level == "中风险":
st.warning(f"⚠️ {message}")
else:
st.info(f"{message}")
st.subheader("📈 数据分析可视化")
chart_cols = st.columns(2)
with chart_cols[0]:
fig_hotspots = px.bar(
top_hotspots.head(10),
x=top_hotspots.head(10).index,
y=["accident_count", "recent_count"],
title="事故频次TOP10分布",
labels={"value": "事故数量", "variable": "类型", "index": "路口名称"},
barmode="group",
)
fig_hotspots.update_layout(xaxis_tickangle=-45)
st.plotly_chart(fig_hotspots, use_container_width=True)
with chart_cols[1]:
risk_distribution = top_hotspots["risk_level"].value_counts()
fig_risk = px.pie(
values=risk_distribution.values,
names=risk_distribution.index,
title="风险等级分布",
color_discrete_map={"高风险": "red", "中风险": "orange", "低风险": "green"},
)
st.plotly_chart(fig_risk, use_container_width=True)
st.subheader("💾 数据导出")
download_cols = st.columns(2)
with download_cols[0]:
hotspot_csv = top_hotspots.to_csv().encode("utf-8-sig")
st.download_button(
"📥 下载热点数据CSV",
data=hotspot_csv,
file_name=f"accident_hotspots_{datetime.now().strftime('%Y%m%d')}.csv",
mime="text/csv",
)
with download_cols[1]:
serializable = serialise_datetime_columns(
top_hotspots.reset_index(),
columns=[col for col in top_hotspots.columns if "time" in col or "date" in col],
)
report_payload = {
"analysis_time": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
"time_window": time_window,
"data_source": accident_source_name or "事故数据",
"total_records": int(len(hotspot_data)),
"analysis_parameters": {"min_accidents": int(min_accidents), "top_n": int(top_n)},
"top_hotspots": serializable.to_dict("records"),
"recommended_strategies": strategies,
"summary": {
"high_risk_count": int((top_hotspots["risk_level"] == "高风险").sum()),
"medium_risk_count": int((top_hotspots["risk_level"] == "中风险").sum()),
"total_analyzed_locations": int(len(hotspots)),
"most_dangerous_location": top_hotspots.index[0]
if len(top_hotspots)
else "",
},
}
st.download_button(
"📄 下载完整分析报告",
data=json.dumps(report_payload, ensure_ascii=False, indent=2),
file_name=f"hotspot_analysis_report_{datetime.now().strftime('%Y%m%d_%H%M')}.json",
mime="application/json",
)
with st.expander("📋 查看原始数据预览"):
preview_cols = ["事故时间", "所在街道", "事故类型", "事故具体地点", "道路类型"]
preview_df = hotspot_data[preview_cols].copy()
st.dataframe(preview_df.head(10), use_container_width=True)

32
ui_sections/model_eval.py Normal file
View File

@@ -0,0 +1,32 @@
from __future__ import annotations
import pandas as pd
import streamlit as st
from services.metrics import evaluate_models
def render_model_eval(base: pd.DataFrame):
st.subheader("模型预测效果对比")
with st.form(key="model_eval_form"):
horizon_sel = st.slider("评估窗口(天)", 7, 60, 30, step=1)
submit_eval = st.form_submit_button("应用评估参数")
if not submit_eval:
st.info("请设置评估窗口并点击“应用评估参数”按钮。")
return
try:
df_metrics = evaluate_models(base['accident_count'], horizon=int(horizon_sel))
st.dataframe(df_metrics, use_container_width=True)
best_model = df_metrics['RMSE'].idxmin()
st.success(f"过去 {int(horizon_sel)} 天中RMSE 最低的模型是:**{best_model}**")
st.download_button(
"下载评估结果 CSV",
data=df_metrics.to_csv().encode('utf-8-sig'),
file_name="model_evaluation.csv",
mime="text/csv",
)
except ValueError as err:
st.warning(str(err))

33
ui_sections/overview.py Normal file
View File

@@ -0,0 +1,33 @@
from __future__ import annotations
import json
import pandas as pd
import plotly.graph_objects as go
import streamlit as st
def render_overview(base: pd.DataFrame, region_sel: str, start_dt: pd.Timestamp, end_dt: pd.Timestamp,
strat_filter: list[str]):
fig_line = go.Figure()
fig_line.add_trace(go.Scatter(x=base.index, y=base['accident_count'], name='事故数', mode='lines'))
fig_line.update_layout(title="事故数(过滤后)", xaxis_title="Date", yaxis_title="Count")
st.plotly_chart(fig_line, use_container_width=True)
html = fig_line.to_html(full_html=True, include_plotlyjs='cdn')
st.download_button("下载图表 HTML", data=html.encode('utf-8'),
file_name="overview_series.html", mime="text/html")
st.dataframe(base, use_container_width=True)
csv_bytes = base.to_csv(index=True).encode('utf-8-sig')
st.download_button("下载当前视图 CSV", data=csv_bytes, file_name="filtered_view.csv", mime="text/csv")
meta = {
"region": region_sel,
"date_range": [str(start_dt.date()), str(end_dt.date())],
"strategy_filter": strat_filter,
"rows": int(len(base)),
"min_date": str(base.index.min().date()) if len(base) else None,
"max_date": str(base.index.max().date()) if len(base) else None,
}
st.download_button("下载运行参数 JSON", data=json.dumps(meta, ensure_ascii=False, indent=2).encode('utf-8'),
file_name="run_metadata.json", mime="application/json")

View File

@@ -0,0 +1,50 @@
from __future__ import annotations
import os
import pandas as pd
import streamlit as st
from services.strategy import generate_output_and_recommendations
def render_strategy_eval(base: pd.DataFrame, all_strategy_types: list[str], region_sel: str):
st.info(f"📌 检测到的策略类型:{', '.join(all_strategy_types) or '(数据中没有策略)'}")
if not all_strategy_types:
st.warning("数据中没有检测到策略。")
return
with st.form(key="strategy_eval_form"):
horizon_eval = st.slider("评估窗口(天)", 7, 60, 14, step=1)
submit_strat_eval = st.form_submit_button("应用评估参数")
if not submit_strat_eval:
st.info("请设置评估窗口并点击“应用评估参数”按钮。")
return
results, recommendation = generate_output_and_recommendations(
base,
all_strategy_types,
region=region_sel if region_sel != '全市' else '全市',
horizon=horizon_eval,
)
if not results:
st.warning("⚠️ 未能完成策略评估。请尝试缩短评估窗口或扩大日期范围。")
return
st.subheader("各策略指标")
df_res = pd.DataFrame(results).T
st.dataframe(df_res, use_container_width=True)
st.success(f"⭐ 推荐:{recommendation}")
st.download_button(
"下载策略评估 CSV",
data=df_res.to_csv().encode('utf-8-sig'),
file_name="strategy_evaluation_results.csv",
mime="text/csv",
)
if os.path.exists('recommendation.txt'):
with open('recommendation.txt','r',encoding='utf-8') as f:
st.download_button("下载推荐文本", data=f.read().encode('utf-8'), file_name="recommendation.txt")