Files

186 lines
6.4 KiB
Python
Raw Permalink Normal View History

from __future__ import annotations
import pandas as pd
import plotly.graph_objects as go
import streamlit as st
from services.forecast import (
arima_forecast_with_grid_search,
knn_forecast_counterfactual,
fit_and_extrapolate,
)
def render_forecast(base: pd.DataFrame):
st.subheader("多模型预测比较")
if base is None or base.empty:
st.info("暂无可用于预测的事故数据,请先在侧边栏上传数据并应用筛选。")
st.session_state.setdefault(
"forecast_state",
{"results": None, "last_message": "暂无可用于预测的事故数据。"},
)
return
forecast_state = st.session_state.setdefault(
"forecast_state",
{
"selected_date": None,
"horizon": 30,
"results": None,
"last_message": None,
"data_signature": None,
},
)
earliest_date = base.index.min().date()
latest_date = base.index.max().date()
fallback_date = max(
(base.index.max() - pd.Timedelta(days=30)).date(),
earliest_date,
)
current_signature = (
earliest_date.isoformat(),
latest_date.isoformat(),
int(len(base)),
float(base["accident_count"].sum()),
)
# Reset cached results if the underlying dataset has changed
if forecast_state.get("data_signature") != current_signature:
forecast_state.update(
{
"data_signature": current_signature,
"results": None,
"last_message": None,
"selected_date": fallback_date,
}
)
default_date = forecast_state.get("selected_date") or fallback_date
if default_date < earliest_date:
default_date = earliest_date
if default_date > latest_date:
default_date = latest_date
with st.form(key="predict_form"):
selected_date = st.date_input(
"选择干预日期 / 预测起点",
value=default_date,
min_value=earliest_date,
max_value=latest_date,
)
horizon = st.number_input(
"预测天数",
min_value=7,
max_value=90,
value=int(forecast_state.get("horizon", 30)),
step=1,
)
submit_predict = st.form_submit_button("应用预测参数")
forecast_state["selected_date"] = selected_date
forecast_state["horizon"] = int(horizon)
if submit_predict:
history = base.loc[:pd.to_datetime(selected_date)]
if len(history) < 10:
forecast_state.update(
{
"results": None,
"last_message": "干预前数据不足(至少需要 10 个观测点)。",
}
)
else:
with st.spinner("正在生成预测结果…"):
warnings: list[str] = []
try:
train_series = history["accident_count"]
arima_df = arima_forecast_with_grid_search(
train_series,
start_date=pd.to_datetime(selected_date) + pd.Timedelta(days=1),
horizon=int(horizon),
)
except Exception as exc:
arima_df = None
warnings.append(f"ARIMA 运行失败:{exc}")
knn_pred, _ = knn_forecast_counterfactual(
base["accident_count"],
pd.to_datetime(selected_date),
horizon=int(horizon),
)
if knn_pred is None:
warnings.append("KNN 预测未生成结果(历史数据不足或维度不满足要求)。")
glm_pred, svr_pred, _ = fit_and_extrapolate(
base["accident_count"],
pd.to_datetime(selected_date),
days=int(horizon),
)
if glm_pred is None and svr_pred is None:
warnings.append("GLM/SVR 预测未生成结果,建议缩短预测窗口或检查源数据。")
forecast_state.update(
{
"results": {
"selected_date": selected_date,
"horizon": int(horizon),
"arima_df": arima_df,
"knn_pred": knn_pred,
"glm_pred": glm_pred,
"svr_pred": svr_pred,
"warnings": warnings,
},
"last_message": None,
}
)
results = forecast_state.get("results")
if not results:
if forecast_state.get("last_message"):
st.warning(forecast_state["last_message"])
else:
st.info("请设置预测参数并点击“应用预测参数”按钮。")
return
first_date = pd.to_datetime(results["selected_date"])
horizon_days = int(results["horizon"])
arima_df = results["arima_df"]
knn_pred = results["knn_pred"]
glm_pred = results["glm_pred"]
svr_pred = results["svr_pred"]
fig_pred = go.Figure()
fig_pred.add_trace(
go.Scatter(x=base.index, y=base["accident_count"], name="实际", mode="lines")
)
if arima_df is not None:
fig_pred.add_trace(
go.Scatter(x=arima_df.index, y=arima_df["forecast"], name="ARIMA", mode="lines")
)
if knn_pred is not None:
fig_pred.add_trace(go.Scatter(x=knn_pred.index, y=knn_pred, name="KNN", mode="lines"))
if glm_pred is not None:
fig_pred.add_trace(go.Scatter(x=glm_pred.index, y=glm_pred, name="GLM", mode="lines"))
if svr_pred is not None:
fig_pred.add_trace(go.Scatter(x=svr_pred.index, y=svr_pred, name="SVR", mode="lines"))
fig_pred.update_layout(
title=f"多模型预测比较(起点:{first_date.date()},预测 {horizon_days} 天)",
xaxis_title="日期",
yaxis_title="事故数",
)
st.plotly_chart(fig_pred, use_container_width=True)
if arima_df is not None:
st.download_button(
"下载 ARIMA 预测 CSV",
data=arima_df.to_csv().encode("utf-8-sig"),
file_name="arima_forecast.csv",
mime="text/csv",
)
for warning_text in results.get("warnings", []):
st.warning(warning_text)