Compare commits
8 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| d8eea8e3a9 | |||
| 5825cf81b7 | |||
| a5e3c4c1da | |||
| 69488904a0 | |||
| 00e766eaa7 | |||
| af4285e147 | |||
| c69419d816 | |||
| a9845d084e |
12
.gitignore
vendored
Normal file
12
.gitignore
vendored
Normal file
@@ -0,0 +1,12 @@
|
||||
sample/
|
||||
strategy_evaluation_results.csv
|
||||
run_metadata.json
|
||||
*.log
|
||||
simulation.html
|
||||
.DS_Store
|
||||
overview_series.html
|
||||
tmp/
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
.venv/
|
||||
.streamlit/
|
||||
36
AGENTS.md
Normal file
36
AGENTS.md
Normal file
@@ -0,0 +1,36 @@
|
||||
# Repository Guidelines
|
||||
|
||||
## Project Structure & Module Organization
|
||||
- Source: `app.py` (Streamlit UI, data processing, forecasting, anomaly detection, evaluation).
|
||||
- Docs & outputs: `docs/`, `overview_series.html`, `strategy_evaluation_results.csv`.
|
||||
- Samples: `sample/` for example data only; avoid sensitive content.
|
||||
- Meta: `requirements.txt`, `readme.md`, `LICENSE`, `CHANGELOG.md`.
|
||||
|
||||
## Build, Test, and Development Commands
|
||||
- Create env: `python -m venv .venv && source .venv/bin/activate` (or follow conda steps in `readme.md`).
|
||||
- Install deps: `pip install -r requirements.txt`.
|
||||
- Run app: `streamlit run app.py` then open `http://localhost:8501`.
|
||||
- Export artifacts: charts save as HTML (Plotly); forecasts may be written to CSV as noted in `readme.md`.
|
||||
|
||||
## Coding Style & Naming Conventions
|
||||
- Python ≥3.8; 4-space indentation; UTF-8.
|
||||
- Names: functions/variables `snake_case`; classes `PascalCase`; constants `UPPER_SNAKE_CASE`.
|
||||
- Files: keep scope focused; use descriptive output names (e.g., `arima_forecast.csv`).
|
||||
- Data handling: prefer pandas/NumPy vectorization; validate inputs; avoid global state except constants.
|
||||
|
||||
## Testing Guidelines
|
||||
- Framework: pytest (recommended). Place tests under `tests/`.
|
||||
- Naming: `test_<module>.py` and `test_<behavior>()`.
|
||||
- Run: `pytest -q`. Focus on `load_and_clean_data`, aggregation, model selection, and metrics.
|
||||
- Keep tests fast and deterministic; avoid large I/O. Use small DataFrame fixtures.
|
||||
|
||||
## Commit & Pull Request Guidelines
|
||||
- Messages: concise, present tense. Prefixes seen: `modify:`, `Add`, `Update`.
|
||||
- Include scope and reason: e.g., `modify: update requirements for statsmodels`.
|
||||
- PRs: clear description, linked issues, repro steps/screenshots for UI, and notes on any schema or output changes.
|
||||
|
||||
## Security & Configuration Tips
|
||||
- Do not commit real accident data or secrets. Use `sample/` for examples.
|
||||
- Optional envs: `LOG_LEVEL=DEBUG`. Keep any API keys in environment variables, not in code.
|
||||
- Validate Excel column names before processing; handle missing columns/rows defensively.
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
## [1.1.0] - 2025-08-28
|
||||
|
||||
### Added
|
||||
- Integrated GPT-based analysis for comprehensive traffic safety insights
|
||||
- Integrated AI-based analysis for comprehensive traffic safety insights
|
||||
- Added automated report generation with AI-powered recommendations
|
||||
- Implemented natural language query processing for data exploration
|
||||
- Added export functionality for analysis reports (PDF/CSV formats)
|
||||
@@ -22,7 +22,7 @@
|
||||
- Addressed memory leaks in large dataset processing
|
||||
|
||||
### Documentation
|
||||
- Updated README with new GPT analysis features and usage examples
|
||||
- Updated README with new AI analysis features and usage examples
|
||||
- Added API documentation for extended functionality
|
||||
- Included sample datasets and tutorial guides
|
||||
|
||||
|
||||
765
app.py
765
app.py
@@ -1,23 +1,17 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from datetime import datetime, timedelta
|
||||
import json
|
||||
import hashlib
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
from typing import Optional
|
||||
|
||||
from sklearn.neighbors import KNeighborsRegressor
|
||||
from sklearn.ensemble import IsolationForest
|
||||
from sklearn.svm import SVR
|
||||
|
||||
import statsmodels.api as sm
|
||||
from statsmodels.tsa.arima.model import ARIMA
|
||||
|
||||
import streamlit as st
|
||||
import plotly.graph_objects as go
|
||||
import plotly.express as px
|
||||
|
||||
# --- Optional deps (graceful fallback)
|
||||
try:
|
||||
@@ -40,179 +34,41 @@ except Exception:
|
||||
HAS_OPENAI = False
|
||||
|
||||
|
||||
# =======================
|
||||
# 1. Data Integration
|
||||
# =======================
|
||||
@st.cache_data(show_spinner=False)
|
||||
def load_and_clean_data(accident_file, strategy_file):
|
||||
accident_df = pd.read_excel(accident_file, sheet_name=None)
|
||||
accident_data = pd.concat(accident_df.values(), ignore_index=True)
|
||||
|
||||
accident_data['事故时间'] = pd.to_datetime(accident_data['事故时间'])
|
||||
accident_data = accident_data.dropna(subset=['事故时间', '所在街道', '事故类型'])
|
||||
|
||||
strategy_df = pd.read_excel(strategy_file)
|
||||
strategy_df['发布时间'] = pd.to_datetime(strategy_df['发布时间'])
|
||||
strategy_df = strategy_df.dropna(subset=['发布时间', '交通策略类型'])
|
||||
|
||||
severity_map = {'财损': 1, '伤人': 2, '亡人': 4}
|
||||
accident_data['severity'] = accident_data['事故类型'].map(severity_map).fillna(1)
|
||||
|
||||
accident_data = accident_data[['事故时间', '所在街道', '事故类型', 'severity']] \
|
||||
.rename(columns={'事故时间': 'date_time', '所在街道': 'region', '事故类型': 'category'})
|
||||
strategy_df = strategy_df[['发布时间', '交通策略类型']] \
|
||||
.rename(columns={'发布时间': 'date_time', '交通策略类型': 'strategy_type'})
|
||||
|
||||
return accident_data, strategy_df
|
||||
|
||||
|
||||
@st.cache_data(show_spinner=False)
|
||||
def aggregate_daily_data(accident_data: pd.DataFrame, strategy_data: pd.DataFrame) -> pd.DataFrame:
|
||||
# City-level aggregation
|
||||
accident_data = accident_data.copy()
|
||||
strategy_data = strategy_data.copy()
|
||||
|
||||
accident_data['date'] = accident_data['date_time'].dt.date
|
||||
daily_accidents = accident_data.groupby('date').agg(
|
||||
accident_count=('date_time', 'count'),
|
||||
severity=('severity', 'sum')
|
||||
from services.io import (
|
||||
load_and_clean_data,
|
||||
aggregate_daily_data,
|
||||
aggregate_daily_data_by_region,
|
||||
load_accident_records,
|
||||
)
|
||||
daily_accidents.index = pd.to_datetime(daily_accidents.index)
|
||||
|
||||
strategy_data['date'] = strategy_data['date_time'].dt.date
|
||||
daily_strategies = strategy_data.groupby('date')['strategy_type'].apply(list)
|
||||
daily_strategies.index = pd.to_datetime(daily_strategies.index)
|
||||
|
||||
combined = daily_accidents.join(daily_strategies, how='left')
|
||||
combined['strategy_type'] = combined['strategy_type'].apply(lambda x: x if isinstance(x, list) else [])
|
||||
combined = combined.asfreq('D')
|
||||
combined[['accident_count', 'severity']] = combined[['accident_count', 'severity']].fillna(0)
|
||||
combined['strategy_type'] = combined['strategy_type'].apply(lambda x: x if isinstance(x, list) else [])
|
||||
return combined
|
||||
|
||||
|
||||
@st.cache_data(show_spinner=False)
|
||||
def aggregate_daily_data_by_region(accident_data: pd.DataFrame, strategy_data: pd.DataFrame) -> pd.DataFrame:
|
||||
"""区域维度聚合。策略按天广播到所有区域(若策略本身无区域字段)。"""
|
||||
df = accident_data.copy()
|
||||
df['date'] = df['date_time'].dt.date
|
||||
g = df.groupby(['region', 'date']).agg(
|
||||
accident_count=('date_time', 'count'),
|
||||
severity=('severity', 'sum')
|
||||
from services.forecast import (
|
||||
arima_forecast_with_grid_search,
|
||||
knn_forecast_counterfactual,
|
||||
fit_and_extrapolate,
|
||||
)
|
||||
g.index = g.index.set_levels([g.index.levels[0], pd.to_datetime(g.index.levels[1])])
|
||||
g = g.sort_index()
|
||||
from services.strategy import (
|
||||
evaluate_strategy_effectiveness,
|
||||
generate_output_and_recommendations,
|
||||
)
|
||||
from services.metrics import evaluate_models
|
||||
|
||||
# 策略(每日列表)
|
||||
s = strategy_data.copy()
|
||||
s['date'] = s['date_time'].dt.date
|
||||
daily_strategies = s.groupby('date')['strategy_type'].apply(list)
|
||||
daily_strategies.index = pd.to_datetime(daily_strategies.index)
|
||||
|
||||
# 广播
|
||||
regions = g.index.get_level_values(0).unique()
|
||||
dates = pd.date_range(g.index.get_level_values(1).min(), g.index.get_level_values(1).max(), freq='D')
|
||||
full_index = pd.MultiIndex.from_product([regions, dates], names=['region', 'date'])
|
||||
g = g.reindex(full_index).fillna(0)
|
||||
|
||||
strat_map = daily_strategies.to_dict()
|
||||
g = g.assign(strategy_type=[strat_map.get(d, []) for d in g.index.get_level_values('date')])
|
||||
return g
|
||||
|
||||
|
||||
from statsmodels.tsa.arima.model import ARIMA
|
||||
from statsmodels.tools.sm_exceptions import ValueWarning
|
||||
import warnings
|
||||
|
||||
def evaluate_arima_model(series, arima_order):
|
||||
"""Fit ARIMA model and return AIC for evaluation."""
|
||||
try:
|
||||
model = ARIMA(series, order=arima_order)
|
||||
model_fit = model.fit()
|
||||
return model_fit.aic
|
||||
except Exception:
|
||||
return float("inf")
|
||||
|
||||
def arima_forecast_with_grid_search(accident_series: pd.Series, start_date: pd.Timestamp,
|
||||
horizon: int = 30, p_values: list = range(0, 6),
|
||||
d_values: list = range(0, 2), q_values: list = range(0, 6)) -> pd.DataFrame:
|
||||
# Pre-process series
|
||||
series = accident_series.asfreq('D').fillna(0)
|
||||
start_date = pd.to_datetime(start_date)
|
||||
|
||||
# Suppress warnings
|
||||
warnings.filterwarnings("ignore", category=ValueWarning)
|
||||
|
||||
# Define the hyperparameters to search through
|
||||
best_score, best_cfg = float("inf"), None
|
||||
|
||||
for p in p_values:
|
||||
for d in d_values:
|
||||
for q in q_values:
|
||||
order = (p, d, q)
|
||||
try:
|
||||
aic = evaluate_arima_model(series, order)
|
||||
if aic < best_score:
|
||||
best_score, best_cfg = aic, order
|
||||
except Exception as e:
|
||||
continue
|
||||
|
||||
# Fit the model with the best found order
|
||||
print(best_cfg)
|
||||
model = ARIMA(series, order=best_cfg)
|
||||
fit = model.fit()
|
||||
|
||||
# Forecasting
|
||||
forecast_index = pd.date_range(start=start_date, periods=horizon, freq='D')
|
||||
res = fit.get_forecast(steps=horizon)
|
||||
df = res.summary_frame()
|
||||
df.index = forecast_index
|
||||
df.index.name = 'date'
|
||||
df.rename(columns={'mean': 'forecast'}, inplace=True)
|
||||
|
||||
return df
|
||||
|
||||
# Example usage:
|
||||
# dataframe = your_data_frame_here
|
||||
# forecast_df = arima_forecast_with_grid_search(dataframe['accident_count'], start_date=pd.Timestamp('YYYY-MM-DD'), horizon=30)
|
||||
|
||||
|
||||
def knn_forecast_counterfactual(accident_series: pd.Series,
|
||||
intervention_date: pd.Timestamp,
|
||||
lookback: int = 14,
|
||||
horizon: int = 30):
|
||||
series = accident_series.asfreq('D').fillna(0)
|
||||
intervention_date = pd.to_datetime(intervention_date).normalize()
|
||||
|
||||
df = pd.DataFrame({'y': series})
|
||||
for i in range(1, lookback + 1):
|
||||
df[f'lag_{i}'] = df['y'].shift(i)
|
||||
|
||||
train = df.loc[:intervention_date - pd.Timedelta(days=1)].dropna()
|
||||
if len(train) < 5:
|
||||
return None, None
|
||||
X_train = train.filter(like='lag_').values
|
||||
y_train = train['y'].values
|
||||
knn = KNeighborsRegressor(n_neighbors=5)
|
||||
knn.fit(X_train, y_train)
|
||||
|
||||
history = df.loc[:intervention_date - pd.Timedelta(days=1), 'y'].tolist()
|
||||
preds = []
|
||||
for _ in range(horizon):
|
||||
if len(history) < lookback:
|
||||
return None, None
|
||||
x = np.array(history[-lookback:][::-1]).reshape(1, -1)
|
||||
pred = knn.predict(x)[0]
|
||||
preds.append(pred)
|
||||
history.append(pred)
|
||||
|
||||
pred_index = pd.date_range(intervention_date, periods=horizon, freq='D')
|
||||
return pd.Series(preds, index=pred_index, name='knn_pred'), None
|
||||
|
||||
from ui_sections import (
|
||||
render_overview,
|
||||
render_forecast,
|
||||
render_model_eval,
|
||||
render_strategy_eval,
|
||||
render_hotspot,
|
||||
)
|
||||
except Exception: # pragma: no cover - fallback to inline logic
|
||||
render_overview = None
|
||||
render_forecast = None
|
||||
render_model_eval = None
|
||||
render_strategy_eval = None
|
||||
render_hotspot = None
|
||||
|
||||
def detect_anomalies(series: pd.Series, contamination: float = 0.1):
|
||||
series = series.asfreq('D').fillna(0)
|
||||
iso = IsolationForest(contamination=contamination, random_state=42)
|
||||
iso = IsolationForest(n_estimators=50, contamination=contamination, random_state=42, n_jobs=-1)
|
||||
yhat = iso.fit_predict(series.values.reshape(-1, 1))
|
||||
anomaly_mask = (yhat == -1)
|
||||
anomaly_indices = series.index[anomaly_mask]
|
||||
@@ -250,130 +106,11 @@ def intervention_model(series: pd.Series,
|
||||
return Y_t, Z_t
|
||||
|
||||
|
||||
def fit_and_extrapolate(series: pd.Series,
|
||||
intervention_date: pd.Timestamp,
|
||||
days: int = 30):
|
||||
|
||||
series = series.asfreq('D').fillna(0)
|
||||
# 统一为无时区、按天的时间戳
|
||||
series.index = pd.to_datetime(series.index).tz_localize(None).normalize()
|
||||
intervention_date = pd.to_datetime(intervention_date).tz_localize(None).normalize()
|
||||
|
||||
pre = series.loc[:intervention_date - pd.Timedelta(days=1)]
|
||||
if len(pre) < 5:
|
||||
return None, None, None
|
||||
|
||||
x_pre = np.arange(len(pre))
|
||||
x_future = np.arange(len(pre), len(pre) + days)
|
||||
|
||||
# 1️⃣ GLM:加入二次项
|
||||
X_pre_glm = sm.add_constant(np.column_stack([x_pre, x_pre**2]))
|
||||
glm = sm.GLM(pre.values, X_pre_glm, family=sm.families.Poisson())
|
||||
glm_res = glm.fit()
|
||||
X_future_glm = sm.add_constant(np.column_stack([x_future, x_future**2]))
|
||||
glm_pred = glm_res.predict(X_future_glm)
|
||||
|
||||
# SVR
|
||||
# 2️⃣ SVR:加标准化 & 调参 / 改线性核
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
from sklearn.pipeline import make_pipeline
|
||||
|
||||
svr = make_pipeline(
|
||||
StandardScaler(),
|
||||
SVR(kernel='rbf', C=10, gamma=0.1) # 或 kernel='linear'
|
||||
)
|
||||
svr.fit(x_pre.reshape(-1, 1), pre.values)
|
||||
svr_pred = svr.predict(x_future.reshape(-1, 1))
|
||||
|
||||
# 目标预测索引(未来可能超出历史范围 —— 用 reindex,不要 .loc[...])
|
||||
post_index = pd.date_range(intervention_date, periods=days, freq='D')
|
||||
|
||||
glm_pred = pd.Series(glm_pred, index=post_index, name='glm_pred')
|
||||
svr_pred = pd.Series(svr_pred, index=post_index, name='svr_pred')
|
||||
|
||||
# ✅ 关键修复:对不存在的日期补 NaN,而不是 .loc[post_index]
|
||||
post = series.reindex(post_index)
|
||||
|
||||
residuals = pd.Series(post.values - svr_pred[:len(post)],
|
||||
index=post_index, name='residual')
|
||||
|
||||
return glm_pred, svr_pred, residuals
|
||||
|
||||
|
||||
def evaluate_strategy_effectiveness(actual_series: pd.Series,
|
||||
counterfactual_series: pd.Series,
|
||||
severity_series: pd.Series,
|
||||
strategy_date: pd.Timestamp,
|
||||
window: int = 30):
|
||||
strategy_date = pd.to_datetime(strategy_date)
|
||||
pre_sev = severity_series.loc[strategy_date - pd.Timedelta(days=window):strategy_date - pd.Timedelta(days=1)].sum()
|
||||
post_sev = severity_series.loc[strategy_date:strategy_date + pd.Timedelta(days=window - 1)].sum()
|
||||
actual_post = actual_series.loc[strategy_date:strategy_date + pd.Timedelta(days=window - 1)]
|
||||
counter_post = counterfactual_series.loc[strategy_date:strategy_date + pd.Timedelta(days=window - 1)]
|
||||
counter_post = counter_post.reindex(actual_post.index)
|
||||
effective_days = (actual_post < counter_post).sum()
|
||||
count_effective = effective_days >= (window / 2)
|
||||
severity_effective = post_sev < pre_sev
|
||||
cf_sum = counter_post.sum()
|
||||
F1 = (cf_sum - actual_post.sum()) / cf_sum if cf_sum > 0 else 0.0
|
||||
F2 = (pre_sev - post_sev) / pre_sev if pre_sev > 0 else 0.0
|
||||
if F1 > 0.5 and F2 > 0.5:
|
||||
safety_state = '一级'
|
||||
elif F1 > 0.3:
|
||||
safety_state = '二级'
|
||||
else:
|
||||
safety_state = '三级'
|
||||
return count_effective, severity_effective, (F1, F2), safety_state
|
||||
|
||||
|
||||
def generate_output_and_recommendations(combined_data: pd.DataFrame,
|
||||
strategy_types: list,
|
||||
region: str = '全市',
|
||||
horizon: int = 30):
|
||||
results = {}
|
||||
accident_series = combined_data['accident_count']
|
||||
severity_series = combined_data['severity']
|
||||
for strategy in strategy_types:
|
||||
has_strategy = combined_data['strategy_type'].apply(lambda x: strategy in x)
|
||||
if not has_strategy.any():
|
||||
continue
|
||||
intervention_date = has_strategy[has_strategy].index[0]
|
||||
glm_pred, svr_pred, residuals = fit_and_extrapolate(accident_series, intervention_date, days=horizon)
|
||||
if svr_pred is None:
|
||||
continue
|
||||
count_eff, sev_eff, (F1, F2), state = evaluate_strategy_effectiveness(
|
||||
actual_series=accident_series,
|
||||
counterfactual_series=svr_pred,
|
||||
severity_series=severity_series,
|
||||
strategy_date=intervention_date,
|
||||
window=horizon
|
||||
)
|
||||
results[strategy] = {
|
||||
'effect_strength': float(residuals.mean()),
|
||||
'adaptability': float(F1 + F2),
|
||||
'count_effective': bool(count_eff),
|
||||
'severity_effective': bool(sev_eff),
|
||||
'safety_state': state,
|
||||
'F1': float(F1),
|
||||
'F2': float(F2),
|
||||
'intervention_date': str(intervention_date.date())
|
||||
}
|
||||
best_strategy = max(results, key=lambda x: results[x]['adaptability']) if results else None
|
||||
recommendation = f"建议在{region}区域长期实施策略类型 {best_strategy}" if best_strategy else "无足够数据推荐策略"
|
||||
pd.DataFrame(results).T.to_csv('strategy_evaluation_results.csv', encoding='utf-8-sig')
|
||||
with open('recommendation.txt', 'w', encoding='utf-8') as f:
|
||||
f.write(recommendation)
|
||||
return results, recommendation
|
||||
|
||||
|
||||
# =======================
|
||||
# 3. UI Helpers
|
||||
# =======================
|
||||
def hash_like(obj: str) -> str:
|
||||
return hashlib.md5(obj.encode('utf-8')).hexdigest()[:8]
|
||||
|
||||
|
||||
def compute_kpis(df_city: pd.DataFrame, arima_df: pd.DataFrame | None,
|
||||
def compute_kpis(df_city: pd.DataFrame, arima_df: Optional[pd.DataFrame],
|
||||
today: pd.Timestamp, window:int=30):
|
||||
# 今日/昨日
|
||||
today_date = pd.to_datetime(today.date())
|
||||
@@ -451,113 +188,70 @@ def save_fig_as_html(fig, filename):
|
||||
f.write(html)
|
||||
return filename
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from sklearn.metrics import mean_absolute_error, mean_squared_error
|
||||
from statsmodels.tsa.arima.model import ARIMA
|
||||
|
||||
# 依赖:已在脚本前面定义的 knn_forecast_counterfactual() 和 fit_and_extrapolate()
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from sklearn.metrics import mean_absolute_error, mean_squared_error
|
||||
from statsmodels.tsa.arima.model import ARIMA
|
||||
|
||||
# 依赖:knn_forecast_counterfactual、fit_and_extrapolate 已存在
|
||||
|
||||
def evaluate_models(series: pd.Series,
|
||||
horizon: int = 30,
|
||||
lookback: int = 14,
|
||||
p_values: range = range(0, 6),
|
||||
d_values: range = range(0, 2),
|
||||
q_values: range = range(0, 6)) -> pd.DataFrame:
|
||||
"""
|
||||
留出法(最后 horizon 天作为验证集)比较 ARIMA / KNN / GLM / SVR,
|
||||
输出 MAE・RMSE・MAPE,并按 RMSE 升序排序。
|
||||
"""
|
||||
# 统一日频 & 缺失补零
|
||||
series = series.asfreq('D').fillna(0)
|
||||
if len(series) <= horizon + 10:
|
||||
raise ValueError("序列太短,无法留出 %d 天进行评估。" % horizon)
|
||||
|
||||
train, test = series.iloc[:-horizon], series.iloc[-horizon:]
|
||||
|
||||
def _to_series_like(pred, a_index):
|
||||
# 将任意预测对齐成与 actual 同索引的 Series
|
||||
if isinstance(pred, pd.Series):
|
||||
return pred.reindex(a_index)
|
||||
return pd.Series(pred, index=a_index)
|
||||
|
||||
def _metrics(a: pd.Series, p) -> dict:
|
||||
p = _to_series_like(p, a.index).astype(float)
|
||||
a = a.astype(float)
|
||||
|
||||
mae = mean_absolute_error(a, p)
|
||||
|
||||
# 兼容旧版 sklearn:没有 squared 参数时手动开方
|
||||
try:
|
||||
rmse = mean_squared_error(a, p, squared=False)
|
||||
except TypeError:
|
||||
rmse = mean_squared_error(a, p) ** 0.5
|
||||
|
||||
# 忽略分母为 0 的样本
|
||||
mape = np.nanmean(np.abs((a - p) / np.where(a == 0, np.nan, a))) * 100
|
||||
return {"MAE": mae, "RMSE": rmse, "MAPE": mape}
|
||||
|
||||
results = {}
|
||||
|
||||
# ---------- ARIMA ----------
|
||||
best_aic, best_order = float('inf'), None
|
||||
for p in p_values:
|
||||
for d in d_values:
|
||||
for q in q_values:
|
||||
try:
|
||||
aic = ARIMA(train, order=(p, d, q)).fit().aic
|
||||
if aic < best_aic:
|
||||
best_aic, best_order = aic, (p, d, q)
|
||||
except Exception:
|
||||
continue
|
||||
arima_train = train.asfreq('D').fillna(0)
|
||||
arima_pred = ARIMA(arima_train, order=best_order).fit().forecast(steps=horizon)
|
||||
results['ARIMA'] = _metrics(test, arima_pred)
|
||||
|
||||
# ---------- KNN ----------
|
||||
try:
|
||||
knn_pred, _ = knn_forecast_counterfactual(series,
|
||||
train.index[-1] + pd.Timedelta(days=1),
|
||||
lookback=lookback,
|
||||
horizon=horizon)
|
||||
if knn_pred is not None:
|
||||
results['KNN'] = _metrics(test, knn_pred)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# ---------- GLM & SVR ----------
|
||||
try:
|
||||
glm_pred, svr_pred, _ = fit_and_extrapolate(series,
|
||||
train.index[-1] + pd.Timedelta(days=1),
|
||||
days=horizon)
|
||||
if glm_pred is not None:
|
||||
results['GLM'] = _metrics(test, glm_pred)
|
||||
if svr_pred is not None:
|
||||
results['SVR'] = _metrics(test, svr_pred)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return (pd.DataFrame(results)
|
||||
.T.sort_values('RMSE')
|
||||
.round(3))
|
||||
# =======================
|
||||
# 4. App
|
||||
# =======================
|
||||
|
||||
|
||||
# =======================
|
||||
# 4. App
|
||||
# =======================
|
||||
def run_streamlit_app():
|
||||
# Must be the first Streamlit command
|
||||
st.set_page_config(page_title="Traffic Safety Analysis", layout="wide")
|
||||
st.title("🚦 Traffic Safety Intervention Analysis System")
|
||||
|
||||
# Sidebar — Upload & Global Filters & Auto Refresh
|
||||
st.sidebar.header("数据与筛选")
|
||||
|
||||
default_min_date = pd.to_datetime('2022-01-01').date()
|
||||
default_max_date = pd.to_datetime('2022-12-31').date()
|
||||
|
||||
def clamp_date_range(requested, minimum, maximum):
|
||||
"""Ensure the requested tuple stays within [minimum, maximum]."""
|
||||
if not isinstance(requested, (list, tuple)):
|
||||
requested = (requested, requested)
|
||||
start, end = requested
|
||||
if start > end:
|
||||
start, end = end, start
|
||||
if end < minimum or start > maximum:
|
||||
return minimum, maximum
|
||||
start = max(minimum, start)
|
||||
end = min(maximum, end)
|
||||
return start, end
|
||||
|
||||
# Initialize session state to store processed data (before rendering controls)
|
||||
if 'processed_data' not in st.session_state:
|
||||
st.session_state['processed_data'] = {
|
||||
'combined_city': None,
|
||||
'combined_by_region': None,
|
||||
'accident_data': None,
|
||||
'accident_records': None,
|
||||
'strategy_data': None,
|
||||
'all_regions': ["全市"],
|
||||
'all_strategy_types': [],
|
||||
'min_date': default_min_date,
|
||||
'max_date': default_max_date,
|
||||
'region_sel': "全市",
|
||||
'date_range': (default_min_date, default_max_date),
|
||||
'strat_filter': [],
|
||||
'accident_source_name': None,
|
||||
}
|
||||
|
||||
sidebar_state = st.session_state['processed_data']
|
||||
|
||||
available_regions = sidebar_state['all_regions'] if sidebar_state['all_regions'] else ["全市"]
|
||||
current_region = sidebar_state['region_sel'] if sidebar_state['region_sel'] in available_regions else available_regions[0]
|
||||
available_strategies = sidebar_state['all_strategy_types']
|
||||
current_strategies = [s for s in sidebar_state['strat_filter'] if s in available_strategies]
|
||||
|
||||
min_date = sidebar_state['min_date']
|
||||
max_date = sidebar_state['max_date']
|
||||
raw_start, raw_end = sidebar_state['date_range']
|
||||
start_default = max(min_date, min(raw_start, max_date))
|
||||
end_default = max(start_default, min(raw_end, max_date))
|
||||
|
||||
# Create a form for data inputs to batch updates
|
||||
with st.sidebar.form(key="data_input_form"):
|
||||
accident_file = st.file_uploader("上传事故数据 (Excel)", type=['xlsx'])
|
||||
@@ -566,13 +260,24 @@ def run_streamlit_app():
|
||||
# Global filters
|
||||
st.markdown("---")
|
||||
st.subheader("全局筛选器")
|
||||
# Placeholder for region selection (will be populated after data is loaded)
|
||||
region_sel = st.selectbox("区域", options=["全市"], index=0, key="region_select")
|
||||
# Default date range (will be updated after data is loaded)
|
||||
min_date = pd.to_datetime('2022-01-01').date()
|
||||
max_date = pd.to_datetime('2022-12-31').date()
|
||||
date_range = st.date_input("时间范围", value=(min_date, max_date), min_value=min_date, max_value=max_date)
|
||||
strat_filter = st.multiselect("策略类型(过滤)", options=[], help="为空表示不过滤策略;选择后仅保留当天包含所选策略的日期")
|
||||
region_sel = st.selectbox(
|
||||
"区域",
|
||||
options=available_regions,
|
||||
index=available_regions.index(current_region),
|
||||
key="region_select",
|
||||
)
|
||||
date_range = st.date_input(
|
||||
"时间范围",
|
||||
value=(start_default, end_default),
|
||||
min_value=min_date,
|
||||
max_value=max_date,
|
||||
)
|
||||
strat_filter = st.multiselect(
|
||||
"策略类型(过滤)",
|
||||
options=available_strategies,
|
||||
default=current_strategies,
|
||||
help="为空表示不过滤策略;选择后仅保留当天包含所选策略的日期",
|
||||
)
|
||||
|
||||
# Apply button for data loading and filtering
|
||||
apply_button = st.form_submit_button("应用数据与筛选")
|
||||
@@ -589,30 +294,15 @@ def run_streamlit_app():
|
||||
|
||||
# Add OpenAI API key input in sidebar
|
||||
st.sidebar.markdown("---")
|
||||
st.sidebar.subheader("GPT API 配置")
|
||||
openai_api_key = st.sidebar.text_input("GPT API Key", value='sk-dQhKOOG48iVEfgJfAb14458dA4474fB09aBbE8153d4aB3Fc', type="password", help="用于GPT分析结果的API密钥")
|
||||
open_ai_base_url = st.sidebar.text_input("GPT Base Url", value='https://az.gptplus5.com/v1', type='default')
|
||||
|
||||
# Initialize session state to store processed data
|
||||
if 'processed_data' not in st.session_state:
|
||||
st.session_state['processed_data'] = {
|
||||
'combined_city': None,
|
||||
'combined_by_region': None,
|
||||
'accident_data': None,
|
||||
'strategy_data': None,
|
||||
'all_regions': ["全市"],
|
||||
'all_strategy_types': [],
|
||||
'min_date': min_date,
|
||||
'max_date': max_date,
|
||||
'region_sel': "全市",
|
||||
'date_range': (min_date, max_date),
|
||||
'strat_filter': []
|
||||
}
|
||||
st.sidebar.subheader("AI API 配置")
|
||||
openai_api_key = st.sidebar.text_input("AI API Key", value='sk-sXY934yPqjh7YKKC08380b198fEb47308cDa09BeE23d9c8a', type="password", help="用于 AI 分析结果的 API 密钥")
|
||||
open_ai_base_url = st.sidebar.text_input("AI Base Url", value='https://aihubmix.com/v1', type='default')
|
||||
|
||||
# Process data only when Apply button is clicked
|
||||
if apply_button and accident_file and strategy_file:
|
||||
with st.spinner("数据载入中…"):
|
||||
# Load and clean data
|
||||
accident_records = load_accident_records(accident_file, require_location=True)
|
||||
accident_data, strategy_data = load_and_clean_data(accident_file, strategy_file)
|
||||
combined_city = aggregate_daily_data(accident_data, strategy_data)
|
||||
combined_by_region = aggregate_daily_data_by_region(accident_data, strategy_data)
|
||||
@@ -624,24 +314,35 @@ def run_streamlit_app():
|
||||
max_date = combined_city.index.max().date()
|
||||
|
||||
# Store processed data in session state
|
||||
sanitized_start, sanitized_end = clamp_date_range(date_range, min_date, max_date)
|
||||
st.session_state['processed_data'].update({
|
||||
'combined_city': combined_city,
|
||||
'combined_by_region': combined_by_region,
|
||||
'accident_data': accident_data,
|
||||
'accident_records': accident_records,
|
||||
'strategy_data': strategy_data,
|
||||
'all_regions': all_regions,
|
||||
'all_strategy_types': all_strategy_types,
|
||||
'min_date': min_date,
|
||||
'max_date': max_date,
|
||||
'region_sel': region_sel,
|
||||
'date_range': date_range,
|
||||
'strat_filter': strat_filter
|
||||
'date_range': (sanitized_start, sanitized_end),
|
||||
'strat_filter': strat_filter,
|
||||
'accident_source_name': getattr(accident_file, "name", "事故数据.xlsx"),
|
||||
})
|
||||
|
||||
sanitized_start, sanitized_end = clamp_date_range(date_range, min_date, max_date)
|
||||
|
||||
# Persist the latest sidebar selections for display and downstream filtering
|
||||
st.session_state['processed_data']['region_sel'] = region_sel
|
||||
st.session_state['processed_data']['date_range'] = (sanitized_start, sanitized_end)
|
||||
st.session_state['processed_data']['strat_filter'] = strat_filter
|
||||
|
||||
# Retrieve data from session state
|
||||
combined_city = st.session_state['processed_data']['combined_city']
|
||||
combined_by_region = st.session_state['processed_data']['combined_by_region']
|
||||
accident_data = st.session_state['processed_data']['accident_data']
|
||||
accident_records = st.session_state['processed_data']['accident_records']
|
||||
strategy_data = st.session_state['processed_data']['strategy_data']
|
||||
all_regions = st.session_state['processed_data']['all_regions']
|
||||
all_strategy_types = st.session_state['processed_data']['all_strategy_types']
|
||||
@@ -650,6 +351,7 @@ def run_streamlit_app():
|
||||
region_sel = st.session_state['processed_data']['region_sel']
|
||||
date_range = st.session_state['processed_data']['date_range']
|
||||
strat_filter = st.session_state['processed_data']['strat_filter']
|
||||
accident_source_name = st.session_state['processed_data']['accident_source_name']
|
||||
|
||||
# Update selectbox and multiselect options dynamically (outside the form for display)
|
||||
st.sidebar.markdown("---")
|
||||
@@ -700,44 +402,128 @@ def run_streamlit_app():
|
||||
with meta_col2:
|
||||
st.caption(f"🕒 最近刷新:{last_refresh.strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
|
||||
# Tabs (add new tab for GPT analysis)
|
||||
tab_dash, tab_pred, tab_eval, tab_anom, tab_strat, tab_comp, tab_sim, tab_gpt = st.tabs(
|
||||
["🏠 总览", "📈 预测模型", "📊 模型评估", "⚠️ 异常检测", "📝 策略评估", "⚖️ 策略对比", "🧪 情景模拟", "🔍 GPT 分析"]
|
||||
tab_labels = [
|
||||
"🏠 总览",
|
||||
"📍 事故热点",
|
||||
"🔍 AI 分析",
|
||||
"📈 预测模型",
|
||||
"📊 模型评估",
|
||||
"⚠️ 异常检测",
|
||||
"📝 策略评估",
|
||||
"⚖️ 策略对比",
|
||||
"🧪 情景模拟",
|
||||
]
|
||||
default_tab = st.session_state.get("active_tab", tab_labels[0])
|
||||
if default_tab not in tab_labels:
|
||||
default_tab = tab_labels[0]
|
||||
selected_tab = st.radio(
|
||||
"功能分区",
|
||||
tab_labels,
|
||||
index=tab_labels.index(default_tab),
|
||||
horizontal=True,
|
||||
label_visibility="collapsed",
|
||||
)
|
||||
st.session_state["active_tab"] = selected_tab
|
||||
|
||||
# --- Tab 1: 总览页
|
||||
with tab_dash:
|
||||
fig_line = go.Figure()
|
||||
fig_line.add_trace(go.Scatter(x=base.index, y=base['accident_count'], name='事故数', mode='lines'))
|
||||
fig_line.update_layout(title="事故数(过滤后)", xaxis_title="Date", yaxis_title="Count")
|
||||
st.plotly_chart(fig_line, use_container_width=True)
|
||||
fname = save_fig_as_html(fig_line, "overview_series.html")
|
||||
st.download_button("下载图表 HTML", data=open(fname, 'rb').read(),
|
||||
file_name="overview_series.html", mime="text/html")
|
||||
|
||||
st.dataframe(base, use_container_width=True)
|
||||
csv_bytes = base.to_csv(index=True).encode('utf-8-sig')
|
||||
st.download_button("下载当前视图 CSV", data=csv_bytes, file_name="filtered_view.csv", mime="text/csv")
|
||||
if selected_tab == "🏠 总览":
|
||||
if render_overview is not None:
|
||||
render_overview(base, region_sel, start_dt, end_dt, strat_filter)
|
||||
else:
|
||||
st.warning("概览模块未能加载,请检查 `ui_sections/overview.py`。")
|
||||
|
||||
meta = {
|
||||
"region": region_sel,
|
||||
"date_range": [str(start_dt.date()), str(end_dt.date())],
|
||||
"strategy_filter": strat_filter,
|
||||
"rows": int(len(base)),
|
||||
"min_date": str(base.index.min().date()) if len(base) else None,
|
||||
"max_date": str(base.index.max().date()) if len(base) else None
|
||||
elif selected_tab == "📍 事故热点":
|
||||
if render_hotspot is not None:
|
||||
render_hotspot(accident_records, accident_source_name)
|
||||
else:
|
||||
st.warning("事故热点模块未能加载,请检查 `ui_sections/hotspot.py`。")
|
||||
|
||||
elif selected_tab == "🔍 AI 分析":
|
||||
from openai import OpenAI
|
||||
st.subheader("AI 数据分析与改进建议")
|
||||
if not HAS_OPENAI:
|
||||
st.warning("未安装 `openai` 库。请安装后重试。")
|
||||
elif not openai_api_key:
|
||||
st.info("请在左侧边栏输入 OpenAI API Key 以启用 AI 分析。")
|
||||
else:
|
||||
if all_strategy_types:
|
||||
# Generate results if not already
|
||||
results, recommendation = generate_output_and_recommendations(base, all_strategy_types,
|
||||
region=region_sel if region_sel != '全市' else '全市')
|
||||
df_res = pd.DataFrame(results).T
|
||||
kpi_json = json.dumps(kpi, ensure_ascii=False, indent=2)
|
||||
results_json = df_res.to_json(orient="records", force_ascii=False)
|
||||
recommendation_text = recommendation
|
||||
|
||||
# Prepare data to send
|
||||
data_to_analyze = {
|
||||
"kpis": kpi_json,
|
||||
"strategy_results": results_json,
|
||||
"recommendation": recommendation_text
|
||||
}
|
||||
with open("run_metadata.json", "w", encoding="utf-8") as f:
|
||||
json.dump(meta, f, ensure_ascii=False, indent=2)
|
||||
st.download_button("下载运行参数 JSON", data=open("run_metadata.json", "rb").read(),
|
||||
file_name="run_metadata.json", mime="application/json")
|
||||
data_str = json.dumps(data_to_analyze, ensure_ascii=False)
|
||||
|
||||
# --- Tab 2: 预测模型
|
||||
with tab_pred:
|
||||
prompt = (
|
||||
"你是一名资深交通安全数据分析顾问。请基于以下结构化数据输出一份专业报告,需包含:\n"
|
||||
"1. 核心指标洞察:按要点总结事故趋势、显著波动及可能原因。\n"
|
||||
"2. 策略绩效评估:对比主要策略的优势、短板与适用场景。\n"
|
||||
"3. 优化建议:为短期(0-3个月)、中期(3-12个月)与长期(12个月以上)分别给出2-3条可操作措施。\n"
|
||||
"请保持正式语气,引用关键数值支撑结论,并用清晰的小节或列表呈现。\n"
|
||||
f"数据摘要:{data_str}\n"
|
||||
)
|
||||
if st.button("上传数据至 AI 并获取分析"):
|
||||
if not openai_api_key.strip():
|
||||
st.info("请提供有效的 AI API Key。")
|
||||
elif not open_ai_base_url.strip():
|
||||
st.info("请提供可访问的 AI Base Url。")
|
||||
else:
|
||||
try:
|
||||
client = OpenAI(
|
||||
base_url=open_ai_base_url,
|
||||
# sk-xxx替换为自己的key
|
||||
api_key=openai_api_key
|
||||
)
|
||||
st.markdown("### AI 分析结果与改进思路")
|
||||
placeholder = st.empty()
|
||||
accumulated_response: list[str] = []
|
||||
with st.spinner("AI 正在生成专业报告,请稍候…"):
|
||||
stream = client.chat.completions.create(
|
||||
model="gpt-5-mini",
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a professional traffic safety analyst who writes concise, well-structured Chinese reports."
|
||||
},
|
||||
{"role": "user", "content": prompt},
|
||||
],
|
||||
stream=True,
|
||||
)
|
||||
for chunk in stream:
|
||||
delta = chunk.choices[0].delta if chunk.choices else None
|
||||
piece = getattr(delta, "content", None) if delta else None
|
||||
if piece:
|
||||
accumulated_response.append(piece)
|
||||
placeholder.markdown("".join(accumulated_response), unsafe_allow_html=True)
|
||||
final_text = "".join(accumulated_response)
|
||||
if not final_text:
|
||||
placeholder.info("AI 未返回可用内容,请稍后重试或检查凭据配置。")
|
||||
except Exception as e:
|
||||
st.error(f"调用 OpenAI API 失败:{str(e)}")
|
||||
else:
|
||||
st.warning("没有策略数据可供分析。")
|
||||
|
||||
# Update refresh time
|
||||
st.session_state['last_refresh'] = datetime.now()
|
||||
|
||||
elif selected_tab == "📈 预测模型":
|
||||
if render_forecast is not None:
|
||||
render_forecast(base)
|
||||
else:
|
||||
st.subheader("多模型预测比较")
|
||||
# 使用表单封装交互组件
|
||||
with st.form(key="predict_form"):
|
||||
default_date = base.index.max() - pd.Timedelta(days=60) if len(base) else pd.Timestamp('2022-01-01')
|
||||
# 缩短默认回溯窗口,提升首次渲染速度
|
||||
default_date = base.index.max() - pd.Timedelta(days=30) if len(base) else pd.Timestamp('2022-01-01')
|
||||
selected_date = st.date_input("选择干预日期 / 预测起点", value=default_date)
|
||||
horizon = st.number_input("预测天数", min_value=7, max_value=90, value=30, step=1)
|
||||
submit_predict = st.form_submit_button("应用预测参数")
|
||||
@@ -796,7 +582,10 @@ def run_streamlit_app():
|
||||
st.info("请设置预测参数并点击“应用预测参数”按钮。")
|
||||
|
||||
# --- Tab 3: 模型评估
|
||||
with tab_eval:
|
||||
elif selected_tab == "📊 模型评估":
|
||||
if render_model_eval is not None:
|
||||
render_model_eval(base)
|
||||
else:
|
||||
st.subheader("模型预测效果对比")
|
||||
with st.form(key="model_eval_form"):
|
||||
horizon_sel = st.slider("评估窗口(天)", 7, 60, 30, step=1)
|
||||
@@ -820,7 +609,7 @@ def run_streamlit_app():
|
||||
st.info("请设置评估窗口并点击“应用评估参数”按钮。")
|
||||
|
||||
# --- Tab 4: 异常检测
|
||||
with tab_anom:
|
||||
elif selected_tab == "⚠️ 异常检测":
|
||||
anomalies, anomaly_fig = detect_anomalies(base['accident_count'])
|
||||
st.plotly_chart(anomaly_fig, use_container_width=True)
|
||||
st.write(f"检测到异常点:{len(anomalies)} 个")
|
||||
@@ -829,25 +618,14 @@ def run_streamlit_app():
|
||||
file_name="anomalies.csv", mime="text/csv")
|
||||
|
||||
# --- Tab 5: 策略评估
|
||||
with tab_strat:
|
||||
st.info(f"📌 检测到的策略类型:{', '.join(all_strategy_types) or '(数据中没有策略)'}")
|
||||
if all_strategy_types:
|
||||
results, recommendation = generate_output_and_recommendations(base, all_strategy_types,
|
||||
region=region_sel if region_sel!='全市' else '全市')
|
||||
st.subheader("各策略指标")
|
||||
df_res = pd.DataFrame(results).T
|
||||
st.dataframe(df_res, use_container_width=True)
|
||||
st.success(f"⭐ 推荐:{recommendation}")
|
||||
st.download_button("下载策略评估 CSV",
|
||||
data=df_res.to_csv().encode('utf-8-sig'),
|
||||
file_name="strategy_evaluation_results.csv", mime="text/csv")
|
||||
with open('recommendation.txt','r',encoding='utf-8') as f:
|
||||
st.download_button("下载推荐文本", data=f.read().encode('utf-8'), file_name="recommendation.txt")
|
||||
elif selected_tab == "📝 策略评估":
|
||||
if render_strategy_eval is not None:
|
||||
render_strategy_eval(base, all_strategy_types, region_sel)
|
||||
else:
|
||||
st.warning("数据中没有检测到策略。")
|
||||
st.warning("策略评估模块不可用,请检查 `ui_sections/strategy_eval.py`。")
|
||||
|
||||
# --- Tab 6: 策略对比
|
||||
with tab_comp:
|
||||
elif selected_tab == "⚖️ 策略对比":
|
||||
def strategy_metrics(strategy):
|
||||
mask = base['strategy_type'].apply(lambda x: strategy in x)
|
||||
if not mask.any():
|
||||
@@ -915,7 +693,7 @@ def run_streamlit_app():
|
||||
st.warning("没有策略可供对比。")
|
||||
|
||||
# --- Tab 7: 情景模拟
|
||||
with tab_sim:
|
||||
elif selected_tab == "🧪 情景模拟":
|
||||
st.subheader("情景模拟")
|
||||
st.write("选择一个日期与策略,模拟“在该日期上线该策略”的影响:")
|
||||
with st.form(key="simulation_form"):
|
||||
@@ -951,65 +729,6 @@ def run_streamlit_app():
|
||||
else:
|
||||
st.info("请设置模拟参数并点击“应用模拟参数”按钮。")
|
||||
|
||||
# --- New Tab 8: GPT 分析
|
||||
with tab_gpt:
|
||||
from openai import OpenAI
|
||||
st.subheader("GPT 数据分析与改进建议")
|
||||
# open_ai_key = f"sk-dQhKOOG48iVEfgJfAb14458dA4474fB09aBbE8153d4aB3Fc"
|
||||
if not HAS_OPENAI:
|
||||
st.warning("未安装 `openai` 库。请安装后重试。")
|
||||
elif not openai_api_key:
|
||||
st.info("请在左侧边栏输入 OpenAI API Key 以启用 GPT 分析。")
|
||||
else:
|
||||
if all_strategy_types:
|
||||
# Generate results if not already
|
||||
results, recommendation = generate_output_and_recommendations(base, all_strategy_types,
|
||||
region=region_sel if region_sel != '全市' else '全市')
|
||||
df_res = pd.DataFrame(results).T
|
||||
kpi_json = json.dumps(kpi, ensure_ascii=False, indent=2)
|
||||
results_json = df_res.to_json(orient="records", force_ascii=False)
|
||||
recommendation_text = recommendation
|
||||
|
||||
# Prepare data to send
|
||||
data_to_analyze = {
|
||||
"kpis": kpi_json,
|
||||
"strategy_results": results_json,
|
||||
"recommendation": recommendation_text
|
||||
}
|
||||
data_str = json.dumps(data_to_analyze, ensure_ascii=False)
|
||||
|
||||
prompt = str(f"""
|
||||
请分析以下交通安全分析结果,包括KPI指标、策略评估结果和推荐。
|
||||
提供数据结果的详细分析,以及改进思路和建议。
|
||||
数据:{str(data_str)}
|
||||
""")
|
||||
#st.text_area(prompt)
|
||||
if st.button("上传数据至 GPT 并获取分析"):
|
||||
try:
|
||||
client = OpenAI(
|
||||
base_url=open_ai_base_url,
|
||||
# sk-xxx替换为自己的key
|
||||
api_key=openai_api_key
|
||||
)
|
||||
response = client.chat.completions.create(
|
||||
model="gpt-4o",
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful assistant that analyzes traffic safety data."},
|
||||
{"role": "user", "content": prompt}
|
||||
],
|
||||
stream=False
|
||||
)
|
||||
gpt_response = response.choices[0].message.content
|
||||
st.markdown("### GPT 分析结果与改进思路")
|
||||
st.markdown(gpt_response, unsafe_allow_html=True)
|
||||
except Exception as e:
|
||||
st.error(f"调用 OpenAI API 失败:{str(e)}")
|
||||
else:
|
||||
st.warning("没有策略数据可供分析。")
|
||||
|
||||
# Update refresh time
|
||||
st.session_state['last_refresh'] = datetime.now()
|
||||
|
||||
else:
|
||||
st.info("请先在左侧上传事故数据与策略数据,并点击“应用数据与筛选”按钮。")
|
||||
|
||||
|
||||
21
config/settings.py
Normal file
21
config/settings.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from __future__ import annotations
|
||||
|
||||
# Default configuration and feature flags
|
||||
|
||||
# Forecasting
|
||||
ARIMA_P = range(0, 4)
|
||||
ARIMA_D = range(0, 2)
|
||||
ARIMA_Q = range(0, 4)
|
||||
|
||||
DEFAULT_HORIZON_PREDICT = 30
|
||||
DEFAULT_HORIZON_EVAL = 14
|
||||
MIN_PRE_DAYS = 5
|
||||
MAX_PRE_DAYS = 120
|
||||
|
||||
# Anomaly detection
|
||||
ANOMALY_N_ESTIMATORS = 50
|
||||
ANOMALY_CONTAMINATION = 0.10
|
||||
|
||||
# Performance flags
|
||||
FAST_MODE = True
|
||||
|
||||
@@ -0,0 +1,78 @@
|
||||
# Installation Guide
|
||||
|
||||
This document explains how to set up TrafficSafeAnalyzer for local development and exploration. The application runs on Streamlit and officially supports Python 3.8.
|
||||
|
||||
## Prerequisites
|
||||
- Python 3.8 (3.9+ is not yet validated; use 3.8 to avoid dependency issues)
|
||||
- Git
|
||||
- `pip` (bundled with Python)
|
||||
- Optional: Conda (for environment management) or Docker (for container-based runs)
|
||||
|
||||
## 1. Obtain the source code
|
||||
|
||||
```bash
|
||||
git clone https://github.com/tongnian0613/TrafficSafeAnalyzer.git
|
||||
cd TrafficSafeAnalyzer
|
||||
```
|
||||
|
||||
If you already have the repository, pull the latest changes instead:
|
||||
|
||||
```bash
|
||||
git pull origin main
|
||||
```
|
||||
|
||||
## 2. Create a dedicated environment
|
||||
|
||||
### Option A: Built-in virtual environment
|
||||
|
||||
```bash
|
||||
python -m venv .venv
|
||||
source .venv/bin/activate # Windows: .venv\Scripts\activate
|
||||
```
|
||||
|
||||
### Option B: Conda environment
|
||||
|
||||
```bash
|
||||
conda create -n trafficsa python=3.8 -y
|
||||
conda activate trafficsa
|
||||
```
|
||||
|
||||
## 3. Install project dependencies
|
||||
|
||||
Install the full dependency set listed in `requirements.txt`:
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
If you prefer a minimal installation before pulling in extras, install the core stack first:
|
||||
|
||||
```bash
|
||||
pip install streamlit pandas numpy matplotlib plotly scikit-learn statsmodels scipy
|
||||
```
|
||||
|
||||
Then add optional packages as needed (Excel readers, auto-refresh, OpenAI integration):
|
||||
|
||||
```bash
|
||||
pip install streamlit-autorefresh openpyxl xlrd cryptography openai
|
||||
```
|
||||
|
||||
## 4. Verify the setup
|
||||
|
||||
1. Ensure the environment is still active (`which python` should point to `.venv` or the conda env).
|
||||
2. Launch the Streamlit app:
|
||||
|
||||
```bash
|
||||
streamlit run app.py
|
||||
```
|
||||
|
||||
3. Open `http://localhost:8501` in your browser. The home page should load without import errors.
|
||||
|
||||
## Troubleshooting tips
|
||||
|
||||
- **Missing package**: Re-run `pip install -r requirements.txt`.
|
||||
- **Python version mismatch**: Confirm `python --version` reports 3.8.x inside your environment.
|
||||
- **OpenSSL or cryptography errors** (macOS/Linux): Update the system OpenSSL libraries and reinstall `cryptography`.
|
||||
- **Taking too long to install**: if a dependency download stalls due to a firewall, retry using a mirror (`-i https://pypi.tuna.tsinghua.edu.cn/simple`) consistent with your environment policy.
|
||||
|
||||
After a successful launch, continue with the usage guide in `docs/usage.md` to load data and explore forecasts.
|
||||
|
||||
@@ -0,0 +1,73 @@
|
||||
# Usage Guide
|
||||
|
||||
TrafficSafeAnalyzer delivers accident analytics and decision support through a Streamlit interface. This guide walks through the daily workflow, expected inputs, and where to find generated artefacts.
|
||||
|
||||
## Start the app
|
||||
|
||||
1. Activate your virtual or conda environment.
|
||||
2. From the project root, run:
|
||||
|
||||
```bash
|
||||
streamlit run app.py
|
||||
```
|
||||
|
||||
3. Open `http://localhost:8501`. Keep the terminal running while you work in the browser.
|
||||
|
||||
## Load input data
|
||||
|
||||
Use the sidebar form labelled “数据与筛选”.
|
||||
|
||||
- **Accident data (`.xlsx`)** — columns should include at minimum:
|
||||
- `事故时间` (timestamp)
|
||||
- `所在街道` (region or district)
|
||||
- `事故类型`
|
||||
- `事故数`/`accident_count` (if absent, the loader aggregates counts)
|
||||
- **Strategy data (`.xlsx`)** — include:
|
||||
- `发布时间`
|
||||
- `交通策略类型`
|
||||
- optional descriptors such as `策略名称`, `策略内容`
|
||||
- Select the global filters (region, date window, strategy filter) and click `应用数据与筛选`.
|
||||
- Uploaded files are cached. Upload a new file or press “Rerun” to refresh after making edits.
|
||||
- Sample datasets for rapid smoke testing live in `sample/事故/*.xlsx` (accidents) and `sample/交通策略/*.xlsx` (strategies); copy them before making modifications.
|
||||
|
||||
> Tip: `services/io.py` performs validation; rows missing key columns are dropped with a warning in the Streamlit log.
|
||||
|
||||
## Navigate the workspace
|
||||
|
||||
- **🏠 总览 (Overview)** — KPI cards, time-series plot, filtered table, and download buttons for HTML (`overview_series.html`), CSV (`filtered_view.csv`), and run metadata (`run_metadata.json`).
|
||||
- **📈 预测模型 (Forecast)** — choose an intervention date and horizon, compare ARIMA / KNN / GLM / SVR forecasts, and export `arima_forecast.csv`(提交后结果会在同一数据集下保留,便于调整其他控件)。
|
||||
- **📊 模型评估 (Model evaluation)** — run rolling-window backtests, inspect RMSE/MAE/MAPE, and download `model_evaluation.csv`.
|
||||
- **⚠️ 异常检测 (Anomaly detection)** — isolation forest marks outliers on the accident series; tweak contamination via the main page controls.
|
||||
- **📝 策略评估 (Strategy evaluation)** — Aggregates metrics per strategy type, recommends the best option, writes `strategy_evaluation_results.csv`, and updates `recommendation.txt`.
|
||||
- **⚖️ 策略对比 (Strategy comparison)** — side-by-side metrics for selected strategies, useful for “what worked best last month” reviews.
|
||||
- **🧪 情景模拟 (Scenario simulation)** — apply intervention models (persistent/decay, lagged effects) to test potential roll-outs.
|
||||
- **🔍 AI 分析** — 默认示例 API Key/Base URL 已预填,可直接体验;如需切换自有凭据,可在侧边栏更新后生成洞察(运行时读取,不会写入磁盘)。
|
||||
- **📍 事故热点 (Hotspot)** — reuse the already uploaded accident data to identify high-risk intersections and produce targeted mitigation ideas; no separate hotspot upload is required.
|
||||
|
||||
Each tab remembers the active filters from the sidebar so results stay consistent.
|
||||
|
||||
## Downloaded artefacts
|
||||
|
||||
Generated files are saved to the project root unless you override paths in the code:
|
||||
|
||||
- `overview_series.html`
|
||||
- `filtered_view.csv`
|
||||
- `run_metadata.json`
|
||||
- `arima_forecast.csv`
|
||||
- `model_evaluation.csv`
|
||||
- `strategy_evaluation_results.csv`
|
||||
- `recommendation.txt`
|
||||
|
||||
After a session, review and archive these outputs under `docs/` or a dated folder as needed.
|
||||
|
||||
## Operational tips
|
||||
|
||||
- **Auto refresh**: enable from the sidebar (requires `streamlit-autorefresh`). Set the interval in seconds for live dashboards.
|
||||
- **Logging**: set `LOG_LEVEL=DEBUG` before launch to see detailed diagnostics in the terminal and Streamlit log.
|
||||
- **Reset filters**: choose “全市” and the full date span, then re-run the sidebar form.
|
||||
- **Common warnings**:
|
||||
- *“数据中没有检测到策略”*: verify the strategy Excel file and column names.
|
||||
- *ARIMA failures*: shorten the horizon or ensure at least 10 historical data points before the intervention date.
|
||||
- *Hotspot data issues*: ensure the accident workbook includes `事故时间`, `所在街道`, `事故类型`, and `事故具体地点` so intersections can be resolved.
|
||||
|
||||
Need deeper integration or batch automation? Extract the core functions from `services/` and orchestrate them in a notebook or scheduled job.
|
||||
|
||||
26
environment.yml
Normal file
26
environment.yml
Normal file
@@ -0,0 +1,26 @@
|
||||
name: trafficsa
|
||||
channels:
|
||||
- conda-forge
|
||||
- defaults
|
||||
dependencies:
|
||||
- python=3.12
|
||||
- pip
|
||||
- streamlit>=1.20
|
||||
- pandas>=1.3
|
||||
- numpy>=1.21
|
||||
- matplotlib>=3.4
|
||||
- plotly>=5
|
||||
- scikit-learn>=1.0
|
||||
- statsmodels>=0.13
|
||||
- scipy>=1.7
|
||||
- pyarrow>=7
|
||||
- python-dateutil>=2.8.2
|
||||
- pytz>=2021.3
|
||||
- openpyxl>=3.0.9
|
||||
- xlrd>=2.0.1
|
||||
- cryptography>=3.4.7
|
||||
- requests
|
||||
- pip:
|
||||
- streamlit-autorefresh>=0.1.5
|
||||
- openai>=2.0.0
|
||||
- jieba>=0.42.1
|
||||
File diff suppressed because one or more lines are too long
23
readme.md
23
readme.md
@@ -8,6 +8,8 @@
|
||||
- 使用 ARIMA、KNN、GLM、SVR 等模型预测事故趋势
|
||||
- 检测异常事故点
|
||||
- 评估交通策略效果并提供推荐
|
||||
- 识别事故热点路口并生成风险分级与整治建议
|
||||
- 支持 AI 分析生成自然语言洞察
|
||||
|
||||
## 安装步骤
|
||||
|
||||
@@ -29,7 +31,7 @@ cd TrafficSafeAnalyzer
|
||||
2. 创建虚拟环境(推荐):
|
||||
|
||||
```bash
|
||||
conda create -n trafficsa python=3.12 -y
|
||||
conda create -n trafficsa python=3.8 -y
|
||||
conda activate trafficsa
|
||||
pip install -r requirements.txt
|
||||
streamlit run app.py
|
||||
@@ -85,10 +87,20 @@ openai>=2.0.0
|
||||
|
||||
## 配置参数
|
||||
|
||||
- **数据文件**:上传事故数据(`accident_file`)和策略数据(`strategy_file`),格式为 Excel
|
||||
- **数据文件**:上传事故数据(`accident_file`)和策略数据(`strategy_file`),格式为 Excel;事故热点分析会直接复用事故数据,无需额外上传。
|
||||
- **环境变量**(可选):
|
||||
- `LOG_LEVEL=DEBUG`:启用详细日志
|
||||
- 示例:`export LOG_LEVEL=DEBUG`(Linux/macOS)或 `set LOG_LEVEL=DEBUG`(Windows)
|
||||
- **AI 分析凭据**:应用内已预填可用的示例 API Key 与 Base URL,可直接体验;如需使用自有服务,可在侧边栏替换后即时生效。
|
||||
|
||||
## 示例数据
|
||||
|
||||
`sample/` 目录提供了脱敏示例数据,便于快速体验:
|
||||
|
||||
- `sample/事故/*.xlsx`:按年份划分的事故记录
|
||||
- `sample/交通策略/*.xlsx`:策略发布记录
|
||||
|
||||
使用前建议复制到临时位置再进行编辑。
|
||||
|
||||
## 输入输出格式
|
||||
|
||||
@@ -118,8 +130,11 @@ streamlit run app.py
|
||||
**问题**:数据加载失败
|
||||
**解决**:确保 Excel 文件格式正确,检查列名是否匹配
|
||||
|
||||
**问题**:`NameError: name 'strategy_metrics' is not defined`
|
||||
**解决**:确保 `strategy_metrics` 函数定义在 `app.py` 中,且位于 `run_streamlit_app` 函数内
|
||||
**问题**:预测模型页面点击后图表未显示
|
||||
**解决**:确认干预日期之前至少有 10 条历史记录,或缩短预测天数重新提交
|
||||
|
||||
**问题**:热点分析提示“请上传事故数据”
|
||||
**解决**:侧边栏上传事故数据后点击“应用数据与筛选”,热点模块会复用相同数据集
|
||||
|
||||
## 日志分析
|
||||
|
||||
|
||||
@@ -27,5 +27,8 @@ cryptography>=3.4.7
|
||||
# OpenAI
|
||||
openai
|
||||
|
||||
# jieba for Chinese text segmentation
|
||||
jieba
|
||||
|
||||
# Note: hashlib and json are part of Python standard library
|
||||
# Note: os and datetime are part of Python standard library
|
||||
@@ -1,11 +0,0 @@
|
||||
{
|
||||
"region": "全市",
|
||||
"date_range": [
|
||||
"2022-01-01",
|
||||
"2022-12-31"
|
||||
],
|
||||
"strategy_filter": [],
|
||||
"rows": 365,
|
||||
"min_date": "2022-01-01",
|
||||
"max_date": "2022-12-31"
|
||||
}
|
||||
Binary file not shown.
Binary file not shown.
139
services/forecast.py
Normal file
139
services/forecast.py
Normal file
@@ -0,0 +1,139 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import warnings
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import streamlit as st
|
||||
import statsmodels.api as sm
|
||||
from statsmodels.tsa.arima.model import ARIMA
|
||||
from statsmodels.tools.sm_exceptions import ValueWarning
|
||||
from sklearn.neighbors import KNeighborsRegressor
|
||||
from sklearn.svm import SVR
|
||||
from sklearn.pipeline import make_pipeline
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
|
||||
from config.settings import ARIMA_P, ARIMA_D, ARIMA_Q, MAX_PRE_DAYS
|
||||
|
||||
|
||||
@st.cache_data(show_spinner=False)
|
||||
def evaluate_arima_model(series, arima_order):
|
||||
try:
|
||||
model = ARIMA(series, order=arima_order)
|
||||
model_fit = model.fit()
|
||||
return model_fit.aic
|
||||
except Exception:
|
||||
return float("inf")
|
||||
|
||||
|
||||
@st.cache_data(show_spinner=False)
|
||||
def arima_forecast_with_grid_search(accident_series: pd.Series,
|
||||
start_date: pd.Timestamp,
|
||||
horizon: int = 30,
|
||||
p_values: list = tuple(ARIMA_P),
|
||||
d_values: list = tuple(ARIMA_D),
|
||||
q_values: list = tuple(ARIMA_Q)) -> pd.DataFrame:
|
||||
series = accident_series.asfreq('D').fillna(0)
|
||||
start_date = pd.to_datetime(start_date)
|
||||
|
||||
warnings.filterwarnings("ignore", category=ValueWarning)
|
||||
best_score, best_cfg = float("inf"), None
|
||||
for p in p_values:
|
||||
for d in d_values:
|
||||
for q in q_values:
|
||||
order = (p, d, q)
|
||||
try:
|
||||
aic = evaluate_arima_model(series, order)
|
||||
if aic < best_score:
|
||||
best_score, best_cfg = aic, order
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
model = ARIMA(series, order=best_cfg)
|
||||
fit = model.fit()
|
||||
forecast_index = pd.date_range(start=start_date, periods=horizon, freq='D')
|
||||
res = fit.get_forecast(steps=horizon)
|
||||
df = res.summary_frame()
|
||||
df.index = forecast_index
|
||||
df.index.name = 'date'
|
||||
df.rename(columns={'mean': 'forecast'}, inplace=True)
|
||||
return df
|
||||
|
||||
|
||||
def knn_forecast_counterfactual(accident_series: pd.Series,
|
||||
intervention_date: pd.Timestamp,
|
||||
lookback: int = 14,
|
||||
horizon: int = 30):
|
||||
series = accident_series.asfreq('D').fillna(0)
|
||||
intervention_date = pd.to_datetime(intervention_date).normalize()
|
||||
|
||||
df = pd.DataFrame({'y': series})
|
||||
for i in range(1, lookback + 1):
|
||||
df[f'lag_{i}'] = df['y'].shift(i)
|
||||
|
||||
train = df.loc[:intervention_date - pd.Timedelta(days=1)].dropna()
|
||||
if len(train) < 5:
|
||||
return None, None
|
||||
X_train = train.filter(like='lag_').values
|
||||
y_train = train['y'].values
|
||||
knn = KNeighborsRegressor(n_neighbors=5)
|
||||
knn.fit(X_train, y_train)
|
||||
|
||||
history = df.loc[:intervention_date - pd.Timedelta(days=1), 'y'].tolist()
|
||||
preds = []
|
||||
for _ in range(horizon):
|
||||
if len(history) < lookback:
|
||||
return None, None
|
||||
x = np.array(history[-lookback:][::-1]).reshape(1, -1)
|
||||
pred = knn.predict(x)[0]
|
||||
preds.append(pred)
|
||||
history.append(pred)
|
||||
|
||||
pred_index = pd.date_range(intervention_date, periods=horizon, freq='D')
|
||||
return pd.Series(preds, index=pred_index, name='knn_pred'), None
|
||||
|
||||
|
||||
def fit_and_extrapolate(series: pd.Series,
|
||||
intervention_date: pd.Timestamp,
|
||||
days: int = 30,
|
||||
max_pre_days: int = MAX_PRE_DAYS):
|
||||
series = series.asfreq('D').fillna(0)
|
||||
series.index = pd.to_datetime(series.index).tz_localize(None).normalize()
|
||||
intervention_date = pd.to_datetime(intervention_date).tz_localize(None).normalize()
|
||||
|
||||
pre = series.loc[:intervention_date - pd.Timedelta(days=1)]
|
||||
if len(pre) > max_pre_days:
|
||||
pre = pre.iloc[-max_pre_days:]
|
||||
if len(pre) < 3:
|
||||
return None, None, None
|
||||
|
||||
x_pre = np.arange(len(pre))
|
||||
x_future = np.arange(len(pre), len(pre) + days)
|
||||
|
||||
try:
|
||||
X_pre_glm = sm.add_constant(np.column_stack([x_pre, x_pre**2]))
|
||||
glm = sm.GLM(pre.values, X_pre_glm, family=sm.families.Poisson())
|
||||
glm_res = glm.fit()
|
||||
X_future_glm = sm.add_constant(np.column_stack([x_future, x_future**2]))
|
||||
glm_pred = glm_res.predict(X_future_glm)
|
||||
except Exception:
|
||||
glm_pred = None
|
||||
|
||||
try:
|
||||
svr = make_pipeline(StandardScaler(), SVR(kernel='rbf', C=10, gamma=0.1))
|
||||
svr.fit(x_pre.reshape(-1, 1), pre.values)
|
||||
svr_pred = svr.predict(x_future.reshape(-1, 1))
|
||||
except Exception:
|
||||
svr_pred = None
|
||||
|
||||
post_index = pd.date_range(intervention_date, periods=days, freq='D')
|
||||
|
||||
glm_pred = pd.Series(glm_pred, index=post_index, name='glm_pred') if glm_pred is not None else None
|
||||
svr_pred = pd.Series(svr_pred, index=post_index, name='svr_pred') if svr_pred is not None else None
|
||||
|
||||
post = series.reindex(post_index)
|
||||
residuals = None
|
||||
if svr_pred is not None:
|
||||
residuals = pd.Series(post.values - svr_pred[:len(post)], index=post_index, name='residual')
|
||||
|
||||
return glm_pred, svr_pred, residuals
|
||||
|
||||
239
services/hotspot.py
Normal file
239
services/hotspot.py
Normal file
@@ -0,0 +1,239 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Iterable, Optional
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
LOCATION_KEYWORDS: tuple[str, ...] = (
|
||||
"路",
|
||||
"道",
|
||||
"街",
|
||||
"巷",
|
||||
"路口",
|
||||
"交叉口",
|
||||
"大道",
|
||||
"公路",
|
||||
"口",
|
||||
)
|
||||
AREA_KEYWORDS: tuple[str, ...] = (
|
||||
"新城",
|
||||
"临城",
|
||||
"千岛",
|
||||
"翁山",
|
||||
"海天",
|
||||
"海宇",
|
||||
"定沈",
|
||||
"滨海",
|
||||
"港岛",
|
||||
"体育",
|
||||
"长升",
|
||||
"金岛",
|
||||
"桃湾",
|
||||
)
|
||||
|
||||
LOCATION_MAPPING: dict[str, str] = {
|
||||
"新城千岛路": "千岛路",
|
||||
"千岛路海天大道": "千岛路海天大道口",
|
||||
"海天大道千岛路": "千岛路海天大道口",
|
||||
"新城翁山路": "翁山路",
|
||||
"翁山路金岛路": "翁山路金岛路口",
|
||||
"海天大道临长路": "海天大道临长路口",
|
||||
"定沈路卫生医院门口": "定沈路医院段",
|
||||
"翁山路海城路西口": "翁山路海城路口",
|
||||
"海宇道路口": "海宇道",
|
||||
"海天大道路口": "海天大道",
|
||||
"定沈路交叉路口": "定沈路",
|
||||
"千岛路路口": "千岛路",
|
||||
"体育路路口": "体育路",
|
||||
"金岛路路口": "金岛路",
|
||||
}
|
||||
|
||||
SEVERITY_MAP: dict[str, int] = {"财损": 1, "伤人": 2, "亡人": 4}
|
||||
|
||||
|
||||
def _extract_road_info(location: str | float | None) -> str:
|
||||
if pd.isna(location):
|
||||
return "未知路段"
|
||||
text = str(location)
|
||||
for keyword in LOCATION_KEYWORDS + AREA_KEYWORDS:
|
||||
if keyword in text:
|
||||
words = text.replace(",", " ").replace(",", " ").split()
|
||||
for word in words:
|
||||
if keyword in word:
|
||||
return word
|
||||
return text
|
||||
return text[:20] if len(text) > 20 else text
|
||||
|
||||
|
||||
def prepare_hotspot_dataset(accident_records: pd.DataFrame) -> pd.DataFrame:
|
||||
df = accident_records.copy()
|
||||
required_defaults: dict[str, str] = {
|
||||
"道路类型": "未知道路类型",
|
||||
"路口路段类型": "未知路段",
|
||||
"事故具体地点": "未知路段",
|
||||
"事故类型": "财损",
|
||||
"所在街道": "未知街道",
|
||||
}
|
||||
for column, default_value in required_defaults.items():
|
||||
if column not in df.columns:
|
||||
df[column] = default_value
|
||||
else:
|
||||
df[column] = df[column].fillna(default_value)
|
||||
|
||||
if "severity" not in df.columns:
|
||||
df["severity"] = df["事故类型"].map(SEVERITY_MAP).fillna(1).astype(int)
|
||||
|
||||
df["事故时间"] = pd.to_datetime(df["事故时间"], errors="coerce")
|
||||
df = df.dropna(subset=["事故时间"]).sort_values("事故时间").reset_index(drop=True)
|
||||
df["standardized_location"] = (
|
||||
df["事故具体地点"].apply(_extract_road_info).replace(LOCATION_MAPPING)
|
||||
)
|
||||
return df
|
||||
|
||||
|
||||
def analyze_hotspot_frequency(df: pd.DataFrame, time_window: str = "7D") -> pd.DataFrame:
|
||||
recent_cutoff = df["事故时间"].max() - pd.Timedelta(time_window)
|
||||
|
||||
overall_stats = df.groupby("standardized_location").agg(
|
||||
accident_count=("事故时间", "count"),
|
||||
last_accident=("事故时间", "max"),
|
||||
main_accident_type=("事故类型", _mode_fallback),
|
||||
main_road_type=("道路类型", _mode_fallback),
|
||||
main_intersection_type=("路口路段类型", _mode_fallback),
|
||||
total_severity=("severity", "sum"),
|
||||
)
|
||||
|
||||
recent_stats = (
|
||||
df[df["事故时间"] >= recent_cutoff]
|
||||
.groupby("standardized_location")
|
||||
.agg(
|
||||
recent_count=("事故时间", "count"),
|
||||
recent_accident_type=("事故类型", _mode_fallback),
|
||||
recent_severity=("severity", "sum"),
|
||||
)
|
||||
)
|
||||
|
||||
result = (
|
||||
overall_stats.merge(recent_stats, left_index=True, right_index=True, how="left")
|
||||
.fillna({"recent_count": 0, "recent_severity": 0})
|
||||
.fillna("")
|
||||
)
|
||||
result["recent_count"] = result["recent_count"].astype(int)
|
||||
result["trend_ratio"] = result["recent_count"] / result["accident_count"]
|
||||
result["days_since_last"] = (
|
||||
df["事故时间"].max() - result["last_accident"]
|
||||
).dt.days.astype(int)
|
||||
result["avg_severity"] = result["total_severity"] / result["accident_count"]
|
||||
return result.sort_values(["recent_count", "accident_count"], ascending=False)
|
||||
|
||||
|
||||
def calculate_hotspot_risk_score(hotspot_df: pd.DataFrame) -> pd.DataFrame:
|
||||
df = hotspot_df.copy()
|
||||
if df.empty:
|
||||
return df
|
||||
|
||||
df["frequency_score"] = (df["accident_count"] / df["accident_count"].max() * 40).clip(
|
||||
0, 40
|
||||
)
|
||||
df["trend_score"] = (df["trend_ratio"] * 30).clip(0, 30)
|
||||
severity_map = {"财损": 5, "伤人": 15, "亡人": 20}
|
||||
df["severity_score"] = df["main_accident_type"].map(severity_map).fillna(5)
|
||||
df["urgency_score"] = ((30 - df["days_since_last"]) / 30 * 10).clip(0, 10)
|
||||
df["risk_score"] = (
|
||||
df["frequency_score"]
|
||||
+ df["trend_score"]
|
||||
+ df["severity_score"]
|
||||
+ df["urgency_score"]
|
||||
)
|
||||
conditions = [
|
||||
df["risk_score"] >= 70,
|
||||
df["risk_score"] >= 50,
|
||||
df["risk_score"] >= 30,
|
||||
]
|
||||
choices = ["高风险", "中风险", "低风险"]
|
||||
df["risk_level"] = np.select(conditions, choices, default="一般风险")
|
||||
return df.sort_values("risk_score", ascending=False)
|
||||
|
||||
|
||||
def generate_hotspot_strategies(
|
||||
hotspot_df: pd.DataFrame, time_period: str = "本周"
|
||||
) -> list[dict[str, str | float]]:
|
||||
strategies: list[dict[str, str | float]] = []
|
||||
for location_name, location_data in hotspot_df.iterrows():
|
||||
accident_count = float(location_data["accident_count"])
|
||||
recent_count = float(location_data.get("recent_count", 0))
|
||||
accident_type = str(location_data.get("main_accident_type", "财损"))
|
||||
intersection_type = str(location_data.get("main_intersection_type", "普通路段"))
|
||||
trend_ratio = float(location_data.get("trend_ratio", 0))
|
||||
risk_level = str(location_data.get("risk_level", "一般风险"))
|
||||
|
||||
base_info = f"{time_period}对【{location_name}】"
|
||||
data_support = (
|
||||
f"(近期{int(recent_count)}起,累计{int(accident_count)}起,{accident_type}为主)"
|
||||
)
|
||||
|
||||
strategy_parts: list[str] = []
|
||||
if "信号灯" in intersection_type:
|
||||
if accident_type == "财损":
|
||||
strategy_parts.extend(["加强闯红灯查处", "优化信号配时", "整治不按规定让行"])
|
||||
else:
|
||||
strategy_parts.extend(["完善人行过街设施", "加强非机动车管理", "设置警示标志"])
|
||||
elif "普通路段" in intersection_type:
|
||||
strategy_parts.extend(["加强巡逻管控", "整治违法停车", "设置限速标志"])
|
||||
else:
|
||||
strategy_parts.extend(["分析事故成因", "制定综合整治方案"])
|
||||
|
||||
if risk_level == "高风险":
|
||||
strategy_parts.extend(["列为重点整治路段", "开展专项整治行动"])
|
||||
elif risk_level == "中风险":
|
||||
strategy_parts.append("加强日常监管")
|
||||
|
||||
if trend_ratio > 0.4:
|
||||
strategy_parts.append("近期重点监控")
|
||||
|
||||
strategy_text = (
|
||||
base_info + "," + ",".join(strategy_parts) + data_support
|
||||
if strategy_parts
|
||||
else base_info + "加强交通安全管理" + data_support
|
||||
)
|
||||
|
||||
strategies.append(
|
||||
{
|
||||
"location": location_name,
|
||||
"strategy": strategy_text,
|
||||
"risk_level": risk_level,
|
||||
"accident_count": accident_count,
|
||||
"recent_count": recent_count,
|
||||
}
|
||||
)
|
||||
return strategies
|
||||
|
||||
|
||||
def serialise_datetime_columns(df: pd.DataFrame, columns: Optional[Iterable[str]] = None) -> pd.DataFrame:
|
||||
result = df.copy()
|
||||
if columns is None:
|
||||
columns = result.columns
|
||||
for column in columns:
|
||||
if column not in result.columns:
|
||||
continue
|
||||
series = result[column]
|
||||
if pd.api.types.is_datetime64_any_dtype(series):
|
||||
result[column] = series.dt.strftime("%Y-%m-%d %H:%M:%S")
|
||||
else:
|
||||
has_timestamp = series.map(lambda value: isinstance(value, (datetime, pd.Timestamp))).any()
|
||||
if has_timestamp:
|
||||
result[column] = series.map(
|
||||
lambda value: value.strftime("%Y-%m-%d %H:%M:%S")
|
||||
if isinstance(value, (datetime, pd.Timestamp))
|
||||
else value
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def _mode_fallback(series: pd.Series) -> str:
|
||||
if series.empty:
|
||||
return ""
|
||||
mode = series.mode()
|
||||
return str(mode.iloc[0]) if not mode.empty else str(series.iloc[0])
|
||||
270
services/io.py
Normal file
270
services/io.py
Normal file
@@ -0,0 +1,270 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import Iterable, Mapping
|
||||
|
||||
import pandas as pd
|
||||
import streamlit as st
|
||||
|
||||
COLUMN_ALIASES: Mapping[str, str] = {
|
||||
'事故发生时间': '事故时间',
|
||||
'发生时间': '事故时间',
|
||||
'时间': '事故时间',
|
||||
'街道': '所在街道',
|
||||
'所属街道': '所在街道',
|
||||
'所属辖区': '所在区县',
|
||||
'辖区街道': '所在街道',
|
||||
'事故发生地点': '事故地点',
|
||||
'事故地址': '事故地点',
|
||||
'事故位置': '事故地点',
|
||||
'事故具体地址': '事故具体地点',
|
||||
'案件类型': '事故类型',
|
||||
'事故类别': '事故类型',
|
||||
'事故性质': '事故类型',
|
||||
'事故类型1': '事故类型',
|
||||
}
|
||||
|
||||
ACCIDENT_TYPE_NORMALIZATION: Mapping[str, str] = {
|
||||
'财产损失': '财损',
|
||||
'财产损失事故': '财损',
|
||||
'一般程序': '伤人',
|
||||
'一般程序事故': '伤人',
|
||||
'伤人事故': '伤人',
|
||||
'造成人员受伤': '伤人',
|
||||
'造成人员死亡': '亡人',
|
||||
'死亡事故': '亡人',
|
||||
'亡人事故': '亡人',
|
||||
'亡人死亡': '亡人',
|
||||
'号': '财损',
|
||||
}
|
||||
|
||||
REGION_FROM_LOCATION_PATTERN = re.compile(r'([一-龥]{2,8}(街道|新区|开发区|镇|区))')
|
||||
|
||||
REGION_NORMALIZATION: Mapping[str, str] = {
|
||||
'临城中队': '临城街道',
|
||||
'临城新区': '临城街道',
|
||||
'临城': '临城街道',
|
||||
'新城': '临城街道',
|
||||
'千岛中队': '千岛街道',
|
||||
'千岛新区': '千岛街道',
|
||||
'千岛': '千岛街道',
|
||||
'沈家门中队': '沈家门街道',
|
||||
'沈家门': '沈家门街道',
|
||||
'普陀城区': '沈家门街道',
|
||||
'普陀': '沈家门街道',
|
||||
}
|
||||
|
||||
|
||||
def _clean_text(series: pd.Series) -> pd.Series:
|
||||
"""Strip whitespace and normalise obvious null placeholders."""
|
||||
cleaned = series.astype(str).str.strip()
|
||||
null_tokens = {'', 'nan', 'NaN', 'None', 'NULL', '<NA>', '无', '—'}
|
||||
return cleaned.mask(cleaned.isin(null_tokens))
|
||||
|
||||
|
||||
def _maybe_seek_start(file_obj) -> None:
|
||||
if hasattr(file_obj, "seek"):
|
||||
try:
|
||||
file_obj.seek(0)
|
||||
except Exception: # pragma: no cover - guard against non file-likes
|
||||
pass
|
||||
|
||||
|
||||
def _prepare_sheet(df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""Standardise a single sheet from the事故数据 workbook."""
|
||||
if df is None or df.empty:
|
||||
return pd.DataFrame()
|
||||
|
||||
sheet = df.copy()
|
||||
# Normalise column names first
|
||||
sheet.columns = [str(col).strip() for col in sheet.columns]
|
||||
# If 栏目 still not recognised, attempt to locate header row inside the data
|
||||
if '事故时间' not in sheet.columns and '事故发生时间' not in sheet.columns:
|
||||
header_row = None
|
||||
for idx, row in sheet.iterrows():
|
||||
values = [str(cell).strip() for cell in row.tolist()]
|
||||
if '事故时间' in values or '事故发生时间' in values or '报警时间' in values:
|
||||
header_row = idx
|
||||
break
|
||||
if header_row is not None:
|
||||
sheet.columns = [str(x).strip() for x in sheet.iloc[header_row].tolist()]
|
||||
sheet = sheet.iloc[header_row + 1 :].reset_index(drop=True)
|
||||
sheet.columns = [str(col).strip() for col in sheet.columns]
|
||||
|
||||
# Apply aliases after potential header relocation
|
||||
sheet = sheet.rename(columns={src: dst for src, dst in COLUMN_ALIASES.items() if src in sheet.columns})
|
||||
|
||||
return sheet
|
||||
|
||||
|
||||
def _coalesce_columns(df: pd.DataFrame, columns: Iterable[str]) -> pd.Series:
|
||||
result = pd.Series(pd.NA, index=df.index, dtype="object")
|
||||
for col in columns:
|
||||
if col in df.columns:
|
||||
candidate = _clean_text(df[col])
|
||||
result = result.fillna(candidate)
|
||||
return result
|
||||
|
||||
|
||||
def _infer_region_from_location(location: str) -> str | None:
|
||||
if pd.isna(location):
|
||||
return None
|
||||
text = str(location).strip()
|
||||
if not text:
|
||||
return None
|
||||
match = REGION_FROM_LOCATION_PATTERN.search(text)
|
||||
if match:
|
||||
return match.group(1)
|
||||
return None
|
||||
|
||||
|
||||
def _normalise_region_series(series: pd.Series) -> pd.Series:
|
||||
return series.map(lambda val: REGION_NORMALIZATION.get(val, val) if pd.notna(val) else val)
|
||||
|
||||
|
||||
def load_accident_records(accident_file, *, require_location: bool = False) -> pd.DataFrame:
|
||||
"""
|
||||
Load accident records from the updated Excel template.
|
||||
|
||||
The function supports workbooks with a single sheet (e.g. sample/事故处理/事故2021-2022.xlsx)
|
||||
as well as legacy multi-sheet formats where the header row might sit within the data.
|
||||
"""
|
||||
_maybe_seek_start(accident_file)
|
||||
sheets = pd.read_excel(accident_file, sheet_name=None)
|
||||
if isinstance(sheets, dict):
|
||||
frames = [frame for frame in ( _prepare_sheet(df) for df in sheets.values() ) if not frame.empty]
|
||||
else: # pragma: no cover - pandas only returns dict when sheet_name=None, but keep guard
|
||||
frames = [_prepare_sheet(sheets)]
|
||||
|
||||
if not frames:
|
||||
raise ValueError("未在上传的事故数据中检测到有效的事故记录,请确认文件内容。")
|
||||
|
||||
accident_df = pd.concat(frames, ignore_index=True)
|
||||
|
||||
# Normalise columns of interest
|
||||
if '事故时间' not in accident_df.columns and '报警时间' in accident_df.columns:
|
||||
accident_df['事故时间'] = accident_df['报警时间']
|
||||
|
||||
if '事故时间' not in accident_df.columns:
|
||||
raise ValueError("事故数据缺少“事故时间”字段,请确认模板是否为最新版本。")
|
||||
|
||||
accident_df['事故时间'] = pd.to_datetime(accident_df['事故时间'], errors='coerce')
|
||||
|
||||
# Location harmonisation (used for both region inference and hotspot analysis)
|
||||
location_columns_available = [col for col in ['事故具体地点', '事故地点'] if col in accident_df.columns]
|
||||
location_series = _coalesce_columns(accident_df, ['事故具体地点', '事故地点'])
|
||||
|
||||
# Region handling
|
||||
region = _coalesce_columns(accident_df, ['所在街道', '所属街道', '所在区县', '辖区中队'])
|
||||
# Infer region from location fields when still missing
|
||||
if region.isna().any():
|
||||
inferred = location_series.map(_infer_region_from_location)
|
||||
region = region.fillna(inferred)
|
||||
region = region.fillna(_clean_text(location_series))
|
||||
|
||||
region_clean = _clean_text(region)
|
||||
accident_df['所在街道'] = _normalise_region_series(region_clean)
|
||||
|
||||
# Accident type normalisation
|
||||
accident_type = _coalesce_columns(accident_df, ['事故类型', '事故类别', '事故性质'])
|
||||
accident_type = accident_type.replace(ACCIDENT_TYPE_NORMALIZATION)
|
||||
accident_type = _clean_text(accident_type).replace(ACCIDENT_TYPE_NORMALIZATION)
|
||||
accident_df['事故类型'] = accident_type.fillna('财损')
|
||||
|
||||
# Location column harmonisation
|
||||
if require_location and not location_columns_available and location_series.isna().all():
|
||||
raise ValueError("事故数据缺少“事故具体地点”字段,请确认模板是否与 sample/事故处理 中示例一致。")
|
||||
accident_df['事故具体地点'] = _clean_text(location_series)
|
||||
|
||||
# Drop records with missing core fields
|
||||
subset = ['事故时间', '所在街道', '事故类型']
|
||||
if require_location:
|
||||
subset.append('事故具体地点')
|
||||
accident_df = accident_df.dropna(subset=subset)
|
||||
|
||||
# Severity score
|
||||
severity_map = {'财损': 1, '伤人': 2, '亡人': 4}
|
||||
accident_df['severity'] = accident_df['事故类型'].map(severity_map).fillna(1).astype(int)
|
||||
|
||||
accident_df = accident_df.sort_values('事故时间').reset_index(drop=True)
|
||||
|
||||
return accident_df
|
||||
|
||||
|
||||
@st.cache_data(show_spinner=False)
|
||||
def load_and_clean_data(accident_file, strategy_file):
|
||||
accident_records = load_accident_records(accident_file)
|
||||
|
||||
accident_data = accident_records.rename(
|
||||
columns={'事故时间': 'date_time', '所在街道': 'region', '事故类型': 'category'}
|
||||
)
|
||||
|
||||
_maybe_seek_start(strategy_file)
|
||||
strategy_df = pd.read_excel(strategy_file)
|
||||
strategy_df = strategy_df.rename(columns=lambda col: str(col).strip())
|
||||
if '发布时间' not in strategy_df.columns:
|
||||
raise ValueError("策略数据缺少“发布时间”字段,请确认文件格式。")
|
||||
|
||||
strategy_df['发布时间'] = pd.to_datetime(strategy_df['发布时间'], errors='coerce')
|
||||
if '交通策略类型' not in strategy_df.columns:
|
||||
raise ValueError("策略数据缺少“交通策略类型”字段,请确认文件格式。")
|
||||
|
||||
strategy_df['交通策略类型'] = _clean_text(strategy_df['交通策略类型'])
|
||||
strategy_df = strategy_df.dropna(subset=['发布时间', '交通策略类型'])
|
||||
|
||||
accident_data = accident_data[['date_time', 'region', 'category', 'severity']]
|
||||
strategy_df = strategy_df[['发布时间', '交通策略类型']].rename(
|
||||
columns={'发布时间': 'date_time', '交通策略类型': 'strategy_type'}
|
||||
)
|
||||
|
||||
return accident_data, strategy_df
|
||||
|
||||
|
||||
@st.cache_data(show_spinner=False)
|
||||
def aggregate_daily_data(accident_data: pd.DataFrame, strategy_data: pd.DataFrame) -> pd.DataFrame:
|
||||
accident_data = accident_data.copy()
|
||||
strategy_data = strategy_data.copy()
|
||||
|
||||
accident_data['date'] = accident_data['date_time'].dt.date
|
||||
daily_accidents = accident_data.groupby('date').agg(
|
||||
accident_count=('date_time', 'count'),
|
||||
severity=('severity', 'sum')
|
||||
)
|
||||
daily_accidents.index = pd.to_datetime(daily_accidents.index)
|
||||
|
||||
strategy_data['date'] = strategy_data['date_time'].dt.date
|
||||
daily_strategies = strategy_data.groupby('date')['strategy_type'].apply(list)
|
||||
daily_strategies.index = pd.to_datetime(daily_strategies.index)
|
||||
|
||||
combined = daily_accidents.join(daily_strategies, how='left')
|
||||
combined['strategy_type'] = combined['strategy_type'].apply(lambda x: x if isinstance(x, list) else [])
|
||||
combined = combined.asfreq('D')
|
||||
combined[['accident_count', 'severity']] = combined[['accident_count', 'severity']].fillna(0)
|
||||
combined['strategy_type'] = combined['strategy_type'].apply(lambda x: x if isinstance(x, list) else [])
|
||||
return combined
|
||||
|
||||
|
||||
@st.cache_data(show_spinner=False)
|
||||
def aggregate_daily_data_by_region(accident_data: pd.DataFrame, strategy_data: pd.DataFrame) -> pd.DataFrame:
|
||||
df = accident_data.copy()
|
||||
df['date'] = df['date_time'].dt.date
|
||||
g = df.groupby(['region', 'date']).agg(
|
||||
accident_count=('date_time', 'count'),
|
||||
severity=('severity', 'sum')
|
||||
)
|
||||
g.index = g.index.set_levels([g.index.levels[0], pd.to_datetime(g.index.levels[1])])
|
||||
g = g.sort_index()
|
||||
|
||||
s = strategy_data.copy()
|
||||
s['date'] = s['date_time'].dt.date
|
||||
daily_strategies = s.groupby('date')['strategy_type'].apply(list)
|
||||
daily_strategies.index = pd.to_datetime(daily_strategies.index)
|
||||
|
||||
regions = g.index.get_level_values(0).unique()
|
||||
dates = pd.date_range(g.index.get_level_values(1).min(), g.index.get_level_values(1).max(), freq='D')
|
||||
full_index = pd.MultiIndex.from_product([regions, dates], names=['region', 'date'])
|
||||
g = g.reindex(full_index).fillna(0)
|
||||
|
||||
strat_map = daily_strategies.to_dict()
|
||||
g = g.assign(strategy_type=[strat_map.get(d, []) for d in g.index.get_level_values('date')])
|
||||
return g
|
||||
86
services/metrics.py
Normal file
86
services/metrics.py
Normal file
@@ -0,0 +1,86 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from sklearn.metrics import mean_absolute_error, mean_squared_error
|
||||
from statsmodels.tsa.arima.model import ARIMA
|
||||
import streamlit as st
|
||||
|
||||
|
||||
@st.cache_data(show_spinner=False)
|
||||
def evaluate_models(series: pd.Series,
|
||||
horizon: int = 30,
|
||||
lookback: int = 14,
|
||||
p_values: range = range(0, 4),
|
||||
d_values: range = range(0, 2),
|
||||
q_values: range = range(0, 4)) -> pd.DataFrame:
|
||||
"""
|
||||
留出法(最后 horizon 天作为验证集)比较 ARIMA / KNN / GLM / SVR,
|
||||
输出 MAE・RMSE・MAPE,并按 RMSE 升序排序。
|
||||
"""
|
||||
series = series.asfreq('D').fillna(0)
|
||||
if len(series) <= horizon + 10:
|
||||
raise ValueError("序列太短,无法留出 %d 天进行评估。" % horizon)
|
||||
|
||||
train, test = series.iloc[:-horizon], series.iloc[-horizon:]
|
||||
|
||||
def _to_series_like(pred, a_index):
|
||||
if isinstance(pred, pd.Series):
|
||||
return pred.reindex(a_index)
|
||||
return pd.Series(pred, index=a_index)
|
||||
|
||||
def _metrics(a: pd.Series, p) -> dict:
|
||||
p = _to_series_like(p, a.index).astype(float)
|
||||
a = a.astype(float)
|
||||
mae = mean_absolute_error(a, p)
|
||||
try:
|
||||
rmse = mean_squared_error(a, p, squared=False)
|
||||
except TypeError:
|
||||
rmse = mean_squared_error(a, p) ** 0.5
|
||||
mape = np.nanmean(np.abs((a - p) / np.where(a == 0, np.nan, a))) * 100
|
||||
return {"MAE": mae, "RMSE": rmse, "MAPE": mape}
|
||||
|
||||
results = {}
|
||||
|
||||
best_aic, best_order = float('inf'), (1, 0, 1)
|
||||
for p in p_values:
|
||||
for d in d_values:
|
||||
for q in q_values:
|
||||
try:
|
||||
aic = ARIMA(train, order=(p, d, q)).fit().aic
|
||||
if aic < best_aic:
|
||||
best_aic, best_order = aic, (p, d, q)
|
||||
except Exception:
|
||||
continue
|
||||
arima_train = train.asfreq('D').fillna(0)
|
||||
arima_pred = ARIMA(arima_train, order=best_order).fit().forecast(steps=horizon)
|
||||
results['ARIMA'] = _metrics(test, arima_pred)
|
||||
|
||||
# Import local utilities to avoid circular dependencies
|
||||
from services.forecast import knn_forecast_counterfactual, fit_and_extrapolate
|
||||
|
||||
try:
|
||||
knn_pred, _ = knn_forecast_counterfactual(series,
|
||||
train.index[-1] + pd.Timedelta(days=1),
|
||||
lookback=lookback,
|
||||
horizon=horizon)
|
||||
if knn_pred is not None:
|
||||
results['KNN'] = _metrics(test, knn_pred)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
glm_pred, svr_pred, _ = fit_and_extrapolate(series,
|
||||
train.index[-1] + pd.Timedelta(days=1),
|
||||
days=horizon)
|
||||
if glm_pred is not None:
|
||||
results['GLM'] = _metrics(test, glm_pred)
|
||||
if svr_pred is not None:
|
||||
results['SVR'] = _metrics(test, svr_pred)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return (pd.DataFrame(results)
|
||||
.T.sort_values('RMSE')
|
||||
.round(3))
|
||||
|
||||
157
services/strategy.py
Normal file
157
services/strategy.py
Normal file
@@ -0,0 +1,157 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pandas as pd
|
||||
import streamlit as st
|
||||
|
||||
from services.forecast import fit_and_extrapolate, arima_forecast_with_grid_search
|
||||
from config.settings import MIN_PRE_DAYS, MAX_PRE_DAYS
|
||||
|
||||
|
||||
def evaluate_strategy_effectiveness(actual_series: pd.Series,
|
||||
counterfactual_series: pd.Series,
|
||||
severity_series: pd.Series,
|
||||
strategy_date: pd.Timestamp,
|
||||
window: int = 30):
|
||||
strategy_date = pd.to_datetime(strategy_date)
|
||||
window_end = strategy_date + pd.Timedelta(days=window - 1)
|
||||
pre_sev = severity_series.loc[strategy_date - pd.Timedelta(days=window):strategy_date - pd.Timedelta(days=1)].sum()
|
||||
post_sev = severity_series.loc[strategy_date:window_end].sum()
|
||||
actual_post = actual_series.loc[strategy_date:window_end]
|
||||
counter_post = counterfactual_series.loc[strategy_date:window_end].reindex(actual_post.index)
|
||||
window_len = len(actual_post)
|
||||
if window_len == 0:
|
||||
return False, False, (0.0, 0.0), '三级'
|
||||
effective_days = (actual_post < counter_post).sum()
|
||||
count_effective = effective_days >= (window_len / 2)
|
||||
severity_effective = post_sev < pre_sev
|
||||
cf_sum = counter_post.sum()
|
||||
F1 = (cf_sum - actual_post.sum()) / cf_sum if cf_sum > 0 else 0.0
|
||||
F2 = (pre_sev - post_sev) / pre_sev if pre_sev > 0 else 0.0
|
||||
if F1 > 0.5 and F2 > 0.5:
|
||||
safety_state = '一级'
|
||||
elif F1 > 0.3:
|
||||
safety_state = '二级'
|
||||
else:
|
||||
safety_state = '三级'
|
||||
return count_effective, severity_effective, (F1, F2), safety_state
|
||||
|
||||
|
||||
@st.cache_data(show_spinner=False)
|
||||
def generate_output_and_recommendations(combined_data: pd.DataFrame,
|
||||
strategy_types: list,
|
||||
region: str = '全市',
|
||||
horizon: int = 30):
|
||||
results = {}
|
||||
combined_data = combined_data.copy().asfreq('D')
|
||||
combined_data[['accident_count','severity']] = combined_data[['accident_count','severity']].fillna(0)
|
||||
combined_data['strategy_type'] = combined_data['strategy_type'].apply(lambda x: x if isinstance(x, list) else [])
|
||||
|
||||
acc_full = combined_data['accident_count']
|
||||
sev_full = combined_data['severity']
|
||||
|
||||
max_fit_days = max(horizon + 60, MAX_PRE_DAYS)
|
||||
|
||||
for strategy in strategy_types:
|
||||
has_strategy = combined_data['strategy_type'].apply(lambda x: strategy in x)
|
||||
if not has_strategy.any():
|
||||
continue
|
||||
candidate_dates = has_strategy[has_strategy].index
|
||||
intervention_date = None
|
||||
fit_start_dt = None
|
||||
for dt in candidate_dates:
|
||||
fit_start_dt = max(acc_full.index.min(), dt - pd.Timedelta(days=max_fit_days))
|
||||
pre_hist = acc_full.loc[fit_start_dt:dt - pd.Timedelta(days=1)]
|
||||
if len(pre_hist) >= MIN_PRE_DAYS:
|
||||
intervention_date = dt
|
||||
break
|
||||
if intervention_date is None:
|
||||
intervention_date = candidate_dates[0]
|
||||
fit_start_dt = max(acc_full.index.min(), intervention_date - pd.Timedelta(days=max_fit_days))
|
||||
|
||||
acc = acc_full.loc[fit_start_dt:]
|
||||
sev = sev_full.loc[fit_start_dt:]
|
||||
horizon_eff = max(7, min(horizon, len(acc.loc[intervention_date:]) ))
|
||||
|
||||
glm_pred, svr_pred, residuals = fit_and_extrapolate(acc, intervention_date, days=horizon_eff)
|
||||
|
||||
counter = None
|
||||
if svr_pred is not None:
|
||||
counter = svr_pred
|
||||
elif glm_pred is not None:
|
||||
counter = glm_pred
|
||||
else:
|
||||
try:
|
||||
arima_df = arima_forecast_with_grid_search(acc.loc[:intervention_date],
|
||||
start_date=intervention_date + pd.Timedelta(days=1),
|
||||
horizon=horizon_eff)
|
||||
counter = pd.Series(arima_df['forecast'].values, index=arima_df.index, name='cf_arima')
|
||||
residuals = (acc.reindex(counter.index) - counter)
|
||||
except Exception:
|
||||
counter = None
|
||||
if counter is None:
|
||||
continue
|
||||
|
||||
count_eff, sev_eff, (F1, F2), state = evaluate_strategy_effectiveness(
|
||||
actual_series=acc,
|
||||
counterfactual_series=counter,
|
||||
severity_series=sev,
|
||||
strategy_date=intervention_date,
|
||||
window=horizon_eff
|
||||
)
|
||||
results[strategy] = {
|
||||
'effect_strength': float(residuals.dropna().mean()) if residuals is not None else 0.0,
|
||||
'adaptability': float(F1 + F2),
|
||||
'count_effective': bool(count_eff),
|
||||
'severity_effective': bool(sev_eff),
|
||||
'safety_state': state,
|
||||
'F1': float(F1),
|
||||
'F2': float(F2),
|
||||
'intervention_date': str(intervention_date.date())
|
||||
}
|
||||
|
||||
# Secondary attempt with 14-day window if no results
|
||||
if not results:
|
||||
for strategy in strategy_types:
|
||||
has_strategy = combined_data['strategy_type'].apply(lambda x: strategy in x)
|
||||
if not has_strategy.any():
|
||||
continue
|
||||
intervention_date = has_strategy[has_strategy].index[0]
|
||||
glm_pred, svr_pred, residuals = fit_and_extrapolate(acc_full, intervention_date, days=14)
|
||||
counter = None
|
||||
if svr_pred is not None:
|
||||
counter = svr_pred
|
||||
elif glm_pred is not None:
|
||||
counter = glm_pred
|
||||
else:
|
||||
try:
|
||||
arima_df = arima_forecast_with_grid_search(acc_full.loc[:intervention_date],
|
||||
start_date=intervention_date + pd.Timedelta(days=1),
|
||||
horizon=14)
|
||||
counter = pd.Series(arima_df['forecast'].values, index=arima_df.index, name='cf_arima')
|
||||
residuals = (acc_full.reindex(counter.index) - counter)
|
||||
except Exception:
|
||||
counter = None
|
||||
if counter is None:
|
||||
continue
|
||||
count_eff, sev_eff, (F1, F2), state = evaluate_strategy_effectiveness(
|
||||
actual_series=acc_full,
|
||||
counterfactual_series=counter,
|
||||
severity_series=sev_full,
|
||||
strategy_date=intervention_date,
|
||||
window=14
|
||||
)
|
||||
results[strategy] = {
|
||||
'effect_strength': float(residuals.dropna().mean()) if residuals is not None else 0.0,
|
||||
'adaptability': float(F1 + F2),
|
||||
'count_effective': bool(count_eff),
|
||||
'severity_effective': bool(sev_eff),
|
||||
'safety_state': state,
|
||||
'F1': float(F1),
|
||||
'F2': float(F2),
|
||||
'intervention_date': str(intervention_date.date())
|
||||
}
|
||||
|
||||
best_strategy = max(results, key=lambda x: results[x]['adaptability']) if results else None
|
||||
recommendation = f"建议在{region}区域长期实施策略类型 {best_strategy}" if best_strategy else "无足够数据推荐策略"
|
||||
return results, recommendation
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -1,5 +0,0 @@
|
||||
,effect_strength,adaptability,count_effective,severity_effective,safety_state,F1,F2,intervention_date
|
||||
交通信息预警,-8.965321179202334,-0.7855379968058066,True,False,三级,0.2463091369521552,-1.0318471337579618,2022-01-13
|
||||
交通整治活动,-2.651006128785241,-1.667254385637472,True,False,三级,0.08411173458110731,-1.7513661202185793,2022-01-11
|
||||
交通管制措施,-10.70286313762653,0.19010392243197832,True,False,三级,0.2989387495766646,-0.1088348271446863,2022-01-20
|
||||
政策制度实施,-2.6771799687750018,-5.1316650216481605,True,False,三级,0.07856225107911223,-5.2102272727272725,2022-01-06
|
||||
|
13
ui_sections/__init__.py
Normal file
13
ui_sections/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from .overview import render_overview
|
||||
from .forecast import render_forecast
|
||||
from .model_eval import render_model_eval
|
||||
from .strategy_eval import render_strategy_eval
|
||||
from .hotspot import render_hotspot
|
||||
|
||||
__all__ = [
|
||||
'render_overview',
|
||||
'render_forecast',
|
||||
'render_model_eval',
|
||||
'render_strategy_eval',
|
||||
'render_hotspot',
|
||||
]
|
||||
185
ui_sections/forecast.py
Normal file
185
ui_sections/forecast.py
Normal file
@@ -0,0 +1,185 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pandas as pd
|
||||
import plotly.graph_objects as go
|
||||
import streamlit as st
|
||||
|
||||
from services.forecast import (
|
||||
arima_forecast_with_grid_search,
|
||||
knn_forecast_counterfactual,
|
||||
fit_and_extrapolate,
|
||||
)
|
||||
|
||||
|
||||
def render_forecast(base: pd.DataFrame):
|
||||
st.subheader("多模型预测比较")
|
||||
|
||||
if base is None or base.empty:
|
||||
st.info("暂无可用于预测的事故数据,请先在侧边栏上传数据并应用筛选。")
|
||||
st.session_state.setdefault(
|
||||
"forecast_state",
|
||||
{"results": None, "last_message": "暂无可用于预测的事故数据。"},
|
||||
)
|
||||
return
|
||||
|
||||
forecast_state = st.session_state.setdefault(
|
||||
"forecast_state",
|
||||
{
|
||||
"selected_date": None,
|
||||
"horizon": 30,
|
||||
"results": None,
|
||||
"last_message": None,
|
||||
"data_signature": None,
|
||||
},
|
||||
)
|
||||
|
||||
earliest_date = base.index.min().date()
|
||||
latest_date = base.index.max().date()
|
||||
fallback_date = max(
|
||||
(base.index.max() - pd.Timedelta(days=30)).date(),
|
||||
earliest_date,
|
||||
)
|
||||
current_signature = (
|
||||
earliest_date.isoformat(),
|
||||
latest_date.isoformat(),
|
||||
int(len(base)),
|
||||
float(base["accident_count"].sum()),
|
||||
)
|
||||
|
||||
# Reset cached results if the underlying dataset has changed
|
||||
if forecast_state.get("data_signature") != current_signature:
|
||||
forecast_state.update(
|
||||
{
|
||||
"data_signature": current_signature,
|
||||
"results": None,
|
||||
"last_message": None,
|
||||
"selected_date": fallback_date,
|
||||
}
|
||||
)
|
||||
|
||||
default_date = forecast_state.get("selected_date") or fallback_date
|
||||
if default_date < earliest_date:
|
||||
default_date = earliest_date
|
||||
if default_date > latest_date:
|
||||
default_date = latest_date
|
||||
|
||||
with st.form(key="predict_form"):
|
||||
selected_date = st.date_input(
|
||||
"选择干预日期 / 预测起点",
|
||||
value=default_date,
|
||||
min_value=earliest_date,
|
||||
max_value=latest_date,
|
||||
)
|
||||
horizon = st.number_input(
|
||||
"预测天数",
|
||||
min_value=7,
|
||||
max_value=90,
|
||||
value=int(forecast_state.get("horizon", 30)),
|
||||
step=1,
|
||||
)
|
||||
submit_predict = st.form_submit_button("应用预测参数")
|
||||
|
||||
forecast_state["selected_date"] = selected_date
|
||||
forecast_state["horizon"] = int(horizon)
|
||||
|
||||
if submit_predict:
|
||||
history = base.loc[:pd.to_datetime(selected_date)]
|
||||
if len(history) < 10:
|
||||
forecast_state.update(
|
||||
{
|
||||
"results": None,
|
||||
"last_message": "干预前数据不足(至少需要 10 个观测点)。",
|
||||
}
|
||||
)
|
||||
else:
|
||||
with st.spinner("正在生成预测结果…"):
|
||||
warnings: list[str] = []
|
||||
try:
|
||||
train_series = history["accident_count"]
|
||||
arima_df = arima_forecast_with_grid_search(
|
||||
train_series,
|
||||
start_date=pd.to_datetime(selected_date) + pd.Timedelta(days=1),
|
||||
horizon=int(horizon),
|
||||
)
|
||||
except Exception as exc:
|
||||
arima_df = None
|
||||
warnings.append(f"ARIMA 运行失败:{exc}")
|
||||
|
||||
knn_pred, _ = knn_forecast_counterfactual(
|
||||
base["accident_count"],
|
||||
pd.to_datetime(selected_date),
|
||||
horizon=int(horizon),
|
||||
)
|
||||
if knn_pred is None:
|
||||
warnings.append("KNN 预测未生成结果(历史数据不足或维度不满足要求)。")
|
||||
|
||||
glm_pred, svr_pred, _ = fit_and_extrapolate(
|
||||
base["accident_count"],
|
||||
pd.to_datetime(selected_date),
|
||||
days=int(horizon),
|
||||
)
|
||||
if glm_pred is None and svr_pred is None:
|
||||
warnings.append("GLM/SVR 预测未生成结果,建议缩短预测窗口或检查源数据。")
|
||||
|
||||
forecast_state.update(
|
||||
{
|
||||
"results": {
|
||||
"selected_date": selected_date,
|
||||
"horizon": int(horizon),
|
||||
"arima_df": arima_df,
|
||||
"knn_pred": knn_pred,
|
||||
"glm_pred": glm_pred,
|
||||
"svr_pred": svr_pred,
|
||||
"warnings": warnings,
|
||||
},
|
||||
"last_message": None,
|
||||
}
|
||||
)
|
||||
|
||||
results = forecast_state.get("results")
|
||||
if not results:
|
||||
if forecast_state.get("last_message"):
|
||||
st.warning(forecast_state["last_message"])
|
||||
else:
|
||||
st.info("请设置预测参数并点击“应用预测参数”按钮。")
|
||||
return
|
||||
|
||||
first_date = pd.to_datetime(results["selected_date"])
|
||||
horizon_days = int(results["horizon"])
|
||||
arima_df = results["arima_df"]
|
||||
knn_pred = results["knn_pred"]
|
||||
glm_pred = results["glm_pred"]
|
||||
svr_pred = results["svr_pred"]
|
||||
|
||||
fig_pred = go.Figure()
|
||||
fig_pred.add_trace(
|
||||
go.Scatter(x=base.index, y=base["accident_count"], name="实际", mode="lines")
|
||||
)
|
||||
if arima_df is not None:
|
||||
fig_pred.add_trace(
|
||||
go.Scatter(x=arima_df.index, y=arima_df["forecast"], name="ARIMA", mode="lines")
|
||||
)
|
||||
if knn_pred is not None:
|
||||
fig_pred.add_trace(go.Scatter(x=knn_pred.index, y=knn_pred, name="KNN", mode="lines"))
|
||||
if glm_pred is not None:
|
||||
fig_pred.add_trace(go.Scatter(x=glm_pred.index, y=glm_pred, name="GLM", mode="lines"))
|
||||
if svr_pred is not None:
|
||||
fig_pred.add_trace(go.Scatter(x=svr_pred.index, y=svr_pred, name="SVR", mode="lines"))
|
||||
|
||||
fig_pred.update_layout(
|
||||
title=f"多模型预测比较(起点:{first_date.date()},预测 {horizon_days} 天)",
|
||||
xaxis_title="日期",
|
||||
yaxis_title="事故数",
|
||||
)
|
||||
st.plotly_chart(fig_pred, use_container_width=True)
|
||||
|
||||
if arima_df is not None:
|
||||
st.download_button(
|
||||
"下载 ARIMA 预测 CSV",
|
||||
data=arima_df.to_csv().encode("utf-8-sig"),
|
||||
file_name="arima_forecast.csv",
|
||||
mime="text/csv",
|
||||
)
|
||||
|
||||
for warning_text in results.get("warnings", []):
|
||||
st.warning(warning_text)
|
||||
185
ui_sections/hotspot.py
Normal file
185
ui_sections/hotspot.py
Normal file
@@ -0,0 +1,185 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from datetime import datetime
|
||||
|
||||
import plotly.express as px
|
||||
import streamlit as st
|
||||
|
||||
from services.hotspot import (
|
||||
analyze_hotspot_frequency,
|
||||
calculate_hotspot_risk_score,
|
||||
generate_hotspot_strategies,
|
||||
prepare_hotspot_dataset,
|
||||
serialise_datetime_columns,
|
||||
)
|
||||
|
||||
|
||||
@st.cache_data(show_spinner=False)
|
||||
def _prepare_hotspot_data(df):
|
||||
return prepare_hotspot_dataset(df)
|
||||
|
||||
|
||||
def render_hotspot(accident_records, accident_source_name: str | None) -> None:
|
||||
st.header("📍 事故多发路口分析")
|
||||
st.markdown("独立分析事故数据,识别高风险路口并生成针对性策略。")
|
||||
|
||||
if accident_records is None:
|
||||
st.info("请在左侧上传事故数据并点击“应用数据与筛选”后再执行热点分析。")
|
||||
st.markdown(
|
||||
"""
|
||||
### 📝 支持的数据格式要求:
|
||||
- **文件格式**:Excel (.xlsx)
|
||||
- **必要字段**:
|
||||
- `事故时间`
|
||||
- `事故类型`
|
||||
- `事故具体地点`
|
||||
- `所在街道`
|
||||
- `道路类型`
|
||||
- `路口路段类型`
|
||||
"""
|
||||
)
|
||||
return
|
||||
|
||||
with st.spinner("正在准备事故热点数据…"):
|
||||
hotspot_data = _prepare_hotspot_data(accident_records)
|
||||
|
||||
st.success(f"✅ 成功加载数据:{len(hotspot_data)} 条事故记录")
|
||||
|
||||
metric_cols = st.columns(3)
|
||||
with metric_cols[0]:
|
||||
st.metric(
|
||||
"数据时间范围",
|
||||
f"{hotspot_data['事故时间'].min().strftime('%Y-%m-%d')} 至 {hotspot_data['事故时间'].max().strftime('%Y-%m-%d')}",
|
||||
)
|
||||
with metric_cols[1]:
|
||||
st.metric(
|
||||
"事故类型分布",
|
||||
f"财损: {len(hotspot_data[hotspot_data['事故类型'] == '财损'])}起",
|
||||
)
|
||||
with metric_cols[2]:
|
||||
st.metric("涉及区域", f"{hotspot_data['所在街道'].nunique()}个街道")
|
||||
|
||||
st.subheader("🔧 分析参数设置")
|
||||
settings_cols = st.columns(3)
|
||||
with settings_cols[0]:
|
||||
time_window = st.selectbox(
|
||||
"统计时间窗口",
|
||||
options=["7D", "15D", "30D"],
|
||||
index=0,
|
||||
key="hotspot_window",
|
||||
)
|
||||
with settings_cols[1]:
|
||||
min_accidents = st.number_input(
|
||||
"最小事故数", min_value=1, max_value=50, value=3, key="hotspot_min_accidents"
|
||||
)
|
||||
with settings_cols[2]:
|
||||
top_n = st.slider("显示热点数量", min_value=3, max_value=20, value=8, key="hotspot_top_n")
|
||||
|
||||
if not st.button("🚀 开始热点分析", type="primary"):
|
||||
return
|
||||
|
||||
with st.spinner("正在分析事故热点分布…"):
|
||||
hotspots = analyze_hotspot_frequency(hotspot_data, time_window=time_window)
|
||||
hotspots = hotspots[hotspots["accident_count"] >= min_accidents]
|
||||
|
||||
if hotspots.empty:
|
||||
st.warning("⚠️ 未找到符合条件的事故热点数据,请调整筛选参数。")
|
||||
return
|
||||
|
||||
hotspots_with_risk = calculate_hotspot_risk_score(hotspots.head(top_n * 3))
|
||||
top_hotspots = hotspots_with_risk.head(top_n)
|
||||
|
||||
st.subheader("📊 事故多发路口排名(前{0}个)".format(top_n))
|
||||
display_columns = {
|
||||
"accident_count": "累计事故数",
|
||||
"recent_count": "近期事故数",
|
||||
"trend_ratio": "趋势比例",
|
||||
"main_accident_type": "主要类型",
|
||||
"main_intersection_type": "路口类型",
|
||||
"risk_score": "风险评分",
|
||||
"risk_level": "风险等级",
|
||||
}
|
||||
display_df = top_hotspots[list(display_columns.keys())].rename(columns=display_columns)
|
||||
styled_df = display_df.style.format({"趋势比例": "{:.2f}", "风险评分": "{:.1f}"}).background_gradient(
|
||||
subset=["风险评分"], cmap="Reds"
|
||||
)
|
||||
st.dataframe(styled_df, use_container_width=True)
|
||||
|
||||
st.subheader("🎯 针对性策略建议")
|
||||
strategies = generate_hotspot_strategies(top_hotspots, time_period="本周")
|
||||
for index, strategy_info in enumerate(strategies, start=1):
|
||||
message = f"**{index}. {strategy_info['strategy']}**"
|
||||
risk_level = strategy_info["risk_level"]
|
||||
if risk_level == "高风险":
|
||||
st.error(f"🚨 {message}")
|
||||
elif risk_level == "中风险":
|
||||
st.warning(f"⚠️ {message}")
|
||||
else:
|
||||
st.info(f"✅ {message}")
|
||||
|
||||
st.subheader("📈 数据分析可视化")
|
||||
chart_cols = st.columns(2)
|
||||
with chart_cols[0]:
|
||||
fig_hotspots = px.bar(
|
||||
top_hotspots.head(10),
|
||||
x=top_hotspots.head(10).index,
|
||||
y=["accident_count", "recent_count"],
|
||||
title="事故频次TOP10分布",
|
||||
labels={"value": "事故数量", "variable": "类型", "index": "路口名称"},
|
||||
barmode="group",
|
||||
)
|
||||
fig_hotspots.update_layout(xaxis_tickangle=-45)
|
||||
st.plotly_chart(fig_hotspots, use_container_width=True)
|
||||
|
||||
with chart_cols[1]:
|
||||
risk_distribution = top_hotspots["risk_level"].value_counts()
|
||||
fig_risk = px.pie(
|
||||
values=risk_distribution.values,
|
||||
names=risk_distribution.index,
|
||||
title="风险等级分布",
|
||||
color_discrete_map={"高风险": "red", "中风险": "orange", "低风险": "green"},
|
||||
)
|
||||
st.plotly_chart(fig_risk, use_container_width=True)
|
||||
|
||||
st.subheader("💾 数据导出")
|
||||
download_cols = st.columns(2)
|
||||
with download_cols[0]:
|
||||
hotspot_csv = top_hotspots.to_csv().encode("utf-8-sig")
|
||||
st.download_button(
|
||||
"📥 下载热点数据CSV",
|
||||
data=hotspot_csv,
|
||||
file_name=f"accident_hotspots_{datetime.now().strftime('%Y%m%d')}.csv",
|
||||
mime="text/csv",
|
||||
)
|
||||
|
||||
with download_cols[1]:
|
||||
serializable = serialise_datetime_columns(top_hotspots.reset_index())
|
||||
report_payload = {
|
||||
"analysis_time": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
||||
"time_window": time_window,
|
||||
"data_source": accident_source_name or "事故数据",
|
||||
"total_records": int(len(hotspot_data)),
|
||||
"analysis_parameters": {"min_accidents": int(min_accidents), "top_n": int(top_n)},
|
||||
"top_hotspots": serializable.to_dict("records"),
|
||||
"recommended_strategies": strategies,
|
||||
"summary": {
|
||||
"high_risk_count": int((top_hotspots["risk_level"] == "高风险").sum()),
|
||||
"medium_risk_count": int((top_hotspots["risk_level"] == "中风险").sum()),
|
||||
"total_analyzed_locations": int(len(hotspots)),
|
||||
"most_dangerous_location": top_hotspots.index[0]
|
||||
if len(top_hotspots)
|
||||
else "无",
|
||||
},
|
||||
}
|
||||
st.download_button(
|
||||
"📄 下载完整分析报告",
|
||||
data=json.dumps(report_payload, ensure_ascii=False, indent=2),
|
||||
file_name=f"hotspot_analysis_report_{datetime.now().strftime('%Y%m%d_%H%M')}.json",
|
||||
mime="application/json",
|
||||
)
|
||||
|
||||
with st.expander("📋 查看原始数据预览"):
|
||||
preview_cols = ["事故时间", "所在街道", "事故类型", "事故具体地点", "道路类型"]
|
||||
preview_df = hotspot_data[preview_cols].copy()
|
||||
st.dataframe(preview_df.head(10), use_container_width=True)
|
||||
32
ui_sections/model_eval.py
Normal file
32
ui_sections/model_eval.py
Normal file
@@ -0,0 +1,32 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pandas as pd
|
||||
import streamlit as st
|
||||
|
||||
from services.metrics import evaluate_models
|
||||
|
||||
|
||||
def render_model_eval(base: pd.DataFrame):
|
||||
st.subheader("模型预测效果对比")
|
||||
with st.form(key="model_eval_form"):
|
||||
horizon_sel = st.slider("评估窗口(天)", 7, 60, 30, step=1)
|
||||
submit_eval = st.form_submit_button("应用评估参数")
|
||||
|
||||
if not submit_eval:
|
||||
st.info("请设置评估窗口并点击“应用评估参数”按钮。")
|
||||
return
|
||||
|
||||
try:
|
||||
df_metrics = evaluate_models(base['accident_count'], horizon=int(horizon_sel))
|
||||
st.dataframe(df_metrics, use_container_width=True)
|
||||
best_model = df_metrics['RMSE'].idxmin()
|
||||
st.success(f"过去 {int(horizon_sel)} 天中,RMSE 最低的模型是:**{best_model}**")
|
||||
st.download_button(
|
||||
"下载评估结果 CSV",
|
||||
data=df_metrics.to_csv().encode('utf-8-sig'),
|
||||
file_name="model_evaluation.csv",
|
||||
mime="text/csv",
|
||||
)
|
||||
except ValueError as err:
|
||||
st.warning(str(err))
|
||||
|
||||
33
ui_sections/overview.py
Normal file
33
ui_sections/overview.py
Normal file
@@ -0,0 +1,33 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import pandas as pd
|
||||
import plotly.graph_objects as go
|
||||
import streamlit as st
|
||||
|
||||
def render_overview(base: pd.DataFrame, region_sel: str, start_dt: pd.Timestamp, end_dt: pd.Timestamp,
|
||||
strat_filter: list[str]):
|
||||
fig_line = go.Figure()
|
||||
fig_line.add_trace(go.Scatter(x=base.index, y=base['accident_count'], name='事故数', mode='lines'))
|
||||
fig_line.update_layout(title="事故数(过滤后)", xaxis_title="Date", yaxis_title="Count")
|
||||
st.plotly_chart(fig_line, use_container_width=True)
|
||||
|
||||
html = fig_line.to_html(full_html=True, include_plotlyjs='cdn')
|
||||
st.download_button("下载图表 HTML", data=html.encode('utf-8'),
|
||||
file_name="overview_series.html", mime="text/html")
|
||||
|
||||
st.dataframe(base, use_container_width=True)
|
||||
csv_bytes = base.to_csv(index=True).encode('utf-8-sig')
|
||||
st.download_button("下载当前视图 CSV", data=csv_bytes, file_name="filtered_view.csv", mime="text/csv")
|
||||
|
||||
meta = {
|
||||
"region": region_sel,
|
||||
"date_range": [str(start_dt.date()), str(end_dt.date())],
|
||||
"strategy_filter": strat_filter,
|
||||
"rows": int(len(base)),
|
||||
"min_date": str(base.index.min().date()) if len(base) else None,
|
||||
"max_date": str(base.index.max().date()) if len(base) else None,
|
||||
}
|
||||
st.download_button("下载运行参数 JSON", data=json.dumps(meta, ensure_ascii=False, indent=2).encode('utf-8'),
|
||||
file_name="run_metadata.json", mime="application/json")
|
||||
|
||||
50
ui_sections/strategy_eval.py
Normal file
50
ui_sections/strategy_eval.py
Normal file
@@ -0,0 +1,50 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import pandas as pd
|
||||
import streamlit as st
|
||||
|
||||
from services.strategy import generate_output_and_recommendations
|
||||
|
||||
|
||||
def render_strategy_eval(base: pd.DataFrame, all_strategy_types: list[str], region_sel: str):
|
||||
st.info(f"📌 检测到的策略类型:{', '.join(all_strategy_types) or '(数据中没有策略)'}")
|
||||
if not all_strategy_types:
|
||||
st.warning("数据中没有检测到策略。")
|
||||
return
|
||||
|
||||
with st.form(key="strategy_eval_form"):
|
||||
horizon_eval = st.slider("评估窗口(天)", 7, 60, 14, step=1)
|
||||
submit_strat_eval = st.form_submit_button("应用评估参数")
|
||||
|
||||
if not submit_strat_eval:
|
||||
st.info("请设置评估窗口并点击“应用评估参数”按钮。")
|
||||
return
|
||||
|
||||
results, recommendation = generate_output_and_recommendations(
|
||||
base,
|
||||
all_strategy_types,
|
||||
region=region_sel if region_sel != '全市' else '全市',
|
||||
horizon=horizon_eval,
|
||||
)
|
||||
|
||||
if not results:
|
||||
st.warning("⚠️ 未能完成策略评估。请尝试缩短评估窗口或扩大日期范围。")
|
||||
return
|
||||
|
||||
st.subheader("各策略指标")
|
||||
df_res = pd.DataFrame(results).T
|
||||
st.dataframe(df_res, use_container_width=True)
|
||||
st.success(f"⭐ 推荐:{recommendation}")
|
||||
|
||||
st.download_button(
|
||||
"下载策略评估 CSV",
|
||||
data=df_res.to_csv().encode('utf-8-sig'),
|
||||
file_name="strategy_evaluation_results.csv",
|
||||
mime="text/csv",
|
||||
)
|
||||
|
||||
if os.path.exists('recommendation.txt'):
|
||||
with open('recommendation.txt','r',encoding='utf-8') as f:
|
||||
st.download_button("下载推荐文本", data=f.read().encode('utf-8'), file_name="recommendation.txt")
|
||||
|
||||
Reference in New Issue
Block a user