modify: cleanup project structure and docs

This commit is contained in:
2025-11-02 08:40:28 +08:00
parent a5e3c4c1da
commit 5825cf81b7
19 changed files with 1903 additions and 1125 deletions

13
ui_sections/__init__.py Normal file
View File

@@ -0,0 +1,13 @@
from .overview import render_overview
from .forecast import render_forecast
from .model_eval import render_model_eval
from .strategy_eval import render_strategy_eval
from .hotspot import render_hotspot
__all__ = [
'render_overview',
'render_forecast',
'render_model_eval',
'render_strategy_eval',
'render_hotspot',
]

185
ui_sections/forecast.py Normal file
View File

@@ -0,0 +1,185 @@
from __future__ import annotations
import pandas as pd
import plotly.graph_objects as go
import streamlit as st
from services.forecast import (
arima_forecast_with_grid_search,
knn_forecast_counterfactual,
fit_and_extrapolate,
)
def render_forecast(base: pd.DataFrame):
st.subheader("多模型预测比较")
if base is None or base.empty:
st.info("暂无可用于预测的事故数据,请先在侧边栏上传数据并应用筛选。")
st.session_state.setdefault(
"forecast_state",
{"results": None, "last_message": "暂无可用于预测的事故数据。"},
)
return
forecast_state = st.session_state.setdefault(
"forecast_state",
{
"selected_date": None,
"horizon": 30,
"results": None,
"last_message": None,
"data_signature": None,
},
)
earliest_date = base.index.min().date()
latest_date = base.index.max().date()
fallback_date = max(
(base.index.max() - pd.Timedelta(days=30)).date(),
earliest_date,
)
current_signature = (
earliest_date.isoformat(),
latest_date.isoformat(),
int(len(base)),
float(base["accident_count"].sum()),
)
# Reset cached results if the underlying dataset has changed
if forecast_state.get("data_signature") != current_signature:
forecast_state.update(
{
"data_signature": current_signature,
"results": None,
"last_message": None,
"selected_date": fallback_date,
}
)
default_date = forecast_state.get("selected_date") or fallback_date
if default_date < earliest_date:
default_date = earliest_date
if default_date > latest_date:
default_date = latest_date
with st.form(key="predict_form"):
selected_date = st.date_input(
"选择干预日期 / 预测起点",
value=default_date,
min_value=earliest_date,
max_value=latest_date,
)
horizon = st.number_input(
"预测天数",
min_value=7,
max_value=90,
value=int(forecast_state.get("horizon", 30)),
step=1,
)
submit_predict = st.form_submit_button("应用预测参数")
forecast_state["selected_date"] = selected_date
forecast_state["horizon"] = int(horizon)
if submit_predict:
history = base.loc[:pd.to_datetime(selected_date)]
if len(history) < 10:
forecast_state.update(
{
"results": None,
"last_message": "干预前数据不足(至少需要 10 个观测点)。",
}
)
else:
with st.spinner("正在生成预测结果…"):
warnings: list[str] = []
try:
train_series = history["accident_count"]
arima_df = arima_forecast_with_grid_search(
train_series,
start_date=pd.to_datetime(selected_date) + pd.Timedelta(days=1),
horizon=int(horizon),
)
except Exception as exc:
arima_df = None
warnings.append(f"ARIMA 运行失败:{exc}")
knn_pred, _ = knn_forecast_counterfactual(
base["accident_count"],
pd.to_datetime(selected_date),
horizon=int(horizon),
)
if knn_pred is None:
warnings.append("KNN 预测未生成结果(历史数据不足或维度不满足要求)。")
glm_pred, svr_pred, _ = fit_and_extrapolate(
base["accident_count"],
pd.to_datetime(selected_date),
days=int(horizon),
)
if glm_pred is None and svr_pred is None:
warnings.append("GLM/SVR 预测未生成结果,建议缩短预测窗口或检查源数据。")
forecast_state.update(
{
"results": {
"selected_date": selected_date,
"horizon": int(horizon),
"arima_df": arima_df,
"knn_pred": knn_pred,
"glm_pred": glm_pred,
"svr_pred": svr_pred,
"warnings": warnings,
},
"last_message": None,
}
)
results = forecast_state.get("results")
if not results:
if forecast_state.get("last_message"):
st.warning(forecast_state["last_message"])
else:
st.info("请设置预测参数并点击“应用预测参数”按钮。")
return
first_date = pd.to_datetime(results["selected_date"])
horizon_days = int(results["horizon"])
arima_df = results["arima_df"]
knn_pred = results["knn_pred"]
glm_pred = results["glm_pred"]
svr_pred = results["svr_pred"]
fig_pred = go.Figure()
fig_pred.add_trace(
go.Scatter(x=base.index, y=base["accident_count"], name="实际", mode="lines")
)
if arima_df is not None:
fig_pred.add_trace(
go.Scatter(x=arima_df.index, y=arima_df["forecast"], name="ARIMA", mode="lines")
)
if knn_pred is not None:
fig_pred.add_trace(go.Scatter(x=knn_pred.index, y=knn_pred, name="KNN", mode="lines"))
if glm_pred is not None:
fig_pred.add_trace(go.Scatter(x=glm_pred.index, y=glm_pred, name="GLM", mode="lines"))
if svr_pred is not None:
fig_pred.add_trace(go.Scatter(x=svr_pred.index, y=svr_pred, name="SVR", mode="lines"))
fig_pred.update_layout(
title=f"多模型预测比较(起点:{first_date.date()},预测 {horizon_days} 天)",
xaxis_title="日期",
yaxis_title="事故数",
)
st.plotly_chart(fig_pred, use_container_width=True)
if arima_df is not None:
st.download_button(
"下载 ARIMA 预测 CSV",
data=arima_df.to_csv().encode("utf-8-sig"),
file_name="arima_forecast.csv",
mime="text/csv",
)
for warning_text in results.get("warnings", []):
st.warning(warning_text)

