modify: cleanup project structure and docs
This commit is contained in:
13
ui_sections/__init__.py
Normal file
13
ui_sections/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from .overview import render_overview
|
||||
from .forecast import render_forecast
|
||||
from .model_eval import render_model_eval
|
||||
from .strategy_eval import render_strategy_eval
|
||||
from .hotspot import render_hotspot
|
||||
|
||||
__all__ = [
|
||||
'render_overview',
|
||||
'render_forecast',
|
||||
'render_model_eval',
|
||||
'render_strategy_eval',
|
||||
'render_hotspot',
|
||||
]
|
||||
185
ui_sections/forecast.py
Normal file
185
ui_sections/forecast.py
Normal file
@@ -0,0 +1,185 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pandas as pd
|
||||
import plotly.graph_objects as go
|
||||
import streamlit as st
|
||||
|
||||
from services.forecast import (
|
||||
arima_forecast_with_grid_search,
|
||||
knn_forecast_counterfactual,
|
||||
fit_and_extrapolate,
|
||||
)
|
||||
|
||||
|
||||
def render_forecast(base: pd.DataFrame):
|
||||
st.subheader("多模型预测比较")
|
||||
|
||||
if base is None or base.empty:
|
||||
st.info("暂无可用于预测的事故数据,请先在侧边栏上传数据并应用筛选。")
|
||||
st.session_state.setdefault(
|
||||
"forecast_state",
|
||||
{"results": None, "last_message": "暂无可用于预测的事故数据。"},
|
||||
)
|
||||
return
|
||||
|
||||
forecast_state = st.session_state.setdefault(
|
||||
"forecast_state",
|
||||
{
|
||||
"selected_date": None,
|
||||
"horizon": 30,
|
||||
"results": None,
|
||||
"last_message": None,
|
||||
"data_signature": None,
|
||||
},
|
||||
)
|
||||
|
||||
earliest_date = base.index.min().date()
|
||||
latest_date = base.index.max().date()
|
||||
fallback_date = max(
|
||||
(base.index.max() - pd.Timedelta(days=30)).date(),
|
||||
earliest_date,
|
||||
)
|
||||
current_signature = (
|
||||
earliest_date.isoformat(),
|
||||
latest_date.isoformat(),
|
||||
int(len(base)),
|
||||
float(base["accident_count"].sum()),
|
||||
)
|
||||
|
||||
# Reset cached results if the underlying dataset has changed
|
||||
if forecast_state.get("data_signature") != current_signature:
|
||||
forecast_state.update(
|
||||
{
|
||||
"data_signature": current_signature,
|
||||
"results": None,
|
||||
"last_message": None,
|
||||
"selected_date": fallback_date,
|
||||
}
|
||||
)
|
||||
|
||||
default_date = forecast_state.get("selected_date") or fallback_date
|
||||
if default_date < earliest_date:
|
||||
default_date = earliest_date
|
||||
if default_date > latest_date:
|
||||
default_date = latest_date
|
||||
|
||||
with st.form(key="predict_form"):
|
||||
selected_date = st.date_input(
|
||||
"选择干预日期 / 预测起点",
|
||||
value=default_date,
|
||||
min_value=earliest_date,
|
||||
max_value=latest_date,
|
||||
)
|
||||
horizon = st.number_input(
|
||||
"预测天数",
|
||||
min_value=7,
|
||||
max_value=90,
|
||||
value=int(forecast_state.get("horizon", 30)),
|
||||
step=1,
|
||||
)
|
||||
submit_predict = st.form_submit_button("应用预测参数")
|
||||
|
||||
forecast_state["selected_date"] = selected_date
|
||||
forecast_state["horizon"] = int(horizon)
|
||||
|
||||
if submit_predict:
|
||||
history = base.loc[:pd.to_datetime(selected_date)]
|
||||
if len(history) < 10:
|
||||
forecast_state.update(
|
||||
{
|
||||
"results": None,
|
||||
"last_message": "干预前数据不足(至少需要 10 个观测点)。",
|
||||
}
|
||||
)
|
||||
else:
|
||||
with st.spinner("正在生成预测结果…"):
|
||||
warnings: list[str] = []
|
||||
try:
|
||||
train_series = history["accident_count"]
|
||||
arima_df = arima_forecast_with_grid_search(
|
||||
train_series,
|
||||
start_date=pd.to_datetime(selected_date) + pd.Timedelta(days=1),
|
||||
horizon=int(horizon),
|
||||
)
|
||||
except Exception as exc:
|
||||
arima_df = None
|
||||
warnings.append(f"ARIMA 运行失败:{exc}")
|
||||
|
||||
knn_pred, _ = knn_forecast_counterfactual(
|
||||
base["accident_count"],
|
||||
pd.to_datetime(selected_date),
|
||||
horizon=int(horizon),
|
||||
)
|
||||
if knn_pred is None:
|
||||
warnings.append("KNN 预测未生成结果(历史数据不足或维度不满足要求)。")
|
||||
|
||||
glm_pred, svr_pred, _ = fit_and_extrapolate(
|
||||
base["accident_count"],
|
||||
pd.to_datetime(selected_date),
|
||||
days=int(horizon),
|
||||
)
|
||||
if glm_pred is None and svr_pred is None:
|
||||
warnings.append("GLM/SVR 预测未生成结果,建议缩短预测窗口或检查源数据。")
|
||||
|
||||
forecast_state.update(
|
||||
{
|
||||
"results": {
|
||||
"selected_date": selected_date,
|
||||
"horizon": int(horizon),
|
||||
"arima_df": arima_df,
|
||||
"knn_pred": knn_pred,
|
||||
"glm_pred": glm_pred,
|
||||
"svr_pred": svr_pred,
|
||||
"warnings": warnings,
|
||||
},
|
||||
"last_message": None,
|
||||
}
|
||||
)
|
||||
|
||||
results = forecast_state.get("results")
|
||||
if not results:
|
||||
if forecast_state.get("last_message"):
|
||||
st.warning(forecast_state["last_message"])
|
||||
else:
|
||||
st.info("请设置预测参数并点击“应用预测参数”按钮。")
|
||||
return
|
||||
|
||||
first_date = pd.to_datetime(results["selected_date"])
|
||||
horizon_days = int(results["horizon"])
|
||||
arima_df = results["arima_df"]
|
||||
knn_pred = results["knn_pred"]
|
||||
glm_pred = results["glm_pred"]
|
||||
svr_pred = results["svr_pred"]
|
||||
|
||||
fig_pred = go.Figure()
|
||||
fig_pred.add_trace(
|
||||
go.Scatter(x=base.index, y=base["accident_count"], name="实际", mode="lines")
|
||||
)
|
||||
if arima_df is not None:
|
||||
fig_pred.add_trace(
|
||||
go.Scatter(x=arima_df.index, y=arima_df["forecast"], name="ARIMA", mode="lines")
|
||||
)
|
||||
if knn_pred is not None:
|
||||
fig_pred.add_trace(go.Scatter(x=knn_pred.index, y=knn_pred, name="KNN", mode="lines"))
|
||||
if glm_pred is not None:
|
||||
fig_pred.add_trace(go.Scatter(x=glm_pred.index, y=glm_pred, name="GLM", mode="lines"))
|
||||
if svr_pred is not None:
|
||||
fig_pred.add_trace(go.Scatter(x=svr_pred.index, y=svr_pred, name="SVR", mode="lines"))
|
||||
|
||||
fig_pred.update_layout(
|
||||
title=f"多模型预测比较(起点:{first_date.date()},预测 {horizon_days} 天)",
|
||||
xaxis_title="日期",
|
||||
yaxis_title="事故数",
|
||||
)
|
||||
st.plotly_chart(fig_pred, use_container_width=True)
|
||||
|
||||
if arima_df is not None:
|
||||
st.download_button(
|
||||
"下载 ARIMA 预测 CSV",
|
||||
data=arima_df.to_csv().encode("utf-8-sig"),
|
||||
file_name="arima_forecast.csv",
|
||||
mime="text/csv",
|
||||
)
|
||||
|
||||
for warning_text in results.get("warnings", []):
|
||||
st.warning(warning_text)
|
||||
189
ui_sections/hotspot.py
Normal file
189
ui_sections/hotspot.py
Normal file
@@ -0,0 +1,189 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from datetime import datetime
|
||||
|
||||
import plotly.express as px
|
||||
import streamlit as st
|
||||
|
||||
from services.hotspot import (
|
||||
analyze_hotspot_frequency,
|
||||
calculate_hotspot_risk_score,
|
||||
generate_hotspot_strategies,
|
||||
prepare_hotspot_dataset,
|
||||
serialise_datetime_columns,
|
||||
)
|
||||
|
||||
|
||||
@st.cache_data(show_spinner=False)
|
||||
def _prepare_hotspot_data(df):
|
||||
return prepare_hotspot_dataset(df)
|
||||
|
||||
|
||||
def render_hotspot(accident_records, accident_source_name: str | None) -> None:
|
||||
st.header("📍 事故多发路口分析")
|
||||
st.markdown("独立分析事故数据,识别高风险路口并生成针对性策略。")
|
||||
|
||||
if accident_records is None:
|
||||
st.info("请在左侧上传事故数据并点击“应用数据与筛选”后再执行热点分析。")
|
||||
st.markdown(
|
||||
"""
|
||||
### 📝 支持的数据格式要求:
|
||||
- **文件格式**:Excel (.xlsx)
|
||||
- **必要字段**:
|
||||
- `事故时间`
|
||||
- `事故类型`
|
||||
- `事故具体地点`
|
||||
- `所在街道`
|
||||
- `道路类型`
|
||||
- `路口路段类型`
|
||||
"""
|
||||
)
|
||||
return
|
||||
|
||||
with st.spinner("正在准备事故热点数据…"):
|
||||
hotspot_data = _prepare_hotspot_data(accident_records)
|
||||
|
||||
st.success(f"✅ 成功加载数据:{len(hotspot_data)} 条事故记录")
|
||||
|
||||
metric_cols = st.columns(3)
|
||||
with metric_cols[0]:
|
||||
st.metric(
|
||||
"数据时间范围",
|
||||
f"{hotspot_data['事故时间'].min().strftime('%Y-%m-%d')} 至 {hotspot_data['事故时间'].max().strftime('%Y-%m-%d')}",
|
||||
)
|
||||
with metric_cols[1]:
|
||||
st.metric(
|
||||
"事故类型分布",
|
||||
f"财损: {len(hotspot_data[hotspot_data['事故类型'] == '财损'])}起",
|
||||
)
|
||||
with metric_cols[2]:
|
||||
st.metric("涉及区域", f"{hotspot_data['所在街道'].nunique()}个街道")
|
||||
|
||||
st.subheader("🔧 分析参数设置")
|
||||
settings_cols = st.columns(3)
|
||||
with settings_cols[0]:
|
||||
time_window = st.selectbox(
|
||||
"统计时间窗口",
|
||||
options=["7D", "15D", "30D"],
|
||||
index=0,
|
||||
key="hotspot_window",
|
||||
)
|
||||
with settings_cols[1]:
|
||||
min_accidents = st.number_input(
|
||||
"最小事故数", min_value=1, max_value=50, value=3, key="hotspot_min_accidents"
|
||||
)
|
||||
with settings_cols[2]:
|
||||
top_n = st.slider("显示热点数量", min_value=3, max_value=20, value=8, key="hotspot_top_n")
|
||||
|
||||
if not st.button("🚀 开始热点分析", type="primary"):
|
||||
return
|
||||
|
||||
with st.spinner("正在分析事故热点分布…"):
|
||||
hotspots = analyze_hotspot_frequency(hotspot_data, time_window=time_window)
|
||||
hotspots = hotspots[hotspots["accident_count"] >= min_accidents]
|
||||
|
||||
if hotspots.empty:
|
||||
st.warning("⚠️ 未找到符合条件的事故热点数据,请调整筛选参数。")
|
||||
return
|
||||
|
||||
hotspots_with_risk = calculate_hotspot_risk_score(hotspots.head(top_n * 3))
|
||||
top_hotspots = hotspots_with_risk.head(top_n)
|
||||
|
||||
st.subheader("📊 事故多发路口排名(前{0}个)".format(top_n))
|
||||
display_columns = {
|
||||
"accident_count": "累计事故数",
|
||||
"recent_count": "近期事故数",
|
||||
"trend_ratio": "趋势比例",
|
||||
"main_accident_type": "主要类型",
|
||||
"main_intersection_type": "路口类型",
|
||||
"risk_score": "风险评分",
|
||||
"risk_level": "风险等级",
|
||||
}
|
||||
display_df = top_hotspots[list(display_columns.keys())].rename(columns=display_columns)
|
||||
styled_df = display_df.style.format({"趋势比例": "{:.2f}", "风险评分": "{:.1f}"}).background_gradient(
|
||||
subset=["风险评分"], cmap="Reds"
|
||||
)
|
||||
st.dataframe(styled_df, use_container_width=True)
|
||||
|
||||
st.subheader("🎯 针对性策略建议")
|
||||
strategies = generate_hotspot_strategies(top_hotspots, time_period="本周")
|
||||
for index, strategy_info in enumerate(strategies, start=1):
|
||||
message = f"**{index}. {strategy_info['strategy']}**"
|
||||
risk_level = strategy_info["risk_level"]
|
||||
if risk_level == "高风险":
|
||||
st.error(f"🚨 {message}")
|
||||
elif risk_level == "中风险":
|
||||
st.warning(f"⚠️ {message}")
|
||||
else:
|
||||
st.info(f"✅ {message}")
|
||||
|
||||
st.subheader("📈 数据分析可视化")
|
||||
chart_cols = st.columns(2)
|
||||
with chart_cols[0]:
|
||||
fig_hotspots = px.bar(
|
||||
top_hotspots.head(10),
|
||||
x=top_hotspots.head(10).index,
|
||||
y=["accident_count", "recent_count"],
|
||||
title="事故频次TOP10分布",
|
||||
labels={"value": "事故数量", "variable": "类型", "index": "路口名称"},
|
||||
barmode="group",
|
||||
)
|
||||
fig_hotspots.update_layout(xaxis_tickangle=-45)
|
||||
st.plotly_chart(fig_hotspots, use_container_width=True)
|
||||
|
||||
with chart_cols[1]:
|
||||
risk_distribution = top_hotspots["risk_level"].value_counts()
|
||||
fig_risk = px.pie(
|
||||
values=risk_distribution.values,
|
||||
names=risk_distribution.index,
|
||||
title="风险等级分布",
|
||||
color_discrete_map={"高风险": "red", "中风险": "orange", "低风险": "green"},
|
||||
)
|
||||
st.plotly_chart(fig_risk, use_container_width=True)
|
||||
|
||||
st.subheader("💾 数据导出")
|
||||
download_cols = st.columns(2)
|
||||
with download_cols[0]:
|
||||
hotspot_csv = top_hotspots.to_csv().encode("utf-8-sig")
|
||||
st.download_button(
|
||||
"📥 下载热点数据CSV",
|
||||
data=hotspot_csv,
|
||||
file_name=f"accident_hotspots_{datetime.now().strftime('%Y%m%d')}.csv",
|
||||
mime="text/csv",
|
||||
)
|
||||
|
||||
with download_cols[1]:
|
||||
serializable = serialise_datetime_columns(
|
||||
top_hotspots.reset_index(),
|
||||
columns=[col for col in top_hotspots.columns if "time" in col or "date" in col],
|
||||
)
|
||||
report_payload = {
|
||||
"analysis_time": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
||||
"time_window": time_window,
|
||||
"data_source": accident_source_name or "事故数据",
|
||||
"total_records": int(len(hotspot_data)),
|
||||
"analysis_parameters": {"min_accidents": int(min_accidents), "top_n": int(top_n)},
|
||||
"top_hotspots": serializable.to_dict("records"),
|
||||
"recommended_strategies": strategies,
|
||||
"summary": {
|
||||
"high_risk_count": int((top_hotspots["risk_level"] == "高风险").sum()),
|
||||
"medium_risk_count": int((top_hotspots["risk_level"] == "中风险").sum()),
|
||||
"total_analyzed_locations": int(len(hotspots)),
|
||||
"most_dangerous_location": top_hotspots.index[0]
|
||||
if len(top_hotspots)
|
||||
else "无",
|
||||
},
|
||||
}
|
||||
st.download_button(
|
||||
"📄 下载完整分析报告",
|
||||
data=json.dumps(report_payload, ensure_ascii=False, indent=2),
|
||||
file_name=f"hotspot_analysis_report_{datetime.now().strftime('%Y%m%d_%H%M')}.json",
|
||||
mime="application/json",
|
||||
)
|
||||
|
||||
with st.expander("📋 查看原始数据预览"):
|
||||
preview_cols = ["事故时间", "所在街道", "事故类型", "事故具体地点", "道路类型"]
|
||||
preview_df = hotspot_data[preview_cols].copy()
|
||||
st.dataframe(preview_df.head(10), use_container_width=True)
|
||||
|
||||
32
ui_sections/model_eval.py
Normal file
32
ui_sections/model_eval.py
Normal file
@@ -0,0 +1,32 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pandas as pd
|
||||
import streamlit as st
|
||||
|
||||
from services.metrics import evaluate_models
|
||||
|
||||
|
||||
def render_model_eval(base: pd.DataFrame):
|
||||
st.subheader("模型预测效果对比")
|
||||
with st.form(key="model_eval_form"):
|
||||
horizon_sel = st.slider("评估窗口(天)", 7, 60, 30, step=1)
|
||||
submit_eval = st.form_submit_button("应用评估参数")
|
||||
|
||||
if not submit_eval:
|
||||
st.info("请设置评估窗口并点击“应用评估参数”按钮。")
|
||||
return
|
||||
|
||||
try:
|
||||
df_metrics = evaluate_models(base['accident_count'], horizon=int(horizon_sel))
|
||||
st.dataframe(df_metrics, use_container_width=True)
|
||||
best_model = df_metrics['RMSE'].idxmin()
|
||||
st.success(f"过去 {int(horizon_sel)} 天中,RMSE 最低的模型是:**{best_model}**")
|
||||
st.download_button(
|
||||
"下载评估结果 CSV",
|
||||
data=df_metrics.to_csv().encode('utf-8-sig'),
|
||||
file_name="model_evaluation.csv",
|
||||
mime="text/csv",
|
||||
)
|
||||
except ValueError as err:
|
||||
st.warning(str(err))
|
||||
|
||||
33
ui_sections/overview.py
Normal file
33
ui_sections/overview.py
Normal file
@@ -0,0 +1,33 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import pandas as pd
|
||||
import plotly.graph_objects as go
|
||||
import streamlit as st
|
||||
|
||||
def render_overview(base: pd.DataFrame, region_sel: str, start_dt: pd.Timestamp, end_dt: pd.Timestamp,
|
||||
strat_filter: list[str]):
|
||||
fig_line = go.Figure()
|
||||
fig_line.add_trace(go.Scatter(x=base.index, y=base['accident_count'], name='事故数', mode='lines'))
|
||||
fig_line.update_layout(title="事故数(过滤后)", xaxis_title="Date", yaxis_title="Count")
|
||||
st.plotly_chart(fig_line, use_container_width=True)
|
||||
|
||||
html = fig_line.to_html(full_html=True, include_plotlyjs='cdn')
|
||||
st.download_button("下载图表 HTML", data=html.encode('utf-8'),
|
||||
file_name="overview_series.html", mime="text/html")
|
||||
|
||||
st.dataframe(base, use_container_width=True)
|
||||
csv_bytes = base.to_csv(index=True).encode('utf-8-sig')
|
||||
st.download_button("下载当前视图 CSV", data=csv_bytes, file_name="filtered_view.csv", mime="text/csv")
|
||||
|
||||
meta = {
|
||||
"region": region_sel,
|
||||
"date_range": [str(start_dt.date()), str(end_dt.date())],
|
||||
"strategy_filter": strat_filter,
|
||||
"rows": int(len(base)),
|
||||
"min_date": str(base.index.min().date()) if len(base) else None,
|
||||
"max_date": str(base.index.max().date()) if len(base) else None,
|
||||
}
|
||||
st.download_button("下载运行参数 JSON", data=json.dumps(meta, ensure_ascii=False, indent=2).encode('utf-8'),
|
||||
file_name="run_metadata.json", mime="application/json")
|
||||
|
||||
50
ui_sections/strategy_eval.py
Normal file
50
ui_sections/strategy_eval.py
Normal file
@@ -0,0 +1,50 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import pandas as pd
|
||||
import streamlit as st
|
||||
|
||||
from services.strategy import generate_output_and_recommendations
|
||||
|
||||
|
||||
def render_strategy_eval(base: pd.DataFrame, all_strategy_types: list[str], region_sel: str):
|
||||
st.info(f"📌 检测到的策略类型:{', '.join(all_strategy_types) or '(数据中没有策略)'}")
|
||||
if not all_strategy_types:
|
||||
st.warning("数据中没有检测到策略。")
|
||||
return
|
||||
|
||||
with st.form(key="strategy_eval_form"):
|
||||
horizon_eval = st.slider("评估窗口(天)", 7, 60, 14, step=1)
|
||||
submit_strat_eval = st.form_submit_button("应用评估参数")
|
||||
|
||||
if not submit_strat_eval:
|
||||
st.info("请设置评估窗口并点击“应用评估参数”按钮。")
|
||||
return
|
||||
|
||||
results, recommendation = generate_output_and_recommendations(
|
||||
base,
|
||||
all_strategy_types,
|
||||
region=region_sel if region_sel != '全市' else '全市',
|
||||
horizon=horizon_eval,
|
||||
)
|
||||
|
||||
if not results:
|
||||
st.warning("⚠️ 未能完成策略评估。请尝试缩短评估窗口或扩大日期范围。")
|
||||
return
|
||||
|
||||
st.subheader("各策略指标")
|
||||
df_res = pd.DataFrame(results).T
|
||||
st.dataframe(df_res, use_container_width=True)
|
||||
st.success(f"⭐ 推荐:{recommendation}")
|
||||
|
||||
st.download_button(
|
||||
"下载策略评估 CSV",
|
||||
data=df_res.to_csv().encode('utf-8-sig'),
|
||||
file_name="strategy_evaluation_results.csv",
|
||||
mime="text/csv",
|
||||
)
|
||||
|
||||
if os.path.exists('recommendation.txt'):
|
||||
with open('recommendation.txt','r',encoding='utf-8') as f:
|
||||
st.download_button("下载推荐文本", data=f.read().encode('utf-8'), file_name="recommendation.txt")
|
||||
|
||||
Reference in New Issue
Block a user