186 lines
6.4 KiB
Python
186 lines
6.4 KiB
Python
from __future__ import annotations
|
|
|
|
import pandas as pd
|
|
import plotly.graph_objects as go
|
|
import streamlit as st
|
|
|
|
from services.forecast import (
|
|
arima_forecast_with_grid_search,
|
|
knn_forecast_counterfactual,
|
|
fit_and_extrapolate,
|
|
)
|
|
|
|
|
|
def render_forecast(base: pd.DataFrame):
|
|
st.subheader("多模型预测比较")
|
|
|
|
if base is None or base.empty:
|
|
st.info("暂无可用于预测的事故数据,请先在侧边栏上传数据并应用筛选。")
|
|
st.session_state.setdefault(
|
|
"forecast_state",
|
|
{"results": None, "last_message": "暂无可用于预测的事故数据。"},
|
|
)
|
|
return
|
|
|
|
forecast_state = st.session_state.setdefault(
|
|
"forecast_state",
|
|
{
|
|
"selected_date": None,
|
|
"horizon": 30,
|
|
"results": None,
|
|
"last_message": None,
|
|
"data_signature": None,
|
|
},
|
|
)
|
|
|
|
earliest_date = base.index.min().date()
|
|
latest_date = base.index.max().date()
|
|
fallback_date = max(
|
|
(base.index.max() - pd.Timedelta(days=30)).date(),
|
|
earliest_date,
|
|
)
|
|
current_signature = (
|
|
earliest_date.isoformat(),
|
|
latest_date.isoformat(),
|
|
int(len(base)),
|
|
float(base["accident_count"].sum()),
|
|
)
|
|
|
|
# Reset cached results if the underlying dataset has changed
|
|
if forecast_state.get("data_signature") != current_signature:
|
|
forecast_state.update(
|
|
{
|
|
"data_signature": current_signature,
|
|
"results": None,
|
|
"last_message": None,
|
|
"selected_date": fallback_date,
|
|
}
|
|
)
|
|
|
|
default_date = forecast_state.get("selected_date") or fallback_date
|
|
if default_date < earliest_date:
|
|
default_date = earliest_date
|
|
if default_date > latest_date:
|
|
default_date = latest_date
|
|
|
|
with st.form(key="predict_form"):
|
|
selected_date = st.date_input(
|
|
"选择干预日期 / 预测起点",
|
|
value=default_date,
|
|
min_value=earliest_date,
|
|
max_value=latest_date,
|
|
)
|
|
horizon = st.number_input(
|
|
"预测天数",
|
|
min_value=7,
|
|
max_value=90,
|
|
value=int(forecast_state.get("horizon", 30)),
|
|
step=1,
|
|
)
|
|
submit_predict = st.form_submit_button("应用预测参数")
|
|
|
|
forecast_state["selected_date"] = selected_date
|
|
forecast_state["horizon"] = int(horizon)
|
|
|
|
if submit_predict:
|
|
history = base.loc[:pd.to_datetime(selected_date)]
|
|
if len(history) < 10:
|
|
forecast_state.update(
|
|
{
|
|
"results": None,
|
|
"last_message": "干预前数据不足(至少需要 10 个观测点)。",
|
|
}
|
|
)
|
|
else:
|
|
with st.spinner("正在生成预测结果…"):
|
|
warnings: list[str] = []
|
|
try:
|
|
train_series = history["accident_count"]
|
|
arima_df = arima_forecast_with_grid_search(
|
|
train_series,
|
|
start_date=pd.to_datetime(selected_date) + pd.Timedelta(days=1),
|
|
horizon=int(horizon),
|
|
)
|
|
except Exception as exc:
|
|
arima_df = None
|
|
warnings.append(f"ARIMA 运行失败:{exc}")
|
|
|
|
knn_pred, _ = knn_forecast_counterfactual(
|
|
base["accident_count"],
|
|
pd.to_datetime(selected_date),
|
|
horizon=int(horizon),
|
|
)
|
|
if knn_pred is None:
|
|
warnings.append("KNN 预测未生成结果(历史数据不足或维度不满足要求)。")
|
|
|
|
glm_pred, svr_pred, _ = fit_and_extrapolate(
|
|
base["accident_count"],
|
|
pd.to_datetime(selected_date),
|
|
days=int(horizon),
|
|
)
|
|
if glm_pred is None and svr_pred is None:
|
|
warnings.append("GLM/SVR 预测未生成结果,建议缩短预测窗口或检查源数据。")
|
|
|
|
forecast_state.update(
|
|
{
|
|
"results": {
|
|
"selected_date": selected_date,
|
|
"horizon": int(horizon),
|
|
"arima_df": arima_df,
|
|
"knn_pred": knn_pred,
|
|
"glm_pred": glm_pred,
|
|
"svr_pred": svr_pred,
|
|
"warnings": warnings,
|
|
},
|
|
"last_message": None,
|
|
}
|
|
)
|
|
|
|
results = forecast_state.get("results")
|
|
if not results:
|
|
if forecast_state.get("last_message"):
|
|
st.warning(forecast_state["last_message"])
|
|
else:
|
|
st.info("请设置预测参数并点击“应用预测参数”按钮。")
|
|
return
|
|
|
|
first_date = pd.to_datetime(results["selected_date"])
|
|
horizon_days = int(results["horizon"])
|
|
arima_df = results["arima_df"]
|
|
knn_pred = results["knn_pred"]
|
|
glm_pred = results["glm_pred"]
|
|
svr_pred = results["svr_pred"]
|
|
|
|
fig_pred = go.Figure()
|
|
fig_pred.add_trace(
|
|
go.Scatter(x=base.index, y=base["accident_count"], name="实际", mode="lines")
|
|
)
|
|
if arima_df is not None:
|
|
fig_pred.add_trace(
|
|
go.Scatter(x=arima_df.index, y=arima_df["forecast"], name="ARIMA", mode="lines")
|
|
)
|
|
if knn_pred is not None:
|
|
fig_pred.add_trace(go.Scatter(x=knn_pred.index, y=knn_pred, name="KNN", mode="lines"))
|
|
if glm_pred is not None:
|
|
fig_pred.add_trace(go.Scatter(x=glm_pred.index, y=glm_pred, name="GLM", mode="lines"))
|
|
if svr_pred is not None:
|
|
fig_pred.add_trace(go.Scatter(x=svr_pred.index, y=svr_pred, name="SVR", mode="lines"))
|
|
|
|
fig_pred.update_layout(
|
|
title=f"多模型预测比较(起点:{first_date.date()},预测 {horizon_days} 天)",
|
|
xaxis_title="日期",
|
|
yaxis_title="事故数",
|
|
)
|
|
st.plotly_chart(fig_pred, use_container_width=True)
|
|
|
|
if arima_df is not None:
|
|
st.download_button(
|
|
"下载 ARIMA 预测 CSV",
|
|
data=arima_df.to_csv().encode("utf-8-sig"),
|
|
file_name="arima_forecast.csv",
|
|
mime="text/csv",
|
|
)
|
|
|
|
for warning_text in results.get("warnings", []):
|
|
st.warning(warning_text)
|