modify: cleanup project structure and docs
This commit is contained in:
185
ui_sections/forecast.py
Normal file
185
ui_sections/forecast.py
Normal file
@@ -0,0 +1,185 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pandas as pd
|
||||
import plotly.graph_objects as go
|
||||
import streamlit as st
|
||||
|
||||
from services.forecast import (
|
||||
arima_forecast_with_grid_search,
|
||||
knn_forecast_counterfactual,
|
||||
fit_and_extrapolate,
|
||||
)
|
||||
|
||||
|
||||
def render_forecast(base: pd.DataFrame):
|
||||
st.subheader("多模型预测比较")
|
||||
|
||||
if base is None or base.empty:
|
||||
st.info("暂无可用于预测的事故数据,请先在侧边栏上传数据并应用筛选。")
|
||||
st.session_state.setdefault(
|
||||
"forecast_state",
|
||||
{"results": None, "last_message": "暂无可用于预测的事故数据。"},
|
||||
)
|
||||
return
|
||||
|
||||
forecast_state = st.session_state.setdefault(
|
||||
"forecast_state",
|
||||
{
|
||||
"selected_date": None,
|
||||
"horizon": 30,
|
||||
"results": None,
|
||||
"last_message": None,
|
||||
"data_signature": None,
|
||||
},
|
||||
)
|
||||
|
||||
earliest_date = base.index.min().date()
|
||||
latest_date = base.index.max().date()
|
||||
fallback_date = max(
|
||||
(base.index.max() - pd.Timedelta(days=30)).date(),
|
||||
earliest_date,
|
||||
)
|
||||
current_signature = (
|
||||
earliest_date.isoformat(),
|
||||
latest_date.isoformat(),
|
||||
int(len(base)),
|
||||
float(base["accident_count"].sum()),
|
||||
)
|
||||
|
||||
# Reset cached results if the underlying dataset has changed
|
||||
if forecast_state.get("data_signature") != current_signature:
|
||||
forecast_state.update(
|
||||
{
|
||||
"data_signature": current_signature,
|
||||
"results": None,
|
||||
"last_message": None,
|
||||
"selected_date": fallback_date,
|
||||
}
|
||||
)
|
||||
|
||||
default_date = forecast_state.get("selected_date") or fallback_date
|
||||
if default_date < earliest_date:
|
||||
default_date = earliest_date
|
||||
if default_date > latest_date:
|
||||
default_date = latest_date
|
||||
|
||||
with st.form(key="predict_form"):
|
||||
selected_date = st.date_input(
|
||||
"选择干预日期 / 预测起点",
|
||||
value=default_date,
|
||||
min_value=earliest_date,
|
||||
max_value=latest_date,
|
||||
)
|
||||
horizon = st.number_input(
|
||||
"预测天数",
|
||||
min_value=7,
|
||||
max_value=90,
|
||||
value=int(forecast_state.get("horizon", 30)),
|
||||
step=1,
|
||||
)
|
||||
submit_predict = st.form_submit_button("应用预测参数")
|
||||
|
||||
forecast_state["selected_date"] = selected_date
|
||||
forecast_state["horizon"] = int(horizon)
|
||||
|
||||
if submit_predict:
|
||||
history = base.loc[:pd.to_datetime(selected_date)]
|
||||
if len(history) < 10:
|
||||
forecast_state.update(
|
||||
{
|
||||
"results": None,
|
||||
"last_message": "干预前数据不足(至少需要 10 个观测点)。",
|
||||
}
|
||||
)
|
||||
else:
|
||||
with st.spinner("正在生成预测结果…"):
|
||||
warnings: list[str] = []
|
||||
try:
|
||||
train_series = history["accident_count"]
|
||||
arima_df = arima_forecast_with_grid_search(
|
||||
train_series,
|
||||
start_date=pd.to_datetime(selected_date) + pd.Timedelta(days=1),
|
||||
horizon=int(horizon),
|
||||
)
|
||||
except Exception as exc:
|
||||
arima_df = None
|
||||
warnings.append(f"ARIMA 运行失败:{exc}")
|
||||
|
||||
knn_pred, _ = knn_forecast_counterfactual(
|
||||
base["accident_count"],
|
||||
pd.to_datetime(selected_date),
|
||||
horizon=int(horizon),
|
||||
)
|
||||
if knn_pred is None:
|
||||
warnings.append("KNN 预测未生成结果(历史数据不足或维度不满足要求)。")
|
||||
|
||||
glm_pred, svr_pred, _ = fit_and_extrapolate(
|
||||
base["accident_count"],
|
||||
pd.to_datetime(selected_date),
|
||||
days=int(horizon),
|
||||
)
|
||||
if glm_pred is None and svr_pred is None:
|
||||
warnings.append("GLM/SVR 预测未生成结果,建议缩短预测窗口或检查源数据。")
|
||||
|
||||
forecast_state.update(
|
||||
{
|
||||
"results": {
|
||||
"selected_date": selected_date,
|
||||
"horizon": int(horizon),
|
||||
"arima_df": arima_df,
|
||||
"knn_pred": knn_pred,
|
||||
"glm_pred": glm_pred,
|
||||
"svr_pred": svr_pred,
|
||||
"warnings": warnings,
|
||||
},
|
||||
"last_message": None,
|
||||
}
|
||||
)
|
||||
|
||||
results = forecast_state.get("results")
|
||||
if not results:
|
||||
if forecast_state.get("last_message"):
|
||||
st.warning(forecast_state["last_message"])
|
||||
else:
|
||||
st.info("请设置预测参数并点击“应用预测参数”按钮。")
|
||||
return
|
||||
|
||||
first_date = pd.to_datetime(results["selected_date"])
|
||||
horizon_days = int(results["horizon"])
|
||||
arima_df = results["arima_df"]
|
||||
knn_pred = results["knn_pred"]
|
||||
glm_pred = results["glm_pred"]
|
||||
svr_pred = results["svr_pred"]
|
||||
|
||||
fig_pred = go.Figure()
|
||||
fig_pred.add_trace(
|
||||
go.Scatter(x=base.index, y=base["accident_count"], name="实际", mode="lines")
|
||||
)
|
||||
if arima_df is not None:
|
||||
fig_pred.add_trace(
|
||||
go.Scatter(x=arima_df.index, y=arima_df["forecast"], name="ARIMA", mode="lines")
|
||||
)
|
||||
if knn_pred is not None:
|
||||
fig_pred.add_trace(go.Scatter(x=knn_pred.index, y=knn_pred, name="KNN", mode="lines"))
|
||||
if glm_pred is not None:
|
||||
fig_pred.add_trace(go.Scatter(x=glm_pred.index, y=glm_pred, name="GLM", mode="lines"))
|
||||
if svr_pred is not None:
|
||||
fig_pred.add_trace(go.Scatter(x=svr_pred.index, y=svr_pred, name="SVR", mode="lines"))
|
||||
|
||||
fig_pred.update_layout(
|
||||
title=f"多模型预测比较(起点:{first_date.date()},预测 {horizon_days} 天)",
|
||||
xaxis_title="日期",
|
||||
yaxis_title="事故数",
|
||||
)
|
||||
st.plotly_chart(fig_pred, use_container_width=True)
|
||||
|
||||
if arima_df is not None:
|
||||
st.download_button(
|
||||
"下载 ARIMA 预测 CSV",
|
||||
data=arima_df.to_csv().encode("utf-8-sig"),
|
||||
file_name="arima_forecast.csv",
|
||||
mime="text/csv",
|
||||
)
|
||||
|
||||
for warning_text in results.get("warnings", []):
|
||||
st.warning(warning_text)
|
||||
Reference in New Issue
Block a user