189
ui_sections/hotspot.py Normal file
View File

@@ -0,0 +1,189 @@
from __future__ import annotations
import json
from datetime import datetime
import plotly.express as px
import streamlit as st
from services.hotspot import (
analyze_hotspot_frequency,
calculate_hotspot_risk_score,
generate_hotspot_strategies,
prepare_hotspot_dataset,
serialise_datetime_columns,
)
@st.cache_data(show_spinner=False)
def _prepare_hotspot_data(df):
return prepare_hotspot_dataset(df)
def render_hotspot(accident_records, accident_source_name: str | None) -> None:
st.header("📍 事故多发路口分析")
st.markdown("独立分析事故数据,识别高风险路口并生成针对性策略。")
if accident_records is None:
st.info("请在左侧上传事故数据并点击“应用数据与筛选”后再执行热点分析。")
st.markdown(
"""
### 📝 支持的数据格式要求:
- **文件格式**Excel (.xlsx)
- **必要字段**
- `事故时间`
- `事故类型`
- `事故具体地点`
- `所在街道`
- `道路类型`
- `路口路段类型`
"""
)
return
with st.spinner("正在准备事故热点数据…"):
hotspot_data = _prepare_hotspot_data(accident_records)
st.success(f"✅ 成功加载数据:{len(hotspot_data)} 条事故记录")
metric_cols = st.columns(3)
with metric_cols[0]:
st.metric(
"数据时间范围",
f"{hotspot_data['事故时间'].min().strftime('%Y-%m-%d')}{hotspot_data['事故时间'].max().strftime('%Y-%m-%d')}",
)
with metric_cols[1]:
st.metric(
"事故类型分布",
f"财损: {len(hotspot_data[hotspot_data['事故类型'] == '财损'])}",
)
with metric_cols[2]:
st.metric("涉及区域", f"{hotspot_data['所在街道'].nunique()}个街道")
st.subheader("🔧 分析参数设置")
settings_cols = st.columns(3)
with settings_cols[0]:
time_window = st.selectbox(
"统计时间窗口",
options=["7D", "15D", "30D"],
index=0,
key="hotspot_window",
)
with settings_cols[1]:
min_accidents = st.number_input(
"最小事故数", min_value=1, max_value=50, value=3, key="hotspot_min_accidents"
)
with settings_cols[2]:
top_n = st.slider("显示热点数量", min_value=3, max_value=20, value=8, key="hotspot_top_n")
if not st.button("🚀 开始热点分析", type="primary"):
return
with st.spinner("正在分析事故热点分布…"):
hotspots = analyze_hotspot_frequency(hotspot_data, time_window=time_window)
hotspots = hotspots[hotspots["accident_count"] >= min_accidents]
if hotspots.empty:
st.warning("⚠️ 未找到符合条件的事故热点数据,请调整筛选参数。")
return
hotspots_with_risk = calculate_hotspot_risk_score(hotspots.head(top_n * 3))
top_hotspots = hotspots_with_risk.head(top_n)
st.subheader("📊 事故多发路口排名(前{0}个)".format(top_n))
display_columns = {
"accident_count": "累计事故数",
"recent_count": "近期事故数",
"trend_ratio": "趋势比例",
"main_accident_type": "主要类型",
"main_intersection_type": "路口类型",
"risk_score": "风险评分",
"risk_level": "风险等级",
}
display_df = top_hotspots[list(display_columns.keys())].rename(columns=display_columns)
styled_df = display_df.style.format({"趋势比例": "{:.2f}", "风险评分": "{:.1f}"}).background_gradient(
subset=["风险评分"], cmap="Reds"
)
st.dataframe(styled_df, use_container_width=True)
st.subheader("🎯 针对性策略建议")
strategies = generate_hotspot_strategies(top_hotspots, time_period="本周")
for index, strategy_info in enumerate(strategies, start=1):
message = f"**{index}. {strategy_info['strategy']}**"
risk_level = strategy_info["risk_level"]
if risk_level == "高风险":
st.error(f"🚨 {message}")
elif risk_level == "中风险":
st.warning(f"⚠️ {message}")
else:
st.info(f"{message}")
st.subheader("📈 数据分析可视化")
chart_cols = st.columns(2)
with chart_cols[0]:
fig_hotspots = px.bar(
top_hotspots.head(10),
x=top_hotspots.head(10).index,
y=["accident_count", "recent_count"],
title="事故频次TOP10分布",
labels={"value": "事故数量", "variable": "类型", "index": "路口名称"},
barmode="group",
)
fig_hotspots.update_layout(xaxis_tickangle=-45)
st.plotly_chart(fig_hotspots, use_container_width=True)
with chart_cols[1]:
risk_distribution = top_hotspots["risk_level"].value_counts()
fig_risk = px.pie(
values=risk_distribution.values,
names=risk_distribution.index,
title="风险等级分布",
color_discrete_map={"高风险": "red", "中风险": "orange", "低风险": "green"},
)
st.plotly_chart(fig_risk, use_container_width=True)
st.subheader("💾 数据导出")
download_cols = st.columns(2)
with download_cols[0]:
hotspot_csv = top_hotspots.to_csv().encode("utf-8-sig")
st.download_button(
"📥 下载热点数据CSV",
data=hotspot_csv,
file_name=f"accident_hotspots_{datetime.now().strftime('%Y%m%d')}.csv",
mime="text/csv",
)
with download_cols[1]:
serializable = serialise_datetime_columns(
top_hotspots.reset_index(),
columns=[col for col in top_hotspots.columns if "time" in col or "date" in col],
)
report_payload = {
"analysis_time": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
"time_window": time_window,
"data_source": accident_source_name or "事故数据",
"total_records": int(len(hotspot_data)),
"analysis_parameters": {"min_accidents": int(min_accidents), "top_n": int(top_n)},
"top_hotspots": serializable.to_dict("records"),
"recommended_strategies": strategies,
"summary": {
"high_risk_count": int((top_hotspots["risk_level"] == "高风险").sum()),
"medium_risk_count": int((top_hotspots["risk_level"] == "中风险").sum()),
"total_analyzed_locations": int(len(hotspots)),
"most_dangerous_location": top_hotspots.index[0]
if len(top_hotspots)
else "",
},
}
st.download_button(
"📄 下载完整分析报告",
data=json.dumps(report_payload, ensure_ascii=False, indent=2),
file_name=f"hotspot_analysis_report_{datetime.now().strftime('%Y%m%d_%H%M')}.json",
mime="application/json",
)
with st.expander("📋 查看原始数据预览"):
preview_cols = ["事故时间", "所在街道", "事故类型", "事故具体地点", "道路类型"]
preview_df = hotspot_data[preview_cols].copy()
st.dataframe(preview_df.head(10), use_container_width=True)

32
ui_sections/model_eval.py Normal file
View File

@@ -0,0 +1,32 @@
from __future__ import annotations
import pandas as pd
import streamlit as st
from services.metrics import evaluate_models
def render_model_eval(base: pd.DataFrame):
st.subheader("模型预测效果对比")
with st.form(key="model_eval_form"):
horizon_sel = st.slider("评估窗口(天)", 7, 60, 30, step=1)
submit_eval = st.form_submit_button("应用评估参数")
if not submit_eval:
st.info("请设置评估窗口并点击“应用评估参数”按钮。")
return
try:
df_metrics = evaluate_models(base['accident_count'], horizon=int(horizon_sel))
st.dataframe(df_metrics, use_container_width=True)
best_model = df_metrics['RMSE'].idxmin()
st.success(f"过去 {int(horizon_sel)} 天中RMSE 最低的模型是:**{best_model}**")
st.download_button(
"下载评估结果 CSV",
data=df_metrics.to_csv().encode('utf-8-sig'),
file_name="model_evaluation.csv",
mime="text/csv",
)
except ValueError as err:
st.warning(str(err))

33
ui_sections/overview.py Normal file
View File

@@ -0,0 +1,33 @@
from __future__ import annotations
import json
import pandas as pd
import plotly.graph_objects as go
import streamlit as st
def render_overview(base: pd.DataFrame, region_sel: str, start_dt: pd.Timestamp, end_dt: pd.Timestamp,
strat_filter: list[str]):
fig_line = go.Figure()
fig_line.add_trace(go.Scatter(x=base.index, y=base['accident_count'], name='事故数', mode='lines'))
fig_line.update_layout(title="事故数(过滤后)", xaxis_title="Date", yaxis_title="Count")
st.plotly_chart(fig_line, use_container_width=True)
html = fig_line.to_html(full_html=True, include_plotlyjs='cdn')
st.download_button("下载图表 HTML", data=html.encode('utf-8'),
file_name="overview_series.html", mime="text/html")
st.dataframe(base, use_container_width=True)
csv_bytes = base.to_csv(index=True).encode('utf-8-sig')
st.download_button("下载当前视图 CSV", data=csv_bytes, file_name="filtered_view.csv", mime="text/csv")
meta = {
"region": region_sel,
"date_range": [str(start_dt.date()), str(end_dt.date())],
"strategy_filter": strat_filter,
"rows": int(len(base)),
"min_date": str(base.index.min().date()) if len(base) else None,
"max_date": str(base.index.max().date()) if len(base) else None,
}
st.download_button("下载运行参数 JSON", data=json.dumps(meta, ensure_ascii=False, indent=2).encode('utf-8'),
file_name="run_metadata.json", mime="application/json")

View File

@@ -0,0 +1,50 @@
from __future__ import annotations
import os
import pandas as pd
import streamlit as st
from services.strategy import generate_output_and_recommendations
def render_strategy_eval(base: pd.DataFrame, all_strategy_types: list[str], region_sel: str):
st.info(f"📌 检测到的策略类型:{', '.join(all_strategy_types) or '(数据中没有策略)'}")
if not all_strategy_types:
st.warning("数据中没有检测到策略。")
return
with st.form(key="strategy_eval_form"):
horizon_eval = st.slider("评估窗口(天)", 7, 60, 14, step=1)
submit_strat_eval = st.form_submit_button("应用评估参数")
if not submit_strat_eval:
st.info("请设置评估窗口并点击“应用评估参数”按钮。")
return
results, recommendation = generate_output_and_recommendations(
base,
all_strategy_types,
region=region_sel if region_sel != '全市' else '全市',
horizon=horizon_eval,
)
if not results:
st.warning("⚠️ 未能完成策略评估。请尝试缩短评估窗口或扩大日期范围。")
return
st.subheader("各策略指标")
df_res = pd.DataFrame(results).T
st.dataframe(df_res, use_container_width=True)
st.success(f"⭐ 推荐:{recommendation}")
st.download_button(
"下载策略评估 CSV",
data=df_res.to_csv().encode('utf-8-sig'),
file_name="strategy_evaluation_results.csv",
mime="text/csv",
)
if os.path.exists('recommendation.txt'):
with open('recommendation.txt','r',encoding='utf-8') as f:
st.download_button("下载推荐文本", data=f.read().encode('utf-8'), file_name="recommendation.txt")