diff --git a/.gitignore b/.gitignore index 4a5b132..ccfba49 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,8 @@ run_metadata.json simulation.html .DS_Store overview_series.html +tmp/ +__pycache__/ +*.py[cod] +.venv/ +.streamlit/ diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 0000000..1d1eb9e --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,36 @@ +# Repository Guidelines + +## Project Structure & Module Organization +- Source: `app.py` (Streamlit UI, data processing, forecasting, anomaly detection, evaluation). +- Docs & outputs: `docs/`, `overview_series.html`, `strategy_evaluation_results.csv`. +- Samples: `sample/` for example data only; avoid sensitive content. +- Meta: `requirements.txt`, `readme.md`, `LICENSE`, `CHANGELOG.md`. + +## Build, Test, and Development Commands +- Create env: `python -m venv .venv && source .venv/bin/activate` (or follow conda steps in `readme.md`). +- Install deps: `pip install -r requirements.txt`. +- Run app: `streamlit run app.py` then open `http://localhost:8501`. +- Export artifacts: charts save as HTML (Plotly); forecasts may be written to CSV as noted in `readme.md`. + +## Coding Style & Naming Conventions +- Python ≥3.8; 4-space indentation; UTF-8. +- Names: functions/variables `snake_case`; classes `PascalCase`; constants `UPPER_SNAKE_CASE`. +- Files: keep scope focused; use descriptive output names (e.g., `arima_forecast.csv`). +- Data handling: prefer pandas/NumPy vectorization; validate inputs; avoid global state except constants. + +## Testing Guidelines +- Framework: pytest (recommended). Place tests under `tests/`. +- Naming: `test_.py` and `test_()`. +- Run: `pytest -q`. Focus on `load_and_clean_data`, aggregation, model selection, and metrics. +- Keep tests fast and deterministic; avoid large I/O. Use small DataFrame fixtures. + +## Commit & Pull Request Guidelines +- Messages: concise, present tense. Prefixes seen: `modify:`, `Add`, `Update`. +- Include scope and reason: e.g., `modify: update requirements for statsmodels`. +- PRs: clear description, linked issues, repro steps/screenshots for UI, and notes on any schema or output changes. + +## Security & Configuration Tips +- Do not commit real accident data or secrets. Use `sample/` for examples. +- Optional envs: `LOG_LEVEL=DEBUG`. Keep any API keys in environment variables, not in code. +- Validate Excel column names before processing; handle missing columns/rows defensively. + diff --git a/app.py b/app.py index 2c3905b..22d4e1b 100644 --- a/app.py +++ b/app.py @@ -1,23 +1,17 @@ +from __future__ import annotations import os from datetime import datetime, timedelta import json -import hashlib import numpy as np import pandas as pd -import matplotlib.pyplot as plt +from typing import Optional -from sklearn.neighbors import KNeighborsRegressor from sklearn.ensemble import IsolationForest -from sklearn.svm import SVR - -import statsmodels.api as sm -from statsmodels.tsa.arima.model import ARIMA import streamlit as st import plotly.graph_objects as go -import plotly.express as px # --- Optional deps (graceful fallback) try: @@ -40,179 +34,41 @@ except Exception: HAS_OPENAI = False -# ======================= -# 1. Data Integration -# ======================= -@st.cache_data(show_spinner=False) -def load_and_clean_data(accident_file, strategy_file): - accident_df = pd.read_excel(accident_file, sheet_name=None) - accident_data = pd.concat(accident_df.values(), ignore_index=True) +from services.io import ( + load_and_clean_data, + aggregate_daily_data, + aggregate_daily_data_by_region, + load_accident_records, +) +from services.forecast import ( + arima_forecast_with_grid_search, + knn_forecast_counterfactual, + fit_and_extrapolate, +) +from services.strategy import ( + evaluate_strategy_effectiveness, + generate_output_and_recommendations, +) +from services.metrics import evaluate_models - accident_data['事故时间'] = pd.to_datetime(accident_data['事故时间']) - accident_data = accident_data.dropna(subset=['事故时间', '所在街道', '事故类型']) - - strategy_df = pd.read_excel(strategy_file) - strategy_df['发布时间'] = pd.to_datetime(strategy_df['发布时间']) - strategy_df = strategy_df.dropna(subset=['发布时间', '交通策略类型']) - - severity_map = {'财损': 1, '伤人': 2, '亡人': 4} - accident_data['severity'] = accident_data['事故类型'].map(severity_map).fillna(1) - - accident_data = accident_data[['事故时间', '所在街道', '事故类型', 'severity']] \ - .rename(columns={'事故时间': 'date_time', '所在街道': 'region', '事故类型': 'category'}) - strategy_df = strategy_df[['发布时间', '交通策略类型']] \ - .rename(columns={'发布时间': 'date_time', '交通策略类型': 'strategy_type'}) - - return accident_data, strategy_df - - -@st.cache_data(show_spinner=False) -def aggregate_daily_data(accident_data: pd.DataFrame, strategy_data: pd.DataFrame) -> pd.DataFrame: - # City-level aggregation - accident_data = accident_data.copy() - strategy_data = strategy_data.copy() - - accident_data['date'] = accident_data['date_time'].dt.date - daily_accidents = accident_data.groupby('date').agg( - accident_count=('date_time', 'count'), - severity=('severity', 'sum') +try: + from ui_sections import ( + render_overview, + render_forecast, + render_model_eval, + render_strategy_eval, + render_hotspot, ) - daily_accidents.index = pd.to_datetime(daily_accidents.index) - - strategy_data['date'] = strategy_data['date_time'].dt.date - daily_strategies = strategy_data.groupby('date')['strategy_type'].apply(list) - daily_strategies.index = pd.to_datetime(daily_strategies.index) - - combined = daily_accidents.join(daily_strategies, how='left') - combined['strategy_type'] = combined['strategy_type'].apply(lambda x: x if isinstance(x, list) else []) - combined = combined.asfreq('D') - combined[['accident_count', 'severity']] = combined[['accident_count', 'severity']].fillna(0) - combined['strategy_type'] = combined['strategy_type'].apply(lambda x: x if isinstance(x, list) else []) - return combined - - -@st.cache_data(show_spinner=False) -def aggregate_daily_data_by_region(accident_data: pd.DataFrame, strategy_data: pd.DataFrame) -> pd.DataFrame: - """区域维度聚合。策略按天广播到所有区域(若策略本身无区域字段)。""" - df = accident_data.copy() - df['date'] = df['date_time'].dt.date - g = df.groupby(['region', 'date']).agg( - accident_count=('date_time', 'count'), - severity=('severity', 'sum') - ) - g.index = g.index.set_levels([g.index.levels[0], pd.to_datetime(g.index.levels[1])]) - g = g.sort_index() - - # 策略(每日列表) - s = strategy_data.copy() - s['date'] = s['date_time'].dt.date - daily_strategies = s.groupby('date')['strategy_type'].apply(list) - daily_strategies.index = pd.to_datetime(daily_strategies.index) - - # 广播 - regions = g.index.get_level_values(0).unique() - dates = pd.date_range(g.index.get_level_values(1).min(), g.index.get_level_values(1).max(), freq='D') - full_index = pd.MultiIndex.from_product([regions, dates], names=['region', 'date']) - g = g.reindex(full_index).fillna(0) - - strat_map = daily_strategies.to_dict() - g = g.assign(strategy_type=[strat_map.get(d, []) for d in g.index.get_level_values('date')]) - return g - - -from statsmodels.tsa.arima.model import ARIMA -from statsmodels.tools.sm_exceptions import ValueWarning -import warnings - -def evaluate_arima_model(series, arima_order): - """Fit ARIMA model and return AIC for evaluation.""" - try: - model = ARIMA(series, order=arima_order) - model_fit = model.fit() - return model_fit.aic - except Exception: - return float("inf") - -def arima_forecast_with_grid_search(accident_series: pd.Series, start_date: pd.Timestamp, - horizon: int = 30, p_values: list = range(0, 6), - d_values: list = range(0, 2), q_values: list = range(0, 6)) -> pd.DataFrame: - # Pre-process series - series = accident_series.asfreq('D').fillna(0) - start_date = pd.to_datetime(start_date) - - # Suppress warnings - warnings.filterwarnings("ignore", category=ValueWarning) - - # Define the hyperparameters to search through - best_score, best_cfg = float("inf"), None - - for p in p_values: - for d in d_values: - for q in q_values: - order = (p, d, q) - try: - aic = evaluate_arima_model(series, order) - if aic < best_score: - best_score, best_cfg = aic, order - except Exception as e: - continue - - # Fit the model with the best found order - print(best_cfg) - model = ARIMA(series, order=best_cfg) - fit = model.fit() - - # Forecasting - forecast_index = pd.date_range(start=start_date, periods=horizon, freq='D') - res = fit.get_forecast(steps=horizon) - df = res.summary_frame() - df.index = forecast_index - df.index.name = 'date' - df.rename(columns={'mean': 'forecast'}, inplace=True) - - return df - -# Example usage: -# dataframe = your_data_frame_here -# forecast_df = arima_forecast_with_grid_search(dataframe['accident_count'], start_date=pd.Timestamp('YYYY-MM-DD'), horizon=30) - - -def knn_forecast_counterfactual(accident_series: pd.Series, - intervention_date: pd.Timestamp, - lookback: int = 14, - horizon: int = 30): - series = accident_series.asfreq('D').fillna(0) - intervention_date = pd.to_datetime(intervention_date).normalize() - - df = pd.DataFrame({'y': series}) - for i in range(1, lookback + 1): - df[f'lag_{i}'] = df['y'].shift(i) - - train = df.loc[:intervention_date - pd.Timedelta(days=1)].dropna() - if len(train) < 5: - return None, None - X_train = train.filter(like='lag_').values - y_train = train['y'].values - knn = KNeighborsRegressor(n_neighbors=5) - knn.fit(X_train, y_train) - - history = df.loc[:intervention_date - pd.Timedelta(days=1), 'y'].tolist() - preds = [] - for _ in range(horizon): - if len(history) < lookback: - return None, None - x = np.array(history[-lookback:][::-1]).reshape(1, -1) - pred = knn.predict(x)[0] - preds.append(pred) - history.append(pred) - - pred_index = pd.date_range(intervention_date, periods=horizon, freq='D') - return pd.Series(preds, index=pred_index, name='knn_pred'), None - +except Exception: # pragma: no cover - fallback to inline logic + render_overview = None + render_forecast = None + render_model_eval = None + render_strategy_eval = None + render_hotspot = None def detect_anomalies(series: pd.Series, contamination: float = 0.1): series = series.asfreq('D').fillna(0) - iso = IsolationForest(contamination=contamination, random_state=42) + iso = IsolationForest(n_estimators=50, contamination=contamination, random_state=42, n_jobs=-1) yhat = iso.fit_predict(series.values.reshape(-1, 1)) anomaly_mask = (yhat == -1) anomaly_indices = series.index[anomaly_mask] @@ -250,130 +106,11 @@ def intervention_model(series: pd.Series, return Y_t, Z_t -def fit_and_extrapolate(series: pd.Series, - intervention_date: pd.Timestamp, - days: int = 30): - - series = series.asfreq('D').fillna(0) - # 统一为无时区、按天的时间戳 - series.index = pd.to_datetime(series.index).tz_localize(None).normalize() - intervention_date = pd.to_datetime(intervention_date).tz_localize(None).normalize() - - pre = series.loc[:intervention_date - pd.Timedelta(days=1)] - if len(pre) < 5: - return None, None, None - - x_pre = np.arange(len(pre)) - x_future = np.arange(len(pre), len(pre) + days) - - # 1️⃣ GLM:加入二次项 - X_pre_glm = sm.add_constant(np.column_stack([x_pre, x_pre**2])) - glm = sm.GLM(pre.values, X_pre_glm, family=sm.families.Poisson()) - glm_res = glm.fit() - X_future_glm = sm.add_constant(np.column_stack([x_future, x_future**2])) - glm_pred = glm_res.predict(X_future_glm) - - # SVR - # 2️⃣ SVR:加标准化 & 调参 / 改线性核 - from sklearn.preprocessing import StandardScaler - from sklearn.pipeline import make_pipeline - - svr = make_pipeline( - StandardScaler(), - SVR(kernel='rbf', C=10, gamma=0.1) # 或 kernel='linear' - ) - svr.fit(x_pre.reshape(-1, 1), pre.values) - svr_pred = svr.predict(x_future.reshape(-1, 1)) - - # 目标预测索引(未来可能超出历史范围 —— 用 reindex,不要 .loc[...]) - post_index = pd.date_range(intervention_date, periods=days, freq='D') - - glm_pred = pd.Series(glm_pred, index=post_index, name='glm_pred') - svr_pred = pd.Series(svr_pred, index=post_index, name='svr_pred') - - # ✅ 关键修复:对不存在的日期补 NaN,而不是 .loc[post_index] - post = series.reindex(post_index) - - residuals = pd.Series(post.values - svr_pred[:len(post)], - index=post_index, name='residual') - - return glm_pred, svr_pred, residuals - - -def evaluate_strategy_effectiveness(actual_series: pd.Series, - counterfactual_series: pd.Series, - severity_series: pd.Series, - strategy_date: pd.Timestamp, - window: int = 30): - strategy_date = pd.to_datetime(strategy_date) - pre_sev = severity_series.loc[strategy_date - pd.Timedelta(days=window):strategy_date - pd.Timedelta(days=1)].sum() - post_sev = severity_series.loc[strategy_date:strategy_date + pd.Timedelta(days=window - 1)].sum() - actual_post = actual_series.loc[strategy_date:strategy_date + pd.Timedelta(days=window - 1)] - counter_post = counterfactual_series.loc[strategy_date:strategy_date + pd.Timedelta(days=window - 1)] - counter_post = counter_post.reindex(actual_post.index) - effective_days = (actual_post < counter_post).sum() - count_effective = effective_days >= (window / 2) - severity_effective = post_sev < pre_sev - cf_sum = counter_post.sum() - F1 = (cf_sum - actual_post.sum()) / cf_sum if cf_sum > 0 else 0.0 - F2 = (pre_sev - post_sev) / pre_sev if pre_sev > 0 else 0.0 - if F1 > 0.5 and F2 > 0.5: - safety_state = '一级' - elif F1 > 0.3: - safety_state = '二级' - else: - safety_state = '三级' - return count_effective, severity_effective, (F1, F2), safety_state - - -def generate_output_and_recommendations(combined_data: pd.DataFrame, - strategy_types: list, - region: str = '全市', - horizon: int = 30): - results = {} - accident_series = combined_data['accident_count'] - severity_series = combined_data['severity'] - for strategy in strategy_types: - has_strategy = combined_data['strategy_type'].apply(lambda x: strategy in x) - if not has_strategy.any(): - continue - intervention_date = has_strategy[has_strategy].index[0] - glm_pred, svr_pred, residuals = fit_and_extrapolate(accident_series, intervention_date, days=horizon) - if svr_pred is None: - continue - count_eff, sev_eff, (F1, F2), state = evaluate_strategy_effectiveness( - actual_series=accident_series, - counterfactual_series=svr_pred, - severity_series=severity_series, - strategy_date=intervention_date, - window=horizon - ) - results[strategy] = { - 'effect_strength': float(residuals.mean()), - 'adaptability': float(F1 + F2), - 'count_effective': bool(count_eff), - 'severity_effective': bool(sev_eff), - 'safety_state': state, - 'F1': float(F1), - 'F2': float(F2), - 'intervention_date': str(intervention_date.date()) - } - best_strategy = max(results, key=lambda x: results[x]['adaptability']) if results else None - recommendation = f"建议在{region}区域长期实施策略类型 {best_strategy}" if best_strategy else "无足够数据推荐策略" - pd.DataFrame(results).T.to_csv('strategy_evaluation_results.csv', encoding='utf-8-sig') - with open('recommendation.txt', 'w', encoding='utf-8') as f: - f.write(recommendation) - return results, recommendation - - # ======================= # 3. UI Helpers # ======================= -def hash_like(obj: str) -> str: - return hashlib.md5(obj.encode('utf-8')).hexdigest()[:8] - -def compute_kpis(df_city: pd.DataFrame, arima_df: pd.DataFrame | None, +def compute_kpis(df_city: pd.DataFrame, arima_df: Optional[pd.DataFrame], today: pd.Timestamp, window:int=30): # 今日/昨日 today_date = pd.to_datetime(today.date()) @@ -451,283 +188,69 @@ def save_fig_as_html(fig, filename): f.write(html) return filename -import numpy as np -import pandas as pd -from sklearn.metrics import mean_absolute_error, mean_squared_error -from statsmodels.tsa.arima.model import ARIMA -# 依赖:已在脚本前面定义的 knn_forecast_counterfactual() 和 fit_and_extrapolate() -import numpy as np -import pandas as pd -from sklearn.metrics import mean_absolute_error, mean_squared_error -from statsmodels.tsa.arima.model import ARIMA - -# 依赖:knn_forecast_counterfactual、fit_and_extrapolate 已存在 - -def evaluate_models(series: pd.Series, - horizon: int = 30, - lookback: int = 14, - p_values: range = range(0, 6), - d_values: range = range(0, 2), - q_values: range = range(0, 6)) -> pd.DataFrame: - """ - 留出法(最后 horizon 天作为验证集)比较 ARIMA / KNN / GLM / SVR, - 输出 MAE・RMSE・MAPE,并按 RMSE 升序排序。 - """ - # 统一日频 & 缺失补零 - series = series.asfreq('D').fillna(0) - if len(series) <= horizon + 10: - raise ValueError("序列太短,无法留出 %d 天进行评估。" % horizon) - - train, test = series.iloc[:-horizon], series.iloc[-horizon:] - - def _to_series_like(pred, a_index): - # 将任意预测对齐成与 actual 同索引的 Series - if isinstance(pred, pd.Series): - return pred.reindex(a_index) - return pd.Series(pred, index=a_index) - - def _metrics(a: pd.Series, p) -> dict: - p = _to_series_like(p, a.index).astype(float) - a = a.astype(float) - - mae = mean_absolute_error(a, p) - - # 兼容旧版 sklearn:没有 squared 参数时手动开方 - try: - rmse = mean_squared_error(a, p, squared=False) - except TypeError: - rmse = mean_squared_error(a, p) ** 0.5 - - # 忽略分母为 0 的样本 - mape = np.nanmean(np.abs((a - p) / np.where(a == 0, np.nan, a))) * 100 - return {"MAE": mae, "RMSE": rmse, "MAPE": mape} - - results = {} - - # ---------- ARIMA ---------- - best_aic, best_order = float('inf'), None - for p in p_values: - for d in d_values: - for q in q_values: - try: - aic = ARIMA(train, order=(p, d, q)).fit().aic - if aic < best_aic: - best_aic, best_order = aic, (p, d, q) - except Exception: - continue - arima_train = train.asfreq('D').fillna(0) - arima_pred = ARIMA(arima_train, order=best_order).fit().forecast(steps=horizon) - results['ARIMA'] = _metrics(test, arima_pred) - - # ---------- KNN ---------- - try: - knn_pred, _ = knn_forecast_counterfactual(series, - train.index[-1] + pd.Timedelta(days=1), - lookback=lookback, - horizon=horizon) - if knn_pred is not None: - results['KNN'] = _metrics(test, knn_pred) - except Exception: - pass - - # ---------- GLM & SVR ---------- - try: - glm_pred, svr_pred, _ = fit_and_extrapolate(series, - train.index[-1] + pd.Timedelta(days=1), - days=horizon) - if glm_pred is not None: - results['GLM'] = _metrics(test, glm_pred) - if svr_pred is not None: - results['SVR'] = _metrics(test, svr_pred) - except Exception: - pass - - return (pd.DataFrame(results) - .T.sort_values('RMSE') - .round(3)) - - -import re -from collections import Counter -import jieba - -def parse_and_standardize_locations(accident_data): - """解析和标准化事故地点""" - df = accident_data.copy() - - # 提取关键路段信息 - def extract_road_info(location): - if pd.isna(location): - return "未知路段" - - location = str(location) - - # 常见路段关键词 - road_keywords = ['路', '道', '街', '巷', '路口', '交叉口', '大道', '公路'] - area_keywords = ['新城', '临城', '千岛', '翁山', '海天', '海宇', '定沈', '滨海', '港岛', '体育', '长升', '金岛', '桃湾'] - - # 提取包含关键词的路段 - for keyword in road_keywords + area_keywords: - if keyword in location: - # 提取以该关键词为中心的路段名称 - pattern = f'[^,。]*{keyword}[^,。]*' - matches = re.findall(pattern, location) - if matches: - return matches[0].strip() - - return location - - df['standardized_location'] = df['事故具体地点'].apply(extract_road_info) - - # 进一步清理和标准化 - location_mapping = { - '新城千岛路': '千岛路', - '千岛路海天大道': '千岛路海天大道口', - '海天大道千岛路': '千岛路海天大道口', - '新城翁山路': '翁山路', - '翁山路金岛路': '翁山路金岛路口', - # 添加更多标准化映射... - } - - df['standardized_location'] = df['standardized_location'].replace(location_mapping) - - return df - -def analyze_location_frequency(accident_data, time_window='7D'): - """分析地点事故频次""" - df = parse_and_standardize_locations(accident_data) - - # 计算时间窗口 - recent_cutoff = df['事故时间'].max() - pd.Timedelta(time_window) - - # 总体统计 - overall_stats = df.groupby('standardized_location').agg({ - '事故时间': ['count', 'max'], # 事故总数和最近时间 - '事故类型': lambda x: x.mode()[0] if len(x.mode()) > 0 else '财损', - '道路类型': lambda x: x.mode()[0] if len(x.mode()) > 0 else '城区道路', - '路口路段类型': lambda x: x.mode()[0] if len(x.mode()) > 0 else '普通路段' - }) - - # 扁平化列名 - overall_stats.columns = ['accident_count', 'last_accident', 'main_accident_type', 'main_road_type', 'main_intersection_type'] - - # 近期统计 - recent_accidents = df[df['事故时间'] >= recent_cutoff] - recent_stats = recent_accidents.groupby('standardized_location').agg({ - '事故时间': 'count', - '事故类型': lambda x: x.mode()[0] if len(x.mode()) > 0 else '财损' - }).rename(columns={'事故时间': 'recent_count', '事故类型': 'recent_accident_type'}) - - # 合并数据 - result = overall_stats.merge(recent_stats, left_index=True, right_index=True, how='left').fillna(0) - result['recent_count'] = result['recent_count'].astype(int) - - # 计算趋势指标 - result['trend_ratio'] = result['recent_count'] / result['accident_count'] - result['days_since_last'] = (df['事故时间'].max() - result['last_accident']).dt.days - - return result.sort_values(['recent_count', 'accident_count'], ascending=False) - - -def generate_intelligent_strategies(hotspot_df, time_period='本周'): - """生成智能针对性策略""" - strategies = [] - - for location_name, location_data in hotspot_df.iterrows(): - accident_count = location_data['accident_count'] - recent_count = location_data['recent_count'] - accident_type = location_data['main_accident_type'] - road_type = location_data['main_road_type'] - intersection_type = location_data['main_intersection_type'] - trend_ratio = location_data['trend_ratio'] - - # 基础信息 - base_info = f"{time_period}对【{location_name}】" - data_support = f"(近期{int(recent_count)}起,累计{int(accident_count)}起,{accident_type}为主)" - - # 智能策略生成 - strategy_parts = [] - - # 基于事故类型 - if accident_type == '财损': - strategy_parts.append("加强违法查处") - if '信号灯' in intersection_type: - strategy_parts.append("整治闯红灯、不按规定让行") - else: - strategy_parts.append("整治违法变道、超速行驶") - elif accident_type == '伤人': - strategy_parts.append("优化交通组织") - strategy_parts.append("增设安全设施") - if recent_count >= 2: - strategy_parts.append("开展专项整治") - - # 基于路口类型 - if intersection_type == '信号灯路口': - strategy_parts.append("优化信号配时") - elif intersection_type == '非信号灯路口': - strategy_parts.append("完善让行标志") - elif intersection_type == '普通路段': - if trend_ratio > 0.3: # 近期事故占比高 - strategy_parts.append("加强巡逻管控") - - # 基于趋势 - if trend_ratio > 0.5: - strategy_parts.append("列为重点管控路段") - if location_data['days_since_last'] <= 3: - strategy_parts.append("近期需重点关注") - - # 组合策略 - if strategy_parts: - strategy = base_info + "," + ",".join(strategy_parts) + data_support - else: - strategy = base_info + "分析事故成因,制定综合整治方案" + data_support - - strategies.append(strategy) - - return strategies - -def calculate_location_risk_score(hotspot_df): - """计算路口风险评分""" - df = hotspot_df.copy() - - # 事故频次得分 (0-40分) - df['frequency_score'] = (df['accident_count'] / df['accident_count'].max() * 40).clip(0, 40) - - # 近期趋势得分 (0-30分) - df['trend_score'] = (df['trend_ratio'] * 30).clip(0, 30) - - # 事故严重度得分 (0-20分) - severity_map = {'财损': 5, '伤人': 15, '亡人': 20} - df['severity_score'] = df['main_accident_type'].map(severity_map).fillna(5) - - # 时间紧迫度得分 (0-10分) - df['urgency_score'] = ((30 - df['days_since_last']) / 30 * 10).clip(0, 10) - - # 总分 - df['risk_score'] = df['frequency_score'] + df['trend_score'] + df['severity_score'] + df['urgency_score'] - - # 风险等级 - conditions = [ - df['risk_score'] >= 70, - df['risk_score'] >= 50, - df['risk_score'] >= 30 - ] - choices = ['高风险', '中风险', '低风险'] - df['risk_level'] = np.select(conditions, choices, default='一般风险') - - return df.sort_values('risk_score', ascending=False) +# ======================= +# 4. App +# ======================= # ======================= # 4. App # ======================= def run_streamlit_app(): + # Must be the first Streamlit command st.set_page_config(page_title="Traffic Safety Analysis", layout="wide") st.title("🚦 Traffic Safety Intervention Analysis System") # Sidebar — Upload & Global Filters & Auto Refresh st.sidebar.header("数据与筛选") + + default_min_date = pd.to_datetime('2022-01-01').date() + default_max_date = pd.to_datetime('2022-12-31').date() + + def clamp_date_range(requested, minimum, maximum): + """Ensure the requested tuple stays within [minimum, maximum].""" + if not isinstance(requested, (list, tuple)): + requested = (requested, requested) + start, end = requested + if start > end: + start, end = end, start + if end < minimum or start > maximum: + return minimum, maximum + start = max(minimum, start) + end = min(maximum, end) + return start, end + + # Initialize session state to store processed data (before rendering controls) + if 'processed_data' not in st.session_state: + st.session_state['processed_data'] = { + 'combined_city': None, + 'combined_by_region': None, + 'accident_data': None, + 'accident_records': None, + 'strategy_data': None, + 'all_regions': ["全市"], + 'all_strategy_types': [], + 'min_date': default_min_date, + 'max_date': default_max_date, + 'region_sel': "全市", + 'date_range': (default_min_date, default_max_date), + 'strat_filter': [], + 'accident_source_name': None, + } + + sidebar_state = st.session_state['processed_data'] + + available_regions = sidebar_state['all_regions'] if sidebar_state['all_regions'] else ["全市"] + current_region = sidebar_state['region_sel'] if sidebar_state['region_sel'] in available_regions else available_regions[0] + available_strategies = sidebar_state['all_strategy_types'] + current_strategies = [s for s in sidebar_state['strat_filter'] if s in available_strategies] + + min_date = sidebar_state['min_date'] + max_date = sidebar_state['max_date'] + raw_start, raw_end = sidebar_state['date_range'] + start_default = max(min_date, min(raw_start, max_date)) + end_default = max(start_default, min(raw_end, max_date)) # Create a form for data inputs to batch updates with st.sidebar.form(key="data_input_form"): @@ -737,13 +260,24 @@ def run_streamlit_app(): # Global filters st.markdown("---") st.subheader("全局筛选器") - # Placeholder for region selection (will be populated after data is loaded) - region_sel = st.selectbox("区域", options=["全市"], index=0, key="region_select") - # Default date range (will be updated after data is loaded) - min_date = pd.to_datetime('2022-01-01').date() - max_date = pd.to_datetime('2022-12-31').date() - date_range = st.date_input("时间范围", value=(min_date, max_date), min_value=min_date, max_value=max_date) - strat_filter = st.multiselect("策略类型(过滤)", options=[], help="为空表示不过滤策略;选择后仅保留当天包含所选策略的日期") + region_sel = st.selectbox( + "区域", + options=available_regions, + index=available_regions.index(current_region), + key="region_select", + ) + date_range = st.date_input( + "时间范围", + value=(start_default, end_default), + min_value=min_date, + max_value=max_date, + ) + strat_filter = st.multiselect( + "策略类型(过滤)", + options=available_strategies, + default=current_strategies, + help="为空表示不过滤策略;选择后仅保留当天包含所选策略的日期", + ) # Apply button for data loading and filtering apply_button = st.form_submit_button("应用数据与筛选") @@ -761,29 +295,14 @@ def run_streamlit_app(): # Add OpenAI API key input in sidebar st.sidebar.markdown("---") st.sidebar.subheader("GPT API 配置") - openai_api_key = st.sidebar.text_input("GPT API Key", value='sk-dQhKOOG48iVEfgJfAb14458dA4474fB09aBbE8153d4aB3Fc', type="password", help="用于GPT分析结果的API密钥") - open_ai_base_url = st.sidebar.text_input("GPT Base Url", value='https://az.gptplus5.com/v1', type='default') - - # Initialize session state to store processed data - if 'processed_data' not in st.session_state: - st.session_state['processed_data'] = { - 'combined_city': None, - 'combined_by_region': None, - 'accident_data': None, - 'strategy_data': None, - 'all_regions': ["全市"], - 'all_strategy_types': [], - 'min_date': min_date, - 'max_date': max_date, - 'region_sel': "全市", - 'date_range': (min_date, max_date), - 'strat_filter': [] - } + openai_api_key = st.sidebar.text_input("GPT API Key", value='sk-sXY934yPqjh7YKKC08380b198fEb47308cDa09BeE23d9c8a', type="password", help="用于GPT分析结果的API密钥") + open_ai_base_url = st.sidebar.text_input("GPT Base Url", value='https://aihubmix.com/v1', type='default') # Process data only when Apply button is clicked if apply_button and accident_file and strategy_file: with st.spinner("数据载入中…"): # Load and clean data + accident_records = load_accident_records(accident_file, require_location=True) accident_data, strategy_data = load_and_clean_data(accident_file, strategy_file) combined_city = aggregate_daily_data(accident_data, strategy_data) combined_by_region = aggregate_daily_data_by_region(accident_data, strategy_data) @@ -795,24 +314,35 @@ def run_streamlit_app(): max_date = combined_city.index.max().date() # Store processed data in session state + sanitized_start, sanitized_end = clamp_date_range(date_range, min_date, max_date) st.session_state['processed_data'].update({ 'combined_city': combined_city, 'combined_by_region': combined_by_region, 'accident_data': accident_data, + 'accident_records': accident_records, 'strategy_data': strategy_data, 'all_regions': all_regions, 'all_strategy_types': all_strategy_types, 'min_date': min_date, 'max_date': max_date, 'region_sel': region_sel, - 'date_range': date_range, - 'strat_filter': strat_filter + 'date_range': (sanitized_start, sanitized_end), + 'strat_filter': strat_filter, + 'accident_source_name': getattr(accident_file, "name", "事故数据.xlsx"), }) + sanitized_start, sanitized_end = clamp_date_range(date_range, min_date, max_date) + + # Persist the latest sidebar selections for display and downstream filtering + st.session_state['processed_data']['region_sel'] = region_sel + st.session_state['processed_data']['date_range'] = (sanitized_start, sanitized_end) + st.session_state['processed_data']['strat_filter'] = strat_filter + # Retrieve data from session state combined_city = st.session_state['processed_data']['combined_city'] combined_by_region = st.session_state['processed_data']['combined_by_region'] accident_data = st.session_state['processed_data']['accident_data'] + accident_records = st.session_state['processed_data']['accident_records'] strategy_data = st.session_state['processed_data']['strategy_data'] all_regions = st.session_state['processed_data']['all_regions'] all_strategy_types = st.session_state['processed_data']['all_strategy_types'] @@ -821,6 +351,7 @@ def run_streamlit_app(): region_sel = st.session_state['processed_data']['region_sel'] date_range = st.session_state['processed_data']['date_range'] strat_filter = st.session_state['processed_data']['strat_filter'] + accident_source_name = st.session_state['processed_data']['accident_source_name'] # Update selectbox and multiselect options dynamically (outside the form for display) st.sidebar.markdown("---") @@ -871,515 +402,137 @@ def run_streamlit_app(): with meta_col2: st.caption(f"🕒 最近刷新:{last_refresh.strftime('%Y-%m-%d %H:%M:%S')}") - # Tabs (add new tab for GPT analysis) - tab_dash, tab_pred, tab_eval, tab_anom, tab_strat, tab_comp, tab_sim, tab_gpt, tab_hotspot = st.tabs( - ["🏠 总览", "📈 预测模型", "📊 模型评估", "⚠️ 异常检测", "📝 策略评估", "⚖️ 策略对比", "🧪 情景模拟", "🔍 GPT 分析", "📍 事故热点"] + tab_labels = [ + "🏠 总览", + "📈 预测模型", + "📊 模型评估", + "⚠️ 异常检测", + "📝 策略评估", + "⚖️ 策略对比", + "🧪 情景模拟", + "🔍 GPT 分析", + "📍 事故热点", + ] + default_tab = st.session_state.get("active_tab", tab_labels[0]) + if default_tab not in tab_labels: + default_tab = tab_labels[0] + selected_tab = st.radio( + "功能分区", + tab_labels, + index=tab_labels.index(default_tab), + horizontal=True, + label_visibility="collapsed", ) + st.session_state["active_tab"] = selected_tab - with tab_hotspot: - st.header("📍 事故多发路口分析") - st.markdown("独立分析事故数据,识别高风险路口并生成针对性策略") - - # 独立文件上传 - st.subheader("📁 数据上传") - hotspot_file = st.file_uploader("上传事故数据文件", type=['xlsx'], key="hotspot_uploader") - - if hotspot_file is not None: - try: - # 加载数据 - @st.cache_data(show_spinner=False) - def load_hotspot_data(uploaded_file): - """独立加载事故热点分析数据""" - df = pd.read_excel(uploaded_file, sheet_name=None) - accident_data = pd.concat(df.values(), ignore_index=True) - - # 数据清洗和预处理 - accident_data['事故时间'] = pd.to_datetime(accident_data['事故时间']) - accident_data = accident_data.dropna(subset=['事故时间', '所在街道', '事故类型', '事故具体地点']) - - # 添加严重度评分 - severity_map = {'财损': 1, '伤人': 2, '亡人': 4} - accident_data['severity'] = accident_data['事故类型'].map(severity_map).fillna(1) - - return accident_data - - with st.spinner("正在加载数据..."): - accident_data = load_hotspot_data(hotspot_file) - - # 显示数据概览 - st.success(f"✅ 成功加载数据:{len(accident_data)} 条事故记录") - - col1, col2, col3 = st.columns(3) - with col1: - st.metric("数据时间范围", - f"{accident_data['事故时间'].min().strftime('%Y-%m-%d')} 至 {accident_data['事故时间'].max().strftime('%Y-%m-%d')}") - with col2: - st.metric("事故类型分布", - f"财损: {len(accident_data[accident_data['事故类型']=='财损'])}起") - with col3: - st.metric("涉及区域", - f"{accident_data['所在街道'].nunique()}个街道") - - # 地点标准化函数(独立版本) - def standardize_hotspot_locations(df): - """标准化事故地点""" - df = df.copy() - - def extract_road_info(location): - if pd.isna(location): - return "未知路段" - - location = str(location) - - # 常见路段关键词 - road_keywords = ['路', '道', '街', '巷', '路口', '交叉口', '大道', '公路', '口'] - area_keywords = ['新城', '临城', '千岛', '翁山', '海天', '海宇', '定沈', '滨海', '港岛', '体育', '长升', '金岛', '桃湾'] - - # 提取包含关键词的路段 - for keyword in road_keywords + area_keywords: - if keyword in location: - # 简化地点名称 - words = location.split() - for word in words: - if keyword in word: - return word - return location - - # 如果没找到关键词,返回原地点(截断过长的) - return location[:20] if len(location) > 20 else location - - df['standardized_location'] = df['事故具体地点'].apply(extract_road_info) - - # 手动标准化映射(根据实际数据调整) - location_mapping = { - '新城千岛路': '千岛路', - '千岛路海天大道': '千岛路海天大道口', - '海天大道千岛路': '千岛路海天大道口', - '新城翁山路': '翁山路', - '翁山路金岛路': '翁山路金岛路口', - '海天大道临长路': '海天大道临长路口', - '定沈路卫生医院门口': '定沈路医院段', - '翁山路海城路西口': '翁山路海城路口', - '海宇道路口': '海宇道', - '海天大道路口': '海天大道', - '定沈路交叉路口': '定沈路', - '千岛路路口': '千岛路', - '体育路路口': '体育路', - '金岛路路口': '金岛路', - } - - df['standardized_location'] = df['standardized_location'].replace(location_mapping) - - return df - - # 热点分析函数 - def analyze_hotspot_frequency(df, time_window='7D'): - """分析地点事故频次""" - df = standardize_hotspot_locations(df) - - # 计算时间窗口 - recent_cutoff = df['事故时间'].max() - pd.Timedelta(time_window) - - # 总体统计 - overall_stats = df.groupby('standardized_location').agg({ - '事故时间': ['count', 'max'], - '事故类型': lambda x: x.mode()[0] if len(x.mode()) > 0 else '财损', - '道路类型': lambda x: x.mode()[0] if len(x.mode()) > 0 else '城区道路', - '路口路段类型': lambda x: x.mode()[0] if len(x.mode()) > 0 else '普通路段', - 'severity': 'sum' - }) - - # 扁平化列名 - overall_stats.columns = ['accident_count', 'last_accident', 'main_accident_type', - 'main_road_type', 'main_intersection_type', 'total_severity'] - - # 近期统计 - recent_accidents = df[df['事故时间'] >= recent_cutoff] - recent_stats = recent_accidents.groupby('standardized_location').agg({ - '事故时间': 'count', - '事故类型': lambda x: x.mode()[0] if len(x.mode()) > 0 else '财损', - 'severity': 'sum' - }).rename(columns={'事故时间': 'recent_count', '事故类型': 'recent_accident_type', 'severity': 'recent_severity'}) - - # 合并数据 - result = overall_stats.merge(recent_stats, left_index=True, right_index=True, how='left').fillna(0) - result['recent_count'] = result['recent_count'].astype(int) - - # 计算趋势指标 - result['trend_ratio'] = result['recent_count'] / result['accident_count'] - result['days_since_last'] = (df['事故时间'].max() - result['last_accident']).dt.days - result['avg_severity'] = result['total_severity'] / result['accident_count'] - - return result.sort_values(['recent_count', 'accident_count'], ascending=False) - - # 风险评分函数 - def calculate_hotspot_risk_score(hotspot_df): - """计算路口风险评分""" - df = hotspot_df.copy() - - # 事故频次得分 (0-40分) - df['frequency_score'] = (df['accident_count'] / df['accident_count'].max() * 40).clip(0, 40) - - # 近期趋势得分 (0-30分) - df['trend_score'] = (df['trend_ratio'] * 30).clip(0, 30) - - # 事故严重度得分 (0-20分) - severity_map = {'财损': 5, '伤人': 15, '亡人': 20} - df['severity_score'] = df['main_accident_type'].map(severity_map).fillna(5) - - # 时间紧迫度得分 (0-10分) - df['urgency_score'] = ((30 - df['days_since_last']) / 30 * 10).clip(0, 10) - - # 总分 - df['risk_score'] = df['frequency_score'] + df['trend_score'] + df['severity_score'] + df['urgency_score'] - - # 风险等级 - conditions = [ - df['risk_score'] >= 70, - df['risk_score'] >= 50, - df['risk_score'] >= 30 - ] - choices = ['高风险', '中风险', '低风险'] - df['risk_level'] = np.select(conditions, choices, default='一般风险') - - return df.sort_values('risk_score', ascending=False) - - # 策略生成函数 - def generate_hotspot_strategies(hotspot_df, time_period='本周'): - """生成热点针对性策略""" - strategies = [] - - for location_name, location_data in hotspot_df.iterrows(): - accident_count = location_data['accident_count'] - recent_count = location_data['recent_count'] - accident_type = location_data['main_accident_type'] - intersection_type = location_data['main_intersection_type'] - trend_ratio = location_data['trend_ratio'] - risk_level = location_data['risk_level'] - - # 基础信息 - base_info = f"{time_period}对【{location_name}】" - data_support = f"(近期{int(recent_count)}起,累计{int(accident_count)}起,{accident_type}为主)" - - # 智能策略生成 - strategy_parts = [] - - # 基于路口类型和事故类型 - if '信号灯' in str(intersection_type): - if accident_type == '财损': - strategy_parts.extend(["加强闯红灯查处", "优化信号配时", "整治不按规定让行"]) - else: - strategy_parts.extend(["完善人行过街设施", "加强非机动车管理", "设置警示标志"]) - elif '普通路段' in str(intersection_type): - strategy_parts.extend(["加强巡逻管控", "整治违法停车", "设置限速标志"]) - else: - strategy_parts.extend(["分析事故成因", "制定综合整治方案"]) - - # 基于风险等级 - if risk_level == '高风险': - strategy_parts.append("列为重点整治路段") - strategy_parts.append("开展专项整治行动") - elif risk_level == '中风险': - strategy_parts.append("加强日常监管") - - # 基于趋势 - if trend_ratio > 0.4: - strategy_parts.append("近期重点监控") - - # 组合策略 - if strategy_parts: - strategy = base_info + "," + ",".join(strategy_parts) + data_support - else: - strategy = base_info + "加强交通安全管理" + data_support - - strategies.append({ - 'location': location_name, - 'strategy': strategy, - 'risk_level': risk_level, - 'accident_count': accident_count, - 'recent_count': recent_count - }) - - return strategies - - # 分析参数设置 - st.subheader("🔧 分析参数设置") - col1, col2, col3 = st.columns(3) - with col1: - time_window = st.selectbox("统计时间窗口", ['7D', '15D', '30D'], index=0, key="hotspot_window") - with col2: - min_accidents = st.number_input("最小事故数", 1, 50, 3, key="hotspot_min_accidents") - with col3: - top_n = st.slider("显示热点数量", 3, 20, 8, key="hotspot_top_n") - - if st.button("🚀 开始热点分析", type="primary"): - with st.spinner("正在分析事故热点分布..."): - # 执行热点分析 - hotspots = analyze_hotspot_frequency(accident_data, time_window=time_window) - - # 过滤最小事故数 - hotspots = hotspots[hotspots['accident_count'] >= min_accidents] - - if len(hotspots) > 0: - # 计算风险评分 - hotspots_with_risk = calculate_hotspot_risk_score(hotspots.head(top_n * 3)) - top_hotspots = hotspots_with_risk.head(top_n) - - # 显示热点排名 - st.subheader(f"📊 事故多发路口排名(前{top_n}个)") - - display_df = top_hotspots[[ - 'accident_count', 'recent_count', 'trend_ratio', - 'main_accident_type', 'main_intersection_type', 'risk_score', 'risk_level' - ]].rename(columns={ - 'accident_count': '累计事故', - 'recent_count': '近期事故', - 'trend_ratio': '趋势比例', - 'main_accident_type': '主要类型', - 'main_intersection_type': '路口类型', - 'risk_score': '风险评分', - 'risk_level': '风险等级' - }) - - # 格式化显示 - styled_df = display_df.style.format({ - '趋势比例': '{:.2f}', - '风险评分': '{:.1f}' - }).background_gradient(subset=['风险评分'], cmap='Reds') - - st.dataframe(styled_df, use_container_width=True) - - # 生成策略建议 - strategies = generate_hotspot_strategies(top_hotspots, time_period='本周') - - st.subheader("🎯 针对性策略建议") - - for i, strategy_info in enumerate(strategies, 1): - strategy = strategy_info['strategy'] - risk_level = strategy_info['risk_level'] - - # 根据风险等级显示不同颜色 - if risk_level == '高风险': - st.error(f"🚨 **{i}. {strategy}**") - elif risk_level == '中风险': - st.warning(f"⚠️ **{i}. {strategy}**") - else: - st.info(f"✅ **{i}. {strategy}**") - - # 可视化分析 - st.subheader("📈 数据分析可视化") - - col1, col2 = st.columns(2) - - with col1: - # 事故频次分布图 - fig1 = px.bar( - top_hotspots.head(10), - x=top_hotspots.head(10).index, - y=['accident_count', 'recent_count'], - title="事故频次TOP10分布", - labels={'value': '事故数量', 'variable': '类型', 'index': '路口名称'}, - barmode='group' - ) - fig1.update_layout(xaxis_tickangle=-45) - st.plotly_chart(fig1, use_container_width=True) - - with col2: - # 风险等级分布 - risk_dist = top_hotspots['risk_level'].value_counts() - fig2 = px.pie( - values=risk_dist.values, - names=risk_dist.index, - title="风险等级分布", - color_discrete_map={'高风险': 'red', '中风险': 'orange', '低风险': 'green'} - ) - st.plotly_chart(fig2, use_container_width=True) - - # 详细数据下载 - st.subheader("💾 数据导出") - - col_dl1, col_dl2 = st.columns(2) - - with col_dl1: - # 下载热点数据 - hotspot_csv = top_hotspots.to_csv().encode('utf-8-sig') - st.download_button( - "📥 下载热点数据CSV", - data=hotspot_csv, - file_name=f"accident_hotspots_{datetime.now().strftime('%Y%m%d')}.csv", - mime="text/csv" - ) - - with col_dl2: - # 下载策略报告 - report_data = { - "analysis_time": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), - "time_window": time_window, - "data_source": hotspot_file.name, - "total_records": len(accident_data), - "analysis_parameters": { - "min_accidents": min_accidents, - "top_n": top_n - }, - "top_hotspots": top_hotspots.to_dict('records'), - "recommended_strategies": strategies, - "summary": { - "high_risk_count": len(top_hotspots[top_hotspots['risk_level'] == '高风险']), - "medium_risk_count": len(top_hotspots[top_hotspots['risk_level'] == '中风险']), - "total_analyzed_locations": len(hotspots), - "most_dangerous_location": top_hotspots.index[0] if len(top_hotspots) > 0 else "无" - } - } - - st.download_button( - "📄 下载完整分析报告", - data=json.dumps(report_data, ensure_ascii=False, indent=2), - file_name=f"hotspot_analysis_report_{datetime.now().strftime('%Y%m%d_%H%M')}.json", - mime="application/json" - ) - - else: - st.warning("⚠️ 未找到符合条件的事故热点数据,请调整筛选参数") - - # 显示原始数据预览(可选) - with st.expander("📋 查看原始数据预览"): - st.dataframe(accident_data[['事故时间', '所在街道', '事故类型', '事故具体地点', '道路类型']].head(10), - use_container_width=True) - - except Exception as e: - st.error(f"❌ 数据加载失败:{str(e)}") - st.info("请检查文件格式是否正确,确保包含'事故时间'、'事故类型'、'事故具体地点'等必要字段") - + if selected_tab == "📍 事故热点": + if render_hotspot is not None: + render_hotspot(accident_records, accident_source_name) else: - st.info("👆 请上传事故数据Excel文件开始分析") - st.markdown(""" - ### 📝 支持的数据格式要求: - - **文件格式**: Excel (.xlsx) - - **必要字段**: - - `事故时间`: 事故发生时的时间 - - `事故类型`: 财损/伤人/亡人 - - `事故具体地点`: 详细的事故发生地点 - - `所在街道`: 事故发生的街道区域 - - `道路类型`: 城区道路/其他等 - - `路口路段类型`: 信号灯路口/普通路段等 - """) - # --- Tab 1: 总览页 - with tab_dash: - fig_line = go.Figure() - fig_line.add_trace(go.Scatter(x=base.index, y=base['accident_count'], name='事故数', mode='lines')) - fig_line.update_layout(title="事故数(过滤后)", xaxis_title="Date", yaxis_title="Count") - st.plotly_chart(fig_line, use_container_width=True) - fname = save_fig_as_html(fig_line, "overview_series.html") - st.download_button("下载图表 HTML", data=open(fname, 'rb').read(), - file_name="overview_series.html", mime="text/html") + st.warning("事故热点模块未能加载,请检查 `ui_sections/hotspot.py`。") - st.dataframe(base, use_container_width=True) - csv_bytes = base.to_csv(index=True).encode('utf-8-sig') - st.download_button("下载当前视图 CSV", data=csv_bytes, file_name="filtered_view.csv", mime="text/csv") + elif selected_tab == "🏠 总览": + if render_overview is not None: + render_overview(base, region_sel, start_dt, end_dt, strat_filter) + else: + st.warning("概览模块未能加载,请检查 `ui_sections/overview.py`。") - meta = { - "region": region_sel, - "date_range": [str(start_dt.date()), str(end_dt.date())], - "strategy_filter": strat_filter, - "rows": int(len(base)), - "min_date": str(base.index.min().date()) if len(base) else None, - "max_date": str(base.index.max().date()) if len(base) else None - } - with open("run_metadata.json", "w", encoding="utf-8") as f: - json.dump(meta, f, ensure_ascii=False, indent=2) - st.download_button("下载运行参数 JSON", data=open("run_metadata.json", "rb").read(), - file_name="run_metadata.json", mime="application/json") + elif selected_tab == "📈 预测模型": + if render_forecast is not None: + render_forecast(base) + else: + st.subheader("多模型预测比较") + # 使用表单封装交互组件 + with st.form(key="predict_form"): + # 缩短默认回溯窗口,提升首次渲染速度 + default_date = base.index.max() - pd.Timedelta(days=30) if len(base) else pd.Timestamp('2022-01-01') + selected_date = st.date_input("选择干预日期 / 预测起点", value=default_date) + horizon = st.number_input("预测天数", min_value=7, max_value=90, value=30, step=1) + submit_predict = st.form_submit_button("应用预测参数") - # --- Tab 2: 预测模型 - with tab_pred: - st.subheader("多模型预测比较") - # 使用表单封装交互组件 - with st.form(key="predict_form"): - default_date = base.index.max() - pd.Timedelta(days=60) if len(base) else pd.Timestamp('2022-01-01') - selected_date = st.date_input("选择干预日期 / 预测起点", value=default_date) - horizon = st.number_input("预测天数", min_value=7, max_value=90, value=30, step=1) - submit_predict = st.form_submit_button("应用预测参数") + if submit_predict and len(base.loc[:pd.to_datetime(selected_date)]) >= 10: + first_date = pd.to_datetime(selected_date) + try: + train_series = base['accident_count'].loc[:first_date] + arima30 = arima_forecast_with_grid_search( + train_series, + start_date=first_date + pd.Timedelta(days=1), + horizon=horizon + ) + except Exception as e: + st.warning(f"ARIMA 运行失败:{e}") + arima30 = None - if submit_predict and len(base.loc[:pd.to_datetime(selected_date)]) >= 10: - first_date = pd.to_datetime(selected_date) - try: - train_series = base['accident_count'].loc[:first_date] - arima30 = arima_forecast_with_grid_search( - train_series, - start_date=first_date + pd.Timedelta(days=1), - horizon=horizon + knn_pred, _ = knn_forecast_counterfactual(base['accident_count'], + first_date, + horizon=horizon) + glm_pred, svr_pred, residuals = fit_and_extrapolate(base['accident_count'], + first_date, + days=horizon) + + fig_pred = go.Figure() + fig_pred.add_trace(go.Scatter(x=base.index, y=base['accident_count'], + name="实际", mode="lines")) + if arima30 is not None: + fig_pred.add_trace(go.Scatter(x=arima30.index, y=arima30['forecast'], + name="ARIMA", mode="lines")) + if knn_pred is not None: + fig_pred.add_trace(go.Scatter(x=knn_pred.index, y=knn_pred, + name="KNN", mode="lines")) + if glm_pred is not None: + fig_pred.add_trace(go.Scatter(x=glm_pred.index, y=glm_pred, + name="GLM", mode="lines")) + if svr_pred is not None: + fig_pred.add_trace(go.Scatter(x=svr_pred.index, y=svr_pred, + name="SVR", mode="lines")) + + fig_pred.update_layout( + title=f"多模型预测比较(起点:{first_date.date()},预测 {horizon} 天)", + xaxis_title="日期", yaxis_title="事故数" ) - except Exception as e: - st.warning(f"ARIMA 运行失败:{e}") - arima30 = None + st.plotly_chart(fig_pred, use_container_width=True) - knn_pred, _ = knn_forecast_counterfactual(base['accident_count'], - first_date, - horizon=horizon) - glm_pred, svr_pred, residuals = fit_and_extrapolate(base['accident_count'], - first_date, - days=horizon) - - fig_pred = go.Figure() - fig_pred.add_trace(go.Scatter(x=base.index, y=base['accident_count'], - name="实际", mode="lines")) - if arima30 is not None: - fig_pred.add_trace(go.Scatter(x=arima30.index, y=arima30['forecast'], - name="ARIMA", mode="lines")) - if knn_pred is not None: - fig_pred.add_trace(go.Scatter(x=knn_pred.index, y=knn_pred, - name="KNN", mode="lines")) - if glm_pred is not None: - fig_pred.add_trace(go.Scatter(x=glm_pred.index, y=glm_pred, - name="GLM", mode="lines")) - if svr_pred is not None: - fig_pred.add_trace(go.Scatter(x=svr_pred.index, y=svr_pred, - name="SVR", mode="lines")) - - fig_pred.update_layout( - title=f"多模型预测比较(起点:{first_date.date()},预测 {horizon} 天)", - xaxis_title="日期", yaxis_title="事故数" - ) - st.plotly_chart(fig_pred, use_container_width=True) - - col_dl1, col_dl2 = st.columns(2) - if arima30 is not None: - col_dl1.download_button("下载 ARIMA 预测 CSV", - data=arima30.to_csv().encode("utf-8-sig"), - file_name="arima_forecast.csv", - mime="text/csv") - elif submit_predict: - st.info("⚠️ 干预前数据较少,可能影响拟合质量。") - else: - st.info("请设置预测参数并点击“应用预测参数”按钮。") + col_dl1, col_dl2 = st.columns(2) + if arima30 is not None: + col_dl1.download_button("下载 ARIMA 预测 CSV", + data=arima30.to_csv().encode("utf-8-sig"), + file_name="arima_forecast.csv", + mime="text/csv") + elif submit_predict: + st.info("⚠️ 干预前数据较少,可能影响拟合质量。") + else: + st.info("请设置预测参数并点击“应用预测参数”按钮。") # --- Tab 3: 模型评估 - with tab_eval: - st.subheader("模型预测效果对比") - with st.form(key="model_eval_form"): - horizon_sel = st.slider("评估窗口(天)", 7, 60, 30, step=1) - submit_eval = st.form_submit_button("应用评估参数") - - if submit_eval: - try: - df_metrics = evaluate_models(base['accident_count'], horizon=horizon_sel) - st.dataframe(df_metrics, use_container_width=True) - best_model = df_metrics['RMSE'].idxmin() - st.success(f"过去 {horizon_sel} 天中,RMSE 最低的模型是:**{best_model}**") - st.download_button( - "下载评估结果 CSV", - data=df_metrics.to_csv().encode('utf-8-sig'), - file_name="model_evaluation.csv", - mime="text/csv" - ) - except ValueError as err: - st.warning(str(err)) + elif selected_tab == "📊 模型评估": + if render_model_eval is not None: + render_model_eval(base) else: - st.info("请设置评估窗口并点击“应用评估参数”按钮。") + st.subheader("模型预测效果对比") + with st.form(key="model_eval_form"): + horizon_sel = st.slider("评估窗口(天)", 7, 60, 30, step=1) + submit_eval = st.form_submit_button("应用评估参数") + + if submit_eval: + try: + df_metrics = evaluate_models(base['accident_count'], horizon=horizon_sel) + st.dataframe(df_metrics, use_container_width=True) + best_model = df_metrics['RMSE'].idxmin() + st.success(f"过去 {horizon_sel} 天中,RMSE 最低的模型是:**{best_model}**") + st.download_button( + "下载评估结果 CSV", + data=df_metrics.to_csv().encode('utf-8-sig'), + file_name="model_evaluation.csv", + mime="text/csv" + ) + except ValueError as err: + st.warning(str(err)) + else: + st.info("请设置评估窗口并点击“应用评估参数”按钮。") # --- Tab 4: 异常检测 - with tab_anom: + elif selected_tab == "⚠️ 异常检测": anomalies, anomaly_fig = detect_anomalies(base['accident_count']) st.plotly_chart(anomaly_fig, use_container_width=True) st.write(f"检测到异常点:{len(anomalies)} 个") @@ -1388,25 +541,14 @@ def run_streamlit_app(): file_name="anomalies.csv", mime="text/csv") # --- Tab 5: 策略评估 - with tab_strat: - st.info(f"📌 检测到的策略类型:{', '.join(all_strategy_types) or '(数据中没有策略)'}") - if all_strategy_types: - results, recommendation = generate_output_and_recommendations(base, all_strategy_types, - region=region_sel if region_sel!='全市' else '全市') - st.subheader("各策略指标") - df_res = pd.DataFrame(results).T - st.dataframe(df_res, use_container_width=True) - st.success(f"⭐ 推荐:{recommendation}") - st.download_button("下载策略评估 CSV", - data=df_res.to_csv().encode('utf-8-sig'), - file_name="strategy_evaluation_results.csv", mime="text/csv") - with open('recommendation.txt','r',encoding='utf-8') as f: - st.download_button("下载推荐文本", data=f.read().encode('utf-8'), file_name="recommendation.txt") + elif selected_tab == "📝 策略评估": + if render_strategy_eval is not None: + render_strategy_eval(base, all_strategy_types, region_sel) else: - st.warning("数据中没有检测到策略。") + st.warning("策略评估模块不可用,请检查 `ui_sections/strategy_eval.py`。") # --- Tab 6: 策略对比 - with tab_comp: + elif selected_tab == "⚖️ 策略对比": def strategy_metrics(strategy): mask = base['strategy_type'].apply(lambda x: strategy in x) if not mask.any(): @@ -1474,7 +616,7 @@ def run_streamlit_app(): st.warning("没有策略可供对比。") # --- Tab 7: 情景模拟 - with tab_sim: + elif selected_tab == "🧪 情景模拟": st.subheader("情景模拟") st.write("选择一个日期与策略,模拟“在该日期上线该策略”的影响:") with st.form(key="simulation_form"): @@ -1511,7 +653,7 @@ def run_streamlit_app(): st.info("请设置模拟参数并点击“应用模拟参数”按钮。") # --- New Tab 8: GPT 分析 - with tab_gpt: + elif selected_tab == "🔍 GPT 分析": from openai import OpenAI st.subheader("GPT 数据分析与改进建议") # open_ai_key = f"sk-dQhKOOG48iVEfgJfAb14458dA4474fB09aBbE8153d4aB3Fc" @@ -1542,27 +684,29 @@ def run_streamlit_app(): 提供数据结果的详细分析,以及改进思路和建议。 数据:{str(data_str)} """) - #st.text_area(prompt) if st.button("上传数据至 GPT 并获取分析"): - try: - client = OpenAI( - base_url=open_ai_base_url, - # sk-xxx替换为自己的key - api_key=openai_api_key - ) - response = client.chat.completions.create( - model="gpt-4o", - messages=[ - {"role": "system", "content": "You are a helpful assistant that analyzes traffic safety data."}, - {"role": "user", "content": prompt} - ], - stream=False - ) - gpt_response = response.choices[0].message.content - st.markdown("### GPT 分析结果与改进思路") - st.markdown(gpt_response, unsafe_allow_html=True) - except Exception as e: - st.error(f"调用 OpenAI API 失败:{str(e)}") + if False: + st.info("请将 GPT Base Url 更新为实际可访问的接口地址。") + else: + try: + client = OpenAI( + base_url=open_ai_base_url, + # sk-xxx替换为自己的key + api_key=openai_api_key + ) + response = client.chat.completions.create( + model="gpt-5-mini", + messages=[ + {"role": "system", "content": "You are a helpful assistant that analyzes traffic safety data."}, + {"role": "user", "content": prompt} + ], + stream=False + ) + gpt_response = response.choices[0].message.content + st.markdown("### GPT 分析结果与改进思路") + st.markdown(gpt_response, unsafe_allow_html=True) + except Exception as e: + st.error(f"调用 OpenAI API 失败:{str(e)}") else: st.warning("没有策略数据可供分析。") @@ -1573,4 +717,4 @@ def run_streamlit_app(): st.info("请先在左侧上传事故数据与策略数据,并点击“应用数据与筛选”按钮。") if __name__ == "__main__": - run_streamlit_app() \ No newline at end of file + run_streamlit_app() diff --git a/config/settings.py b/config/settings.py new file mode 100644 index 0000000..890bdc4 --- /dev/null +++ b/config/settings.py @@ -0,0 +1,21 @@ +from __future__ import annotations + +# Default configuration and feature flags + +# Forecasting +ARIMA_P = range(0, 4) +ARIMA_D = range(0, 2) +ARIMA_Q = range(0, 4) + +DEFAULT_HORIZON_PREDICT = 30 +DEFAULT_HORIZON_EVAL = 14 +MIN_PRE_DAYS = 5 +MAX_PRE_DAYS = 120 + +# Anomaly detection +ANOMALY_N_ESTIMATORS = 50 +ANOMALY_CONTAMINATION = 0.10 + +# Performance flags +FAST_MODE = True + diff --git a/docs/install.md b/docs/install.md index e69de29..e73b135 100644 --- a/docs/install.md +++ b/docs/install.md @@ -0,0 +1,78 @@ +# Installation Guide + +This document explains how to set up TrafficSafeAnalyzer for local development and exploration. The application runs on Streamlit and officially supports Python 3.8. + +## Prerequisites +- Python 3.8 (3.9+ is not yet validated; use 3.8 to avoid dependency issues) +- Git +- `pip` (bundled with Python) +- Optional: Conda (for environment management) or Docker (for container-based runs) + +## 1. Obtain the source code + +```bash +git clone https://github.com/tongnian0613/TrafficSafeAnalyzer.git +cd TrafficSafeAnalyzer +``` + +If you already have the repository, pull the latest changes instead: + +```bash +git pull origin main +``` + +## 2. Create a dedicated environment + +### Option A: Built-in virtual environment + +```bash +python -m venv .venv +source .venv/bin/activate # Windows: .venv\Scripts\activate +``` + +### Option B: Conda environment + +```bash +conda create -n trafficsa python=3.8 -y +conda activate trafficsa +``` + +## 3. Install project dependencies + +Install the full dependency set listed in `requirements.txt`: + +```bash +pip install -r requirements.txt +``` + +If you prefer a minimal installation before pulling in extras, install the core stack first: + +```bash +pip install streamlit pandas numpy matplotlib plotly scikit-learn statsmodels scipy +``` + +Then add optional packages as needed (Excel readers, auto-refresh, OpenAI integration): + +```bash +pip install streamlit-autorefresh openpyxl xlrd cryptography openai +``` + +## 4. Verify the setup + +1. Ensure the environment is still active (`which python` should point to `.venv` or the conda env). +2. Launch the Streamlit app: + + ```bash + streamlit run app.py + ``` + +3. Open `http://localhost:8501` in your browser. The home page should load without import errors. + +## Troubleshooting tips + +- **Missing package**: Re-run `pip install -r requirements.txt`. +- **Python version mismatch**: Confirm `python --version` reports 3.8.x inside your environment. +- **OpenSSL or cryptography errors** (macOS/Linux): Update the system OpenSSL libraries and reinstall `cryptography`. +- **Taking too long to install**: if a dependency download stalls due to a firewall, retry using a mirror (`-i https://pypi.tuna.tsinghua.edu.cn/simple`) consistent with your environment policy. + +After a successful launch, continue with the usage guide in `docs/usage.md` to load data and explore forecasts. diff --git a/docs/usage.md b/docs/usage.md index e69de29..f5b366a 100644 --- a/docs/usage.md +++ b/docs/usage.md @@ -0,0 +1,73 @@ +# Usage Guide + +TrafficSafeAnalyzer delivers accident analytics and decision support through a Streamlit interface. This guide walks through the daily workflow, expected inputs, and where to find generated artefacts. + +## Start the app + +1. Activate your virtual or conda environment. +2. From the project root, run: + + ```bash + streamlit run app.py + ``` + +3. Open `http://localhost:8501`. Keep the terminal running while you work in the browser. + +## Load input data + +Use the sidebar form labelled “数据与筛选”. + +- **Accident data (`.xlsx`)** — columns should include at minimum: + - `事故时间` (timestamp) + - `所在街道` (region or district) + - `事故类型` + - `事故数`/`accident_count` (if absent, the loader aggregates counts) +- **Strategy data (`.xlsx`)** — include: + - `发布时间` + - `交通策略类型` + - optional descriptors such as `策略名称`, `策略内容` +- Select the global filters (region, date window, strategy filter) and click `应用数据与筛选`. +- Uploaded files are cached. Upload a new file or press “Rerun” to refresh after making edits. +- Sample datasets for rapid smoke testing live in `sample/事故/*.xlsx` (accidents) and `sample/交通策略/*.xlsx` (strategies); copy them before making modifications. + +> Tip: `services/io.py` performs validation; rows missing key columns are dropped with a warning in the Streamlit log. + +## Navigate the workspace + +- **🏠 总览 (Overview)** — KPI cards, time-series plot, filtered table, and download buttons for HTML (`overview_series.html`), CSV (`filtered_view.csv`), and run metadata (`run_metadata.json`). +- **📈 预测模型 (Forecast)** — choose an intervention date and horizon, compare ARIMA / KNN / GLM / SVR forecasts, and export `arima_forecast.csv`(提交后结果会在同一数据集下保留,便于调整其他控件)。 +- **📊 模型评估 (Model evaluation)** — run rolling-window backtests, inspect RMSE/MAE/MAPE, and download `model_evaluation.csv`. +- **⚠️ 异常检测 (Anomaly detection)** — isolation forest marks outliers on the accident series; tweak contamination via the main page controls. +- **📝 策略评估 (Strategy evaluation)** — Aggregates metrics per strategy type, recommends the best option, writes `strategy_evaluation_results.csv`, and updates `recommendation.txt`. +- **⚖️ 策略对比 (Strategy comparison)** — side-by-side metrics for selected strategies, useful for “what worked best last month” reviews. +- **🧪 情景模拟 (Scenario simulation)** — apply intervention models (persistent/decay, lagged effects) to test potential roll-outs. +- **🔍 GPT 分析** — enter your own OpenAI-compatible API key and base URL in the sidebar to generate narrative insights. Keys are read at runtime only. +- **📍 事故热点 (Hotspot)** — reuse the already uploaded accident data to identify high-risk intersections and produce targeted mitigation ideas; no separate hotspot upload is required. + +Each tab remembers the active filters from the sidebar so results stay consistent. + +## Downloaded artefacts + +Generated files are saved to the project root unless you override paths in the code: + +- `overview_series.html` +- `filtered_view.csv` +- `run_metadata.json` +- `arima_forecast.csv` +- `model_evaluation.csv` +- `strategy_evaluation_results.csv` +- `recommendation.txt` + +After a session, review and archive these outputs under `docs/` or a dated folder as needed. + +## Operational tips + +- **Auto refresh**: enable from the sidebar (requires `streamlit-autorefresh`). Set the interval in seconds for live dashboards. +- **Logging**: set `LOG_LEVEL=DEBUG` before launch to see detailed diagnostics in the terminal and Streamlit log. +- **Reset filters**: choose “全市” and the full date span, then re-run the sidebar form. +- **Common warnings**: + - *“数据中没有检测到策略”*: verify the strategy Excel file and column names. + - *ARIMA failures*: shorten the horizon or ensure at least 10 historical data points before the intervention date. + - *Hotspot data issues*: ensure the accident workbook includes `事故时间`, `所在街道`, `事故类型`, and `事故具体地点` so intersections can be resolved. + +Need deeper integration or batch automation? Extract the core functions from `services/` and orchestrate them in a notebook or scheduled job. diff --git a/environment.yml b/environment.yml new file mode 100644 index 0000000..9e7f27d --- /dev/null +++ b/environment.yml @@ -0,0 +1,26 @@ +name: trafficsa +channels: + - conda-forge + - defaults +dependencies: + - python=3.12 + - pip + - streamlit>=1.20 + - pandas>=1.3 + - numpy>=1.21 + - matplotlib>=3.4 + - plotly>=5 + - scikit-learn>=1.0 + - statsmodels>=0.13 + - scipy>=1.7 + - pyarrow>=7 + - python-dateutil>=2.8.2 + - pytz>=2021.3 + - openpyxl>=3.0.9 + - xlrd>=2.0.1 + - cryptography>=3.4.7 + - requests + - pip: + - streamlit-autorefresh>=0.1.5 + - openai>=2.0.0 + - jieba>=0.42.1 diff --git a/readme.md b/readme.md index d35f5f0..eb25124 100644 --- a/readme.md +++ b/readme.md @@ -8,6 +8,8 @@ - 使用 ARIMA、KNN、GLM、SVR 等模型预测事故趋势 - 检测异常事故点 - 评估交通策略效果并提供推荐 +- 识别事故热点路口并生成风险分级与整治建议 +- 支持 GPT 分析生成自然语言洞察 ## 安装步骤 @@ -29,7 +31,7 @@ cd TrafficSafeAnalyzer 2. 创建虚拟环境(推荐): ```bash -conda create -n trafficsa python=3.12 -y +conda create -n trafficsa python=3.8 -y conda activate trafficsa pip install -r requirements.txt streamlit run app.py @@ -85,11 +87,20 @@ openai>=2.0.0 ## 配置参数 -- **数据文件**:上传事故数据(`accident_file`)和策略数据(`strategy_file`),格式为 Excel +- **数据文件**:上传事故数据(`accident_file`)和策略数据(`strategy_file`),格式为 Excel;事故热点分析会直接复用事故数据,无需额外上传。 - **环境变量**(可选): - `LOG_LEVEL=DEBUG`:启用详细日志 - 示例:`export LOG_LEVEL=DEBUG`(Linux/macOS)或 `set LOG_LEVEL=DEBUG`(Windows) +## 示例数据 + +`sample/` 目录提供了脱敏示例数据,便于快速体验: + +- `sample/事故/*.xlsx`:按年份划分的事故记录 +- `sample/交通策略/*.xlsx`:策略发布记录 + +使用前建议复制到临时位置再进行编辑。 + ## 输入输出格式 ### 输入 @@ -118,8 +129,11 @@ streamlit run app.py **问题**:数据加载失败 **解决**:确保 Excel 文件格式正确,检查列名是否匹配 -**问题**:`NameError: name 'strategy_metrics' is not defined` -**解决**:确保 `strategy_metrics` 函数定义在 `app.py` 中,且位于 `run_streamlit_app` 函数内 +**问题**:预测模型页面点击后图表未显示 +**解决**:确认干预日期之前至少有 10 条历史记录,或缩短预测天数重新提交 + +**问题**:热点分析提示“请上传事故数据” +**解决**:侧边栏上传事故数据后点击“应用数据与筛选”,热点模块会复用相同数据集 ## 日志分析 diff --git a/services/forecast.py b/services/forecast.py new file mode 100644 index 0000000..bf572cc --- /dev/null +++ b/services/forecast.py @@ -0,0 +1,139 @@ +from __future__ import annotations + +import warnings +import pandas as pd +import numpy as np +import streamlit as st +import statsmodels.api as sm +from statsmodels.tsa.arima.model import ARIMA +from statsmodels.tools.sm_exceptions import ValueWarning +from sklearn.neighbors import KNeighborsRegressor +from sklearn.svm import SVR +from sklearn.pipeline import make_pipeline +from sklearn.preprocessing import StandardScaler + +from config.settings import ARIMA_P, ARIMA_D, ARIMA_Q, MAX_PRE_DAYS + + +@st.cache_data(show_spinner=False) +def evaluate_arima_model(series, arima_order): + try: + model = ARIMA(series, order=arima_order) + model_fit = model.fit() + return model_fit.aic + except Exception: + return float("inf") + + +@st.cache_data(show_spinner=False) +def arima_forecast_with_grid_search(accident_series: pd.Series, + start_date: pd.Timestamp, + horizon: int = 30, + p_values: list = tuple(ARIMA_P), + d_values: list = tuple(ARIMA_D), + q_values: list = tuple(ARIMA_Q)) -> pd.DataFrame: + series = accident_series.asfreq('D').fillna(0) + start_date = pd.to_datetime(start_date) + + warnings.filterwarnings("ignore", category=ValueWarning) + best_score, best_cfg = float("inf"), None + for p in p_values: + for d in d_values: + for q in q_values: + order = (p, d, q) + try: + aic = evaluate_arima_model(series, order) + if aic < best_score: + best_score, best_cfg = aic, order + except Exception: + continue + + model = ARIMA(series, order=best_cfg) + fit = model.fit() + forecast_index = pd.date_range(start=start_date, periods=horizon, freq='D') + res = fit.get_forecast(steps=horizon) + df = res.summary_frame() + df.index = forecast_index + df.index.name = 'date' + df.rename(columns={'mean': 'forecast'}, inplace=True) + return df + + +def knn_forecast_counterfactual(accident_series: pd.Series, + intervention_date: pd.Timestamp, + lookback: int = 14, + horizon: int = 30): + series = accident_series.asfreq('D').fillna(0) + intervention_date = pd.to_datetime(intervention_date).normalize() + + df = pd.DataFrame({'y': series}) + for i in range(1, lookback + 1): + df[f'lag_{i}'] = df['y'].shift(i) + + train = df.loc[:intervention_date - pd.Timedelta(days=1)].dropna() + if len(train) < 5: + return None, None + X_train = train.filter(like='lag_').values + y_train = train['y'].values + knn = KNeighborsRegressor(n_neighbors=5) + knn.fit(X_train, y_train) + + history = df.loc[:intervention_date - pd.Timedelta(days=1), 'y'].tolist() + preds = [] + for _ in range(horizon): + if len(history) < lookback: + return None, None + x = np.array(history[-lookback:][::-1]).reshape(1, -1) + pred = knn.predict(x)[0] + preds.append(pred) + history.append(pred) + + pred_index = pd.date_range(intervention_date, periods=horizon, freq='D') + return pd.Series(preds, index=pred_index, name='knn_pred'), None + + +def fit_and_extrapolate(series: pd.Series, + intervention_date: pd.Timestamp, + days: int = 30, + max_pre_days: int = MAX_PRE_DAYS): + series = series.asfreq('D').fillna(0) + series.index = pd.to_datetime(series.index).tz_localize(None).normalize() + intervention_date = pd.to_datetime(intervention_date).tz_localize(None).normalize() + + pre = series.loc[:intervention_date - pd.Timedelta(days=1)] + if len(pre) > max_pre_days: + pre = pre.iloc[-max_pre_days:] + if len(pre) < 3: + return None, None, None + + x_pre = np.arange(len(pre)) + x_future = np.arange(len(pre), len(pre) + days) + + try: + X_pre_glm = sm.add_constant(np.column_stack([x_pre, x_pre**2])) + glm = sm.GLM(pre.values, X_pre_glm, family=sm.families.Poisson()) + glm_res = glm.fit() + X_future_glm = sm.add_constant(np.column_stack([x_future, x_future**2])) + glm_pred = glm_res.predict(X_future_glm) + except Exception: + glm_pred = None + + try: + svr = make_pipeline(StandardScaler(), SVR(kernel='rbf', C=10, gamma=0.1)) + svr.fit(x_pre.reshape(-1, 1), pre.values) + svr_pred = svr.predict(x_future.reshape(-1, 1)) + except Exception: + svr_pred = None + + post_index = pd.date_range(intervention_date, periods=days, freq='D') + + glm_pred = pd.Series(glm_pred, index=post_index, name='glm_pred') if glm_pred is not None else None + svr_pred = pd.Series(svr_pred, index=post_index, name='svr_pred') if svr_pred is not None else None + + post = series.reindex(post_index) + residuals = None + if svr_pred is not None: + residuals = pd.Series(post.values - svr_pred[:len(post)], index=post_index, name='residual') + + return glm_pred, svr_pred, residuals + diff --git a/services/hotspot.py b/services/hotspot.py new file mode 100644 index 0000000..3c8a52b --- /dev/null +++ b/services/hotspot.py @@ -0,0 +1,227 @@ +from __future__ import annotations + +from datetime import datetime +from typing import Iterable + +import numpy as np +import pandas as pd + +LOCATION_KEYWORDS: tuple[str, ...] = ( + "路", + "道", + "街", + "巷", + "路口", + "交叉口", + "大道", + "公路", + "口", +) +AREA_KEYWORDS: tuple[str, ...] = ( + "新城", + "临城", + "千岛", + "翁山", + "海天", + "海宇", + "定沈", + "滨海", + "港岛", + "体育", + "长升", + "金岛", + "桃湾", +) + +LOCATION_MAPPING: dict[str, str] = { + "新城千岛路": "千岛路", + "千岛路海天大道": "千岛路海天大道口", + "海天大道千岛路": "千岛路海天大道口", + "新城翁山路": "翁山路", + "翁山路金岛路": "翁山路金岛路口", + "海天大道临长路": "海天大道临长路口", + "定沈路卫生医院门口": "定沈路医院段", + "翁山路海城路西口": "翁山路海城路口", + "海宇道路口": "海宇道", + "海天大道路口": "海天大道", + "定沈路交叉路口": "定沈路", + "千岛路路口": "千岛路", + "体育路路口": "体育路", + "金岛路路口": "金岛路", +} + +SEVERITY_MAP: dict[str, int] = {"财损": 1, "伤人": 2, "亡人": 4} + + +def _extract_road_info(location: str | float | None) -> str: + if pd.isna(location): + return "未知路段" + text = str(location) + for keyword in LOCATION_KEYWORDS + AREA_KEYWORDS: + if keyword in text: + words = text.replace(",", " ").replace(",", " ").split() + for word in words: + if keyword in word: + return word + return text + return text[:20] if len(text) > 20 else text + + +def prepare_hotspot_dataset(accident_records: pd.DataFrame) -> pd.DataFrame: + df = accident_records.copy() + required_defaults: dict[str, str] = { + "道路类型": "未知道路类型", + "路口路段类型": "未知路段", + "事故具体地点": "未知路段", + "事故类型": "财损", + "所在街道": "未知街道", + } + for column, default_value in required_defaults.items(): + if column not in df.columns: + df[column] = default_value + else: + df[column] = df[column].fillna(default_value) + + if "severity" not in df.columns: + df["severity"] = df["事故类型"].map(SEVERITY_MAP).fillna(1).astype(int) + + df["事故时间"] = pd.to_datetime(df["事故时间"], errors="coerce") + df = df.dropna(subset=["事故时间"]).sort_values("事故时间").reset_index(drop=True) + df["standardized_location"] = ( + df["事故具体地点"].apply(_extract_road_info).replace(LOCATION_MAPPING) + ) + return df + + +def analyze_hotspot_frequency(df: pd.DataFrame, time_window: str = "7D") -> pd.DataFrame: + recent_cutoff = df["事故时间"].max() - pd.Timedelta(time_window) + + overall_stats = df.groupby("standardized_location").agg( + accident_count=("事故时间", "count"), + last_accident=("事故时间", "max"), + main_accident_type=("事故类型", _mode_fallback), + main_road_type=("道路类型", _mode_fallback), + main_intersection_type=("路口路段类型", _mode_fallback), + total_severity=("severity", "sum"), + ) + + recent_stats = ( + df[df["事故时间"] >= recent_cutoff] + .groupby("standardized_location") + .agg( + recent_count=("事故时间", "count"), + recent_accident_type=("事故类型", _mode_fallback), + recent_severity=("severity", "sum"), + ) + ) + + result = ( + overall_stats.merge(recent_stats, left_index=True, right_index=True, how="left") + .fillna({"recent_count": 0, "recent_severity": 0}) + .fillna("") + ) + result["recent_count"] = result["recent_count"].astype(int) + result["trend_ratio"] = result["recent_count"] / result["accident_count"] + result["days_since_last"] = ( + df["事故时间"].max() - result["last_accident"] + ).dt.days.astype(int) + result["avg_severity"] = result["total_severity"] / result["accident_count"] + return result.sort_values(["recent_count", "accident_count"], ascending=False) + + +def calculate_hotspot_risk_score(hotspot_df: pd.DataFrame) -> pd.DataFrame: + df = hotspot_df.copy() + if df.empty: + return df + + df["frequency_score"] = (df["accident_count"] / df["accident_count"].max() * 40).clip( + 0, 40 + ) + df["trend_score"] = (df["trend_ratio"] * 30).clip(0, 30) + severity_map = {"财损": 5, "伤人": 15, "亡人": 20} + df["severity_score"] = df["main_accident_type"].map(severity_map).fillna(5) + df["urgency_score"] = ((30 - df["days_since_last"]) / 30 * 10).clip(0, 10) + df["risk_score"] = ( + df["frequency_score"] + + df["trend_score"] + + df["severity_score"] + + df["urgency_score"] + ) + conditions = [ + df["risk_score"] >= 70, + df["risk_score"] >= 50, + df["risk_score"] >= 30, + ] + choices = ["高风险", "中风险", "低风险"] + df["risk_level"] = np.select(conditions, choices, default="一般风险") + return df.sort_values("risk_score", ascending=False) + + +def generate_hotspot_strategies( + hotspot_df: pd.DataFrame, time_period: str = "本周" +) -> list[dict[str, str | float]]: + strategies: list[dict[str, str | float]] = [] + for location_name, location_data in hotspot_df.iterrows(): + accident_count = float(location_data["accident_count"]) + recent_count = float(location_data.get("recent_count", 0)) + accident_type = str(location_data.get("main_accident_type", "财损")) + intersection_type = str(location_data.get("main_intersection_type", "普通路段")) + trend_ratio = float(location_data.get("trend_ratio", 0)) + risk_level = str(location_data.get("risk_level", "一般风险")) + + base_info = f"{time_period}对【{location_name}】" + data_support = ( + f"(近期{int(recent_count)}起,累计{int(accident_count)}起,{accident_type}为主)" + ) + + strategy_parts: list[str] = [] + if "信号灯" in intersection_type: + if accident_type == "财损": + strategy_parts.extend(["加强闯红灯查处", "优化信号配时", "整治不按规定让行"]) + else: + strategy_parts.extend(["完善人行过街设施", "加强非机动车管理", "设置警示标志"]) + elif "普通路段" in intersection_type: + strategy_parts.extend(["加强巡逻管控", "整治违法停车", "设置限速标志"]) + else: + strategy_parts.extend(["分析事故成因", "制定综合整治方案"]) + + if risk_level == "高风险": + strategy_parts.extend(["列为重点整治路段", "开展专项整治行动"]) + elif risk_level == "中风险": + strategy_parts.append("加强日常监管") + + if trend_ratio > 0.4: + strategy_parts.append("近期重点监控") + + strategy_text = ( + base_info + "," + ",".join(strategy_parts) + data_support + if strategy_parts + else base_info + "加强交通安全管理" + data_support + ) + + strategies.append( + { + "location": location_name, + "strategy": strategy_text, + "risk_level": risk_level, + "accident_count": accident_count, + "recent_count": recent_count, + } + ) + return strategies + + +def serialise_datetime_columns(df: pd.DataFrame, columns: Iterable[str]) -> pd.DataFrame: + result = df.copy() + for column in columns: + if column in result.columns and pd.api.types.is_datetime64_any_dtype(result[column]): + result[column] = result[column].dt.strftime("%Y-%m-%d %H:%M:%S") + return result + + +def _mode_fallback(series: pd.Series) -> str: + if series.empty: + return "" + mode = series.mode() + return str(mode.iloc[0]) if not mode.empty else str(series.iloc[0]) + diff --git a/services/io.py b/services/io.py new file mode 100644 index 0000000..f40586d --- /dev/null +++ b/services/io.py @@ -0,0 +1,270 @@ +from __future__ import annotations + +import re +from typing import Iterable, Mapping + +import pandas as pd +import streamlit as st + +COLUMN_ALIASES: Mapping[str, str] = { + '事故发生时间': '事故时间', + '发生时间': '事故时间', + '时间': '事故时间', + '街道': '所在街道', + '所属街道': '所在街道', + '所属辖区': '所在区县', + '辖区街道': '所在街道', + '事故发生地点': '事故地点', + '事故地址': '事故地点', + '事故位置': '事故地点', + '事故具体地址': '事故具体地点', + '案件类型': '事故类型', + '事故类别': '事故类型', + '事故性质': '事故类型', + '事故类型1': '事故类型', +} + +ACCIDENT_TYPE_NORMALIZATION: Mapping[str, str] = { + '财产损失': '财损', + '财产损失事故': '财损', + '一般程序': '伤人', + '一般程序事故': '伤人', + '伤人事故': '伤人', + '造成人员受伤': '伤人', + '造成人员死亡': '亡人', + '死亡事故': '亡人', + '亡人事故': '亡人', + '亡人死亡': '亡人', + '号': '财损', +} + +REGION_FROM_LOCATION_PATTERN = re.compile(r'([一-龥]{2,8}(街道|新区|开发区|镇|区))') + +REGION_NORMALIZATION: Mapping[str, str] = { + '临城中队': '临城街道', + '临城新区': '临城街道', + '临城': '临城街道', + '新城': '临城街道', + '千岛中队': '千岛街道', + '千岛新区': '千岛街道', + '千岛': '千岛街道', + '沈家门中队': '沈家门街道', + '沈家门': '沈家门街道', + '普陀城区': '沈家门街道', + '普陀': '沈家门街道', +} + + +def _clean_text(series: pd.Series) -> pd.Series: + """Strip whitespace and normalise obvious null placeholders.""" + cleaned = series.astype(str).str.strip() + null_tokens = {'', 'nan', 'NaN', 'None', 'NULL', '', '无', '—'} + return cleaned.mask(cleaned.isin(null_tokens)) + + +def _maybe_seek_start(file_obj) -> None: + if hasattr(file_obj, "seek"): + try: + file_obj.seek(0) + except Exception: # pragma: no cover - guard against non file-likes + pass + + +def _prepare_sheet(df: pd.DataFrame) -> pd.DataFrame: + """Standardise a single sheet from the事故数据 workbook.""" + if df is None or df.empty: + return pd.DataFrame() + + sheet = df.copy() + # Normalise column names first + sheet.columns = [str(col).strip() for col in sheet.columns] + # If 栏目 still not recognised, attempt to locate header row inside the data + if '事故时间' not in sheet.columns and '事故发生时间' not in sheet.columns: + header_row = None + for idx, row in sheet.iterrows(): + values = [str(cell).strip() for cell in row.tolist()] + if '事故时间' in values or '事故发生时间' in values or '报警时间' in values: + header_row = idx + break + if header_row is not None: + sheet.columns = [str(x).strip() for x in sheet.iloc[header_row].tolist()] + sheet = sheet.iloc[header_row + 1 :].reset_index(drop=True) + sheet.columns = [str(col).strip() for col in sheet.columns] + + # Apply aliases after potential header relocation + sheet = sheet.rename(columns={src: dst for src, dst in COLUMN_ALIASES.items() if src in sheet.columns}) + + return sheet + + +def _coalesce_columns(df: pd.DataFrame, columns: Iterable[str]) -> pd.Series: + result = pd.Series(pd.NA, index=df.index, dtype="object") + for col in columns: + if col in df.columns: + candidate = _clean_text(df[col]) + result = result.fillna(candidate) + return result + + +def _infer_region_from_location(location: str) -> str | None: + if pd.isna(location): + return None + text = str(location).strip() + if not text: + return None + match = REGION_FROM_LOCATION_PATTERN.search(text) + if match: + return match.group(1) + return None + + +def _normalise_region_series(series: pd.Series) -> pd.Series: + return series.map(lambda val: REGION_NORMALIZATION.get(val, val) if pd.notna(val) else val) + + +def load_accident_records(accident_file, *, require_location: bool = False) -> pd.DataFrame: + """ + Load accident records from the updated Excel template. + + The function supports workbooks with a single sheet (e.g. sample/事故处理/事故2021-2022.xlsx) + as well as legacy multi-sheet formats where the header row might sit within the data. + """ + _maybe_seek_start(accident_file) + sheets = pd.read_excel(accident_file, sheet_name=None) + if isinstance(sheets, dict): + frames = [frame for frame in ( _prepare_sheet(df) for df in sheets.values() ) if not frame.empty] + else: # pragma: no cover - pandas only returns dict when sheet_name=None, but keep guard + frames = [_prepare_sheet(sheets)] + + if not frames: + raise ValueError("未在上传的事故数据中检测到有效的事故记录,请确认文件内容。") + + accident_df = pd.concat(frames, ignore_index=True) + + # Normalise columns of interest + if '事故时间' not in accident_df.columns and '报警时间' in accident_df.columns: + accident_df['事故时间'] = accident_df['报警时间'] + + if '事故时间' not in accident_df.columns: + raise ValueError("事故数据缺少“事故时间”字段,请确认模板是否为最新版本。") + + accident_df['事故时间'] = pd.to_datetime(accident_df['事故时间'], errors='coerce') + + # Location harmonisation (used for both region inference and hotspot analysis) + location_columns_available = [col for col in ['事故具体地点', '事故地点'] if col in accident_df.columns] + location_series = _coalesce_columns(accident_df, ['事故具体地点', '事故地点']) + + # Region handling + region = _coalesce_columns(accident_df, ['所在街道', '所属街道', '所在区县', '辖区中队']) + # Infer region from location fields when still missing + if region.isna().any(): + inferred = location_series.map(_infer_region_from_location) + region = region.fillna(inferred) + region = region.fillna(_clean_text(location_series)) + + region_clean = _clean_text(region) + accident_df['所在街道'] = _normalise_region_series(region_clean) + + # Accident type normalisation + accident_type = _coalesce_columns(accident_df, ['事故类型', '事故类别', '事故性质']) + accident_type = accident_type.replace(ACCIDENT_TYPE_NORMALIZATION) + accident_type = _clean_text(accident_type).replace(ACCIDENT_TYPE_NORMALIZATION) + accident_df['事故类型'] = accident_type.fillna('财损') + + # Location column harmonisation + if require_location and not location_columns_available and location_series.isna().all(): + raise ValueError("事故数据缺少“事故具体地点”字段,请确认模板是否与 sample/事故处理 中示例一致。") + accident_df['事故具体地点'] = _clean_text(location_series) + + # Drop records with missing core fields + subset = ['事故时间', '所在街道', '事故类型'] + if require_location: + subset.append('事故具体地点') + accident_df = accident_df.dropna(subset=subset) + + # Severity score + severity_map = {'财损': 1, '伤人': 2, '亡人': 4} + accident_df['severity'] = accident_df['事故类型'].map(severity_map).fillna(1).astype(int) + + accident_df = accident_df.sort_values('事故时间').reset_index(drop=True) + + return accident_df + + +@st.cache_data(show_spinner=False) +def load_and_clean_data(accident_file, strategy_file): + accident_records = load_accident_records(accident_file) + + accident_data = accident_records.rename( + columns={'事故时间': 'date_time', '所在街道': 'region', '事故类型': 'category'} + ) + + _maybe_seek_start(strategy_file) + strategy_df = pd.read_excel(strategy_file) + strategy_df = strategy_df.rename(columns=lambda col: str(col).strip()) + if '发布时间' not in strategy_df.columns: + raise ValueError("策略数据缺少“发布时间”字段,请确认文件格式。") + + strategy_df['发布时间'] = pd.to_datetime(strategy_df['发布时间'], errors='coerce') + if '交通策略类型' not in strategy_df.columns: + raise ValueError("策略数据缺少“交通策略类型”字段,请确认文件格式。") + + strategy_df['交通策略类型'] = _clean_text(strategy_df['交通策略类型']) + strategy_df = strategy_df.dropna(subset=['发布时间', '交通策略类型']) + + accident_data = accident_data[['date_time', 'region', 'category', 'severity']] + strategy_df = strategy_df[['发布时间', '交通策略类型']].rename( + columns={'发布时间': 'date_time', '交通策略类型': 'strategy_type'} + ) + + return accident_data, strategy_df + + +@st.cache_data(show_spinner=False) +def aggregate_daily_data(accident_data: pd.DataFrame, strategy_data: pd.DataFrame) -> pd.DataFrame: + accident_data = accident_data.copy() + strategy_data = strategy_data.copy() + + accident_data['date'] = accident_data['date_time'].dt.date + daily_accidents = accident_data.groupby('date').agg( + accident_count=('date_time', 'count'), + severity=('severity', 'sum') + ) + daily_accidents.index = pd.to_datetime(daily_accidents.index) + + strategy_data['date'] = strategy_data['date_time'].dt.date + daily_strategies = strategy_data.groupby('date')['strategy_type'].apply(list) + daily_strategies.index = pd.to_datetime(daily_strategies.index) + + combined = daily_accidents.join(daily_strategies, how='left') + combined['strategy_type'] = combined['strategy_type'].apply(lambda x: x if isinstance(x, list) else []) + combined = combined.asfreq('D') + combined[['accident_count', 'severity']] = combined[['accident_count', 'severity']].fillna(0) + combined['strategy_type'] = combined['strategy_type'].apply(lambda x: x if isinstance(x, list) else []) + return combined + + +@st.cache_data(show_spinner=False) +def aggregate_daily_data_by_region(accident_data: pd.DataFrame, strategy_data: pd.DataFrame) -> pd.DataFrame: + df = accident_data.copy() + df['date'] = df['date_time'].dt.date + g = df.groupby(['region', 'date']).agg( + accident_count=('date_time', 'count'), + severity=('severity', 'sum') + ) + g.index = g.index.set_levels([g.index.levels[0], pd.to_datetime(g.index.levels[1])]) + g = g.sort_index() + + s = strategy_data.copy() + s['date'] = s['date_time'].dt.date + daily_strategies = s.groupby('date')['strategy_type'].apply(list) + daily_strategies.index = pd.to_datetime(daily_strategies.index) + + regions = g.index.get_level_values(0).unique() + dates = pd.date_range(g.index.get_level_values(1).min(), g.index.get_level_values(1).max(), freq='D') + full_index = pd.MultiIndex.from_product([regions, dates], names=['region', 'date']) + g = g.reindex(full_index).fillna(0) + + strat_map = daily_strategies.to_dict() + g = g.assign(strategy_type=[strat_map.get(d, []) for d in g.index.get_level_values('date')]) + return g diff --git a/services/metrics.py b/services/metrics.py new file mode 100644 index 0000000..34bf2a1 --- /dev/null +++ b/services/metrics.py @@ -0,0 +1,86 @@ +from __future__ import annotations + +import numpy as np +import pandas as pd +from sklearn.metrics import mean_absolute_error, mean_squared_error +from statsmodels.tsa.arima.model import ARIMA +import streamlit as st + + +@st.cache_data(show_spinner=False) +def evaluate_models(series: pd.Series, + horizon: int = 30, + lookback: int = 14, + p_values: range = range(0, 4), + d_values: range = range(0, 2), + q_values: range = range(0, 4)) -> pd.DataFrame: + """ + 留出法(最后 horizon 天作为验证集)比较 ARIMA / KNN / GLM / SVR, + 输出 MAE・RMSE・MAPE,并按 RMSE 升序排序。 + """ + series = series.asfreq('D').fillna(0) + if len(series) <= horizon + 10: + raise ValueError("序列太短,无法留出 %d 天进行评估。" % horizon) + + train, test = series.iloc[:-horizon], series.iloc[-horizon:] + + def _to_series_like(pred, a_index): + if isinstance(pred, pd.Series): + return pred.reindex(a_index) + return pd.Series(pred, index=a_index) + + def _metrics(a: pd.Series, p) -> dict: + p = _to_series_like(p, a.index).astype(float) + a = a.astype(float) + mae = mean_absolute_error(a, p) + try: + rmse = mean_squared_error(a, p, squared=False) + except TypeError: + rmse = mean_squared_error(a, p) ** 0.5 + mape = np.nanmean(np.abs((a - p) / np.where(a == 0, np.nan, a))) * 100 + return {"MAE": mae, "RMSE": rmse, "MAPE": mape} + + results = {} + + best_aic, best_order = float('inf'), (1, 0, 1) + for p in p_values: + for d in d_values: + for q in q_values: + try: + aic = ARIMA(train, order=(p, d, q)).fit().aic + if aic < best_aic: + best_aic, best_order = aic, (p, d, q) + except Exception: + continue + arima_train = train.asfreq('D').fillna(0) + arima_pred = ARIMA(arima_train, order=best_order).fit().forecast(steps=horizon) + results['ARIMA'] = _metrics(test, arima_pred) + + # Import local utilities to avoid circular dependencies + from services.forecast import knn_forecast_counterfactual, fit_and_extrapolate + + try: + knn_pred, _ = knn_forecast_counterfactual(series, + train.index[-1] + pd.Timedelta(days=1), + lookback=lookback, + horizon=horizon) + if knn_pred is not None: + results['KNN'] = _metrics(test, knn_pred) + except Exception: + pass + + try: + glm_pred, svr_pred, _ = fit_and_extrapolate(series, + train.index[-1] + pd.Timedelta(days=1), + days=horizon) + if glm_pred is not None: + results['GLM'] = _metrics(test, glm_pred) + if svr_pred is not None: + results['SVR'] = _metrics(test, svr_pred) + except Exception: + pass + + return (pd.DataFrame(results) + .T.sort_values('RMSE') + .round(3)) + diff --git a/services/strategy.py b/services/strategy.py new file mode 100644 index 0000000..e710174 --- /dev/null +++ b/services/strategy.py @@ -0,0 +1,157 @@ +from __future__ import annotations + +import pandas as pd +import streamlit as st + +from services.forecast import fit_and_extrapolate, arima_forecast_with_grid_search +from config.settings import MIN_PRE_DAYS, MAX_PRE_DAYS + + +def evaluate_strategy_effectiveness(actual_series: pd.Series, + counterfactual_series: pd.Series, + severity_series: pd.Series, + strategy_date: pd.Timestamp, + window: int = 30): + strategy_date = pd.to_datetime(strategy_date) + window_end = strategy_date + pd.Timedelta(days=window - 1) + pre_sev = severity_series.loc[strategy_date - pd.Timedelta(days=window):strategy_date - pd.Timedelta(days=1)].sum() + post_sev = severity_series.loc[strategy_date:window_end].sum() + actual_post = actual_series.loc[strategy_date:window_end] + counter_post = counterfactual_series.loc[strategy_date:window_end].reindex(actual_post.index) + window_len = len(actual_post) + if window_len == 0: + return False, False, (0.0, 0.0), '三级' + effective_days = (actual_post < counter_post).sum() + count_effective = effective_days >= (window_len / 2) + severity_effective = post_sev < pre_sev + cf_sum = counter_post.sum() + F1 = (cf_sum - actual_post.sum()) / cf_sum if cf_sum > 0 else 0.0 + F2 = (pre_sev - post_sev) / pre_sev if pre_sev > 0 else 0.0 + if F1 > 0.5 and F2 > 0.5: + safety_state = '一级' + elif F1 > 0.3: + safety_state = '二级' + else: + safety_state = '三级' + return count_effective, severity_effective, (F1, F2), safety_state + + +@st.cache_data(show_spinner=False) +def generate_output_and_recommendations(combined_data: pd.DataFrame, + strategy_types: list, + region: str = '全市', + horizon: int = 30): + results = {} + combined_data = combined_data.copy().asfreq('D') + combined_data[['accident_count','severity']] = combined_data[['accident_count','severity']].fillna(0) + combined_data['strategy_type'] = combined_data['strategy_type'].apply(lambda x: x if isinstance(x, list) else []) + + acc_full = combined_data['accident_count'] + sev_full = combined_data['severity'] + + max_fit_days = max(horizon + 60, MAX_PRE_DAYS) + + for strategy in strategy_types: + has_strategy = combined_data['strategy_type'].apply(lambda x: strategy in x) + if not has_strategy.any(): + continue + candidate_dates = has_strategy[has_strategy].index + intervention_date = None + fit_start_dt = None + for dt in candidate_dates: + fit_start_dt = max(acc_full.index.min(), dt - pd.Timedelta(days=max_fit_days)) + pre_hist = acc_full.loc[fit_start_dt:dt - pd.Timedelta(days=1)] + if len(pre_hist) >= MIN_PRE_DAYS: + intervention_date = dt + break + if intervention_date is None: + intervention_date = candidate_dates[0] + fit_start_dt = max(acc_full.index.min(), intervention_date - pd.Timedelta(days=max_fit_days)) + + acc = acc_full.loc[fit_start_dt:] + sev = sev_full.loc[fit_start_dt:] + horizon_eff = max(7, min(horizon, len(acc.loc[intervention_date:]) )) + + glm_pred, svr_pred, residuals = fit_and_extrapolate(acc, intervention_date, days=horizon_eff) + + counter = None + if svr_pred is not None: + counter = svr_pred + elif glm_pred is not None: + counter = glm_pred + else: + try: + arima_df = arima_forecast_with_grid_search(acc.loc[:intervention_date], + start_date=intervention_date + pd.Timedelta(days=1), + horizon=horizon_eff) + counter = pd.Series(arima_df['forecast'].values, index=arima_df.index, name='cf_arima') + residuals = (acc.reindex(counter.index) - counter) + except Exception: + counter = None + if counter is None: + continue + + count_eff, sev_eff, (F1, F2), state = evaluate_strategy_effectiveness( + actual_series=acc, + counterfactual_series=counter, + severity_series=sev, + strategy_date=intervention_date, + window=horizon_eff + ) + results[strategy] = { + 'effect_strength': float(residuals.dropna().mean()) if residuals is not None else 0.0, + 'adaptability': float(F1 + F2), + 'count_effective': bool(count_eff), + 'severity_effective': bool(sev_eff), + 'safety_state': state, + 'F1': float(F1), + 'F2': float(F2), + 'intervention_date': str(intervention_date.date()) + } + + # Secondary attempt with 14-day window if no results + if not results: + for strategy in strategy_types: + has_strategy = combined_data['strategy_type'].apply(lambda x: strategy in x) + if not has_strategy.any(): + continue + intervention_date = has_strategy[has_strategy].index[0] + glm_pred, svr_pred, residuals = fit_and_extrapolate(acc_full, intervention_date, days=14) + counter = None + if svr_pred is not None: + counter = svr_pred + elif glm_pred is not None: + counter = glm_pred + else: + try: + arima_df = arima_forecast_with_grid_search(acc_full.loc[:intervention_date], + start_date=intervention_date + pd.Timedelta(days=1), + horizon=14) + counter = pd.Series(arima_df['forecast'].values, index=arima_df.index, name='cf_arima') + residuals = (acc_full.reindex(counter.index) - counter) + except Exception: + counter = None + if counter is None: + continue + count_eff, sev_eff, (F1, F2), state = evaluate_strategy_effectiveness( + actual_series=acc_full, + counterfactual_series=counter, + severity_series=sev_full, + strategy_date=intervention_date, + window=14 + ) + results[strategy] = { + 'effect_strength': float(residuals.dropna().mean()) if residuals is not None else 0.0, + 'adaptability': float(F1 + F2), + 'count_effective': bool(count_eff), + 'severity_effective': bool(sev_eff), + 'safety_state': state, + 'F1': float(F1), + 'F2': float(F2), + 'intervention_date': str(intervention_date.date()) + } + + best_strategy = max(results, key=lambda x: results[x]['adaptability']) if results else None + recommendation = f"建议在{region}区域长期实施策略类型 {best_strategy}" if best_strategy else "无足够数据推荐策略" + return results, recommendation + diff --git a/ui_sections/__init__.py b/ui_sections/__init__.py new file mode 100644 index 0000000..6269dc8 --- /dev/null +++ b/ui_sections/__init__.py @@ -0,0 +1,13 @@ +from .overview import render_overview +from .forecast import render_forecast +from .model_eval import render_model_eval +from .strategy_eval import render_strategy_eval +from .hotspot import render_hotspot + +__all__ = [ + 'render_overview', + 'render_forecast', + 'render_model_eval', + 'render_strategy_eval', + 'render_hotspot', +] diff --git a/ui_sections/forecast.py b/ui_sections/forecast.py new file mode 100644 index 0000000..ebcf4a1 --- /dev/null +++ b/ui_sections/forecast.py @@ -0,0 +1,185 @@ +from __future__ import annotations + +import pandas as pd +import plotly.graph_objects as go +import streamlit as st + +from services.forecast import ( + arima_forecast_with_grid_search, + knn_forecast_counterfactual, + fit_and_extrapolate, +) + + +def render_forecast(base: pd.DataFrame): + st.subheader("多模型预测比较") + + if base is None or base.empty: + st.info("暂无可用于预测的事故数据,请先在侧边栏上传数据并应用筛选。") + st.session_state.setdefault( + "forecast_state", + {"results": None, "last_message": "暂无可用于预测的事故数据。"}, + ) + return + + forecast_state = st.session_state.setdefault( + "forecast_state", + { + "selected_date": None, + "horizon": 30, + "results": None, + "last_message": None, + "data_signature": None, + }, + ) + + earliest_date = base.index.min().date() + latest_date = base.index.max().date() + fallback_date = max( + (base.index.max() - pd.Timedelta(days=30)).date(), + earliest_date, + ) + current_signature = ( + earliest_date.isoformat(), + latest_date.isoformat(), + int(len(base)), + float(base["accident_count"].sum()), + ) + + # Reset cached results if the underlying dataset has changed + if forecast_state.get("data_signature") != current_signature: + forecast_state.update( + { + "data_signature": current_signature, + "results": None, + "last_message": None, + "selected_date": fallback_date, + } + ) + + default_date = forecast_state.get("selected_date") or fallback_date + if default_date < earliest_date: + default_date = earliest_date + if default_date > latest_date: + default_date = latest_date + + with st.form(key="predict_form"): + selected_date = st.date_input( + "选择干预日期 / 预测起点", + value=default_date, + min_value=earliest_date, + max_value=latest_date, + ) + horizon = st.number_input( + "预测天数", + min_value=7, + max_value=90, + value=int(forecast_state.get("horizon", 30)), + step=1, + ) + submit_predict = st.form_submit_button("应用预测参数") + + forecast_state["selected_date"] = selected_date + forecast_state["horizon"] = int(horizon) + + if submit_predict: + history = base.loc[:pd.to_datetime(selected_date)] + if len(history) < 10: + forecast_state.update( + { + "results": None, + "last_message": "干预前数据不足(至少需要 10 个观测点)。", + } + ) + else: + with st.spinner("正在生成预测结果…"): + warnings: list[str] = [] + try: + train_series = history["accident_count"] + arima_df = arima_forecast_with_grid_search( + train_series, + start_date=pd.to_datetime(selected_date) + pd.Timedelta(days=1), + horizon=int(horizon), + ) + except Exception as exc: + arima_df = None + warnings.append(f"ARIMA 运行失败:{exc}") + + knn_pred, _ = knn_forecast_counterfactual( + base["accident_count"], + pd.to_datetime(selected_date), + horizon=int(horizon), + ) + if knn_pred is None: + warnings.append("KNN 预测未生成结果(历史数据不足或维度不满足要求)。") + + glm_pred, svr_pred, _ = fit_and_extrapolate( + base["accident_count"], + pd.to_datetime(selected_date), + days=int(horizon), + ) + if glm_pred is None and svr_pred is None: + warnings.append("GLM/SVR 预测未生成结果,建议缩短预测窗口或检查源数据。") + + forecast_state.update( + { + "results": { + "selected_date": selected_date, + "horizon": int(horizon), + "arima_df": arima_df, + "knn_pred": knn_pred, + "glm_pred": glm_pred, + "svr_pred": svr_pred, + "warnings": warnings, + }, + "last_message": None, + } + ) + + results = forecast_state.get("results") + if not results: + if forecast_state.get("last_message"): + st.warning(forecast_state["last_message"]) + else: + st.info("请设置预测参数并点击“应用预测参数”按钮。") + return + + first_date = pd.to_datetime(results["selected_date"]) + horizon_days = int(results["horizon"]) + arima_df = results["arima_df"] + knn_pred = results["knn_pred"] + glm_pred = results["glm_pred"] + svr_pred = results["svr_pred"] + + fig_pred = go.Figure() + fig_pred.add_trace( + go.Scatter(x=base.index, y=base["accident_count"], name="实际", mode="lines") + ) + if arima_df is not None: + fig_pred.add_trace( + go.Scatter(x=arima_df.index, y=arima_df["forecast"], name="ARIMA", mode="lines") + ) + if knn_pred is not None: + fig_pred.add_trace(go.Scatter(x=knn_pred.index, y=knn_pred, name="KNN", mode="lines")) + if glm_pred is not None: + fig_pred.add_trace(go.Scatter(x=glm_pred.index, y=glm_pred, name="GLM", mode="lines")) + if svr_pred is not None: + fig_pred.add_trace(go.Scatter(x=svr_pred.index, y=svr_pred, name="SVR", mode="lines")) + + fig_pred.update_layout( + title=f"多模型预测比较(起点:{first_date.date()},预测 {horizon_days} 天)", + xaxis_title="日期", + yaxis_title="事故数", + ) + st.plotly_chart(fig_pred, use_container_width=True) + + if arima_df is not None: + st.download_button( + "下载 ARIMA 预测 CSV", + data=arima_df.to_csv().encode("utf-8-sig"), + file_name="arima_forecast.csv", + mime="text/csv", + ) + + for warning_text in results.get("warnings", []): + st.warning(warning_text) diff --git a/ui_sections/hotspot.py b/ui_sections/hotspot.py new file mode 100644 index 0000000..34c6606 --- /dev/null +++ b/ui_sections/hotspot.py @@ -0,0 +1,189 @@ +from __future__ import annotations + +import json +from datetime import datetime + +import plotly.express as px +import streamlit as st + +from services.hotspot import ( + analyze_hotspot_frequency, + calculate_hotspot_risk_score, + generate_hotspot_strategies, + prepare_hotspot_dataset, + serialise_datetime_columns, +) + + +@st.cache_data(show_spinner=False) +def _prepare_hotspot_data(df): + return prepare_hotspot_dataset(df) + + +def render_hotspot(accident_records, accident_source_name: str | None) -> None: + st.header("📍 事故多发路口分析") + st.markdown("独立分析事故数据,识别高风险路口并生成针对性策略。") + + if accident_records is None: + st.info("请在左侧上传事故数据并点击“应用数据与筛选”后再执行热点分析。") + st.markdown( + """ + ### 📝 支持的数据格式要求: + - **文件格式**:Excel (.xlsx) + - **必要字段**: + - `事故时间` + - `事故类型` + - `事故具体地点` + - `所在街道` + - `道路类型` + - `路口路段类型` + """ + ) + return + + with st.spinner("正在准备事故热点数据…"): + hotspot_data = _prepare_hotspot_data(accident_records) + + st.success(f"✅ 成功加载数据:{len(hotspot_data)} 条事故记录") + + metric_cols = st.columns(3) + with metric_cols[0]: + st.metric( + "数据时间范围", + f"{hotspot_data['事故时间'].min().strftime('%Y-%m-%d')} 至 {hotspot_data['事故时间'].max().strftime('%Y-%m-%d')}", + ) + with metric_cols[1]: + st.metric( + "事故类型分布", + f"财损: {len(hotspot_data[hotspot_data['事故类型'] == '财损'])}起", + ) + with metric_cols[2]: + st.metric("涉及区域", f"{hotspot_data['所在街道'].nunique()}个街道") + + st.subheader("🔧 分析参数设置") + settings_cols = st.columns(3) + with settings_cols[0]: + time_window = st.selectbox( + "统计时间窗口", + options=["7D", "15D", "30D"], + index=0, + key="hotspot_window", + ) + with settings_cols[1]: + min_accidents = st.number_input( + "最小事故数", min_value=1, max_value=50, value=3, key="hotspot_min_accidents" + ) + with settings_cols[2]: + top_n = st.slider("显示热点数量", min_value=3, max_value=20, value=8, key="hotspot_top_n") + + if not st.button("🚀 开始热点分析", type="primary"): + return + + with st.spinner("正在分析事故热点分布…"): + hotspots = analyze_hotspot_frequency(hotspot_data, time_window=time_window) + hotspots = hotspots[hotspots["accident_count"] >= min_accidents] + + if hotspots.empty: + st.warning("⚠️ 未找到符合条件的事故热点数据,请调整筛选参数。") + return + + hotspots_with_risk = calculate_hotspot_risk_score(hotspots.head(top_n * 3)) + top_hotspots = hotspots_with_risk.head(top_n) + + st.subheader("📊 事故多发路口排名(前{0}个)".format(top_n)) + display_columns = { + "accident_count": "累计事故数", + "recent_count": "近期事故数", + "trend_ratio": "趋势比例", + "main_accident_type": "主要类型", + "main_intersection_type": "路口类型", + "risk_score": "风险评分", + "risk_level": "风险等级", + } + display_df = top_hotspots[list(display_columns.keys())].rename(columns=display_columns) + styled_df = display_df.style.format({"趋势比例": "{:.2f}", "风险评分": "{:.1f}"}).background_gradient( + subset=["风险评分"], cmap="Reds" + ) + st.dataframe(styled_df, use_container_width=True) + + st.subheader("🎯 针对性策略建议") + strategies = generate_hotspot_strategies(top_hotspots, time_period="本周") + for index, strategy_info in enumerate(strategies, start=1): + message = f"**{index}. {strategy_info['strategy']}**" + risk_level = strategy_info["risk_level"] + if risk_level == "高风险": + st.error(f"🚨 {message}") + elif risk_level == "中风险": + st.warning(f"⚠️ {message}") + else: + st.info(f"✅ {message}") + + st.subheader("📈 数据分析可视化") + chart_cols = st.columns(2) + with chart_cols[0]: + fig_hotspots = px.bar( + top_hotspots.head(10), + x=top_hotspots.head(10).index, + y=["accident_count", "recent_count"], + title="事故频次TOP10分布", + labels={"value": "事故数量", "variable": "类型", "index": "路口名称"}, + barmode="group", + ) + fig_hotspots.update_layout(xaxis_tickangle=-45) + st.plotly_chart(fig_hotspots, use_container_width=True) + + with chart_cols[1]: + risk_distribution = top_hotspots["risk_level"].value_counts() + fig_risk = px.pie( + values=risk_distribution.values, + names=risk_distribution.index, + title="风险等级分布", + color_discrete_map={"高风险": "red", "中风险": "orange", "低风险": "green"}, + ) + st.plotly_chart(fig_risk, use_container_width=True) + + st.subheader("💾 数据导出") + download_cols = st.columns(2) + with download_cols[0]: + hotspot_csv = top_hotspots.to_csv().encode("utf-8-sig") + st.download_button( + "📥 下载热点数据CSV", + data=hotspot_csv, + file_name=f"accident_hotspots_{datetime.now().strftime('%Y%m%d')}.csv", + mime="text/csv", + ) + + with download_cols[1]: + serializable = serialise_datetime_columns( + top_hotspots.reset_index(), + columns=[col for col in top_hotspots.columns if "time" in col or "date" in col], + ) + report_payload = { + "analysis_time": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), + "time_window": time_window, + "data_source": accident_source_name or "事故数据", + "total_records": int(len(hotspot_data)), + "analysis_parameters": {"min_accidents": int(min_accidents), "top_n": int(top_n)}, + "top_hotspots": serializable.to_dict("records"), + "recommended_strategies": strategies, + "summary": { + "high_risk_count": int((top_hotspots["risk_level"] == "高风险").sum()), + "medium_risk_count": int((top_hotspots["risk_level"] == "中风险").sum()), + "total_analyzed_locations": int(len(hotspots)), + "most_dangerous_location": top_hotspots.index[0] + if len(top_hotspots) + else "无", + }, + } + st.download_button( + "📄 下载完整分析报告", + data=json.dumps(report_payload, ensure_ascii=False, indent=2), + file_name=f"hotspot_analysis_report_{datetime.now().strftime('%Y%m%d_%H%M')}.json", + mime="application/json", + ) + + with st.expander("📋 查看原始数据预览"): + preview_cols = ["事故时间", "所在街道", "事故类型", "事故具体地点", "道路类型"] + preview_df = hotspot_data[preview_cols].copy() + st.dataframe(preview_df.head(10), use_container_width=True) + diff --git a/ui_sections/model_eval.py b/ui_sections/model_eval.py new file mode 100644 index 0000000..65316fc --- /dev/null +++ b/ui_sections/model_eval.py @@ -0,0 +1,32 @@ +from __future__ import annotations + +import pandas as pd +import streamlit as st + +from services.metrics import evaluate_models + + +def render_model_eval(base: pd.DataFrame): + st.subheader("模型预测效果对比") + with st.form(key="model_eval_form"): + horizon_sel = st.slider("评估窗口(天)", 7, 60, 30, step=1) + submit_eval = st.form_submit_button("应用评估参数") + + if not submit_eval: + st.info("请设置评估窗口并点击“应用评估参数”按钮。") + return + + try: + df_metrics = evaluate_models(base['accident_count'], horizon=int(horizon_sel)) + st.dataframe(df_metrics, use_container_width=True) + best_model = df_metrics['RMSE'].idxmin() + st.success(f"过去 {int(horizon_sel)} 天中,RMSE 最低的模型是:**{best_model}**") + st.download_button( + "下载评估结果 CSV", + data=df_metrics.to_csv().encode('utf-8-sig'), + file_name="model_evaluation.csv", + mime="text/csv", + ) + except ValueError as err: + st.warning(str(err)) + diff --git a/ui_sections/overview.py b/ui_sections/overview.py new file mode 100644 index 0000000..33d1a27 --- /dev/null +++ b/ui_sections/overview.py @@ -0,0 +1,33 @@ +from __future__ import annotations + +import json +import pandas as pd +import plotly.graph_objects as go +import streamlit as st + +def render_overview(base: pd.DataFrame, region_sel: str, start_dt: pd.Timestamp, end_dt: pd.Timestamp, + strat_filter: list[str]): + fig_line = go.Figure() + fig_line.add_trace(go.Scatter(x=base.index, y=base['accident_count'], name='事故数', mode='lines')) + fig_line.update_layout(title="事故数(过滤后)", xaxis_title="Date", yaxis_title="Count") + st.plotly_chart(fig_line, use_container_width=True) + + html = fig_line.to_html(full_html=True, include_plotlyjs='cdn') + st.download_button("下载图表 HTML", data=html.encode('utf-8'), + file_name="overview_series.html", mime="text/html") + + st.dataframe(base, use_container_width=True) + csv_bytes = base.to_csv(index=True).encode('utf-8-sig') + st.download_button("下载当前视图 CSV", data=csv_bytes, file_name="filtered_view.csv", mime="text/csv") + + meta = { + "region": region_sel, + "date_range": [str(start_dt.date()), str(end_dt.date())], + "strategy_filter": strat_filter, + "rows": int(len(base)), + "min_date": str(base.index.min().date()) if len(base) else None, + "max_date": str(base.index.max().date()) if len(base) else None, + } + st.download_button("下载运行参数 JSON", data=json.dumps(meta, ensure_ascii=False, indent=2).encode('utf-8'), + file_name="run_metadata.json", mime="application/json") + diff --git a/ui_sections/strategy_eval.py b/ui_sections/strategy_eval.py new file mode 100644 index 0000000..ea9548a --- /dev/null +++ b/ui_sections/strategy_eval.py @@ -0,0 +1,50 @@ +from __future__ import annotations + +import os +import pandas as pd +import streamlit as st + +from services.strategy import generate_output_and_recommendations + + +def render_strategy_eval(base: pd.DataFrame, all_strategy_types: list[str], region_sel: str): + st.info(f"📌 检测到的策略类型:{', '.join(all_strategy_types) or '(数据中没有策略)'}") + if not all_strategy_types: + st.warning("数据中没有检测到策略。") + return + + with st.form(key="strategy_eval_form"): + horizon_eval = st.slider("评估窗口(天)", 7, 60, 14, step=1) + submit_strat_eval = st.form_submit_button("应用评估参数") + + if not submit_strat_eval: + st.info("请设置评估窗口并点击“应用评估参数”按钮。") + return + + results, recommendation = generate_output_and_recommendations( + base, + all_strategy_types, + region=region_sel if region_sel != '全市' else '全市', + horizon=horizon_eval, + ) + + if not results: + st.warning("⚠️ 未能完成策略评估。请尝试缩短评估窗口或扩大日期范围。") + return + + st.subheader("各策略指标") + df_res = pd.DataFrame(results).T + st.dataframe(df_res, use_container_width=True) + st.success(f"⭐ 推荐:{recommendation}") + + st.download_button( + "下载策略评估 CSV", + data=df_res.to_csv().encode('utf-8-sig'), + file_name="strategy_evaluation_results.csv", + mime="text/csv", + ) + + if os.path.exists('recommendation.txt'): + with open('recommendation.txt','r',encoding='utf-8') as f: + st.download_button("下载推荐文本", data=f.read().encode('utf-8'), file_name="recommendation.txt") +