140 lines
5.0 KiB
Python
140 lines
5.0 KiB
Python
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
import warnings
|
||
|
|
import pandas as pd
|
||
|
|
import numpy as np
|
||
|
|
import streamlit as st
|
||
|
|
import statsmodels.api as sm
|
||
|
|
from statsmodels.tsa.arima.model import ARIMA
|
||
|
|
from statsmodels.tools.sm_exceptions import ValueWarning
|
||
|
|
from sklearn.neighbors import KNeighborsRegressor
|
||
|
|
from sklearn.svm import SVR
|
||
|
|
from sklearn.pipeline import make_pipeline
|
||
|
|
from sklearn.preprocessing import StandardScaler
|
||
|
|
|
||
|
|
from config.settings import ARIMA_P, ARIMA_D, ARIMA_Q, MAX_PRE_DAYS
|
||
|
|
|
||
|
|
|
||
|
|
@st.cache_data(show_spinner=False)
|
||
|
|
def evaluate_arima_model(series, arima_order):
|
||
|
|
try:
|
||
|
|
model = ARIMA(series, order=arima_order)
|
||
|
|
model_fit = model.fit()
|
||
|
|
return model_fit.aic
|
||
|
|
except Exception:
|
||
|
|
return float("inf")
|
||
|
|
|
||
|
|
|
||
|
|
@st.cache_data(show_spinner=False)
|
||
|
|
def arima_forecast_with_grid_search(accident_series: pd.Series,
|
||
|
|
start_date: pd.Timestamp,
|
||
|
|
horizon: int = 30,
|
||
|
|
p_values: list = tuple(ARIMA_P),
|
||
|
|
d_values: list = tuple(ARIMA_D),
|
||
|
|
q_values: list = tuple(ARIMA_Q)) -> pd.DataFrame:
|
||
|
|
series = accident_series.asfreq('D').fillna(0)
|
||
|
|
start_date = pd.to_datetime(start_date)
|
||
|
|
|
||
|
|
warnings.filterwarnings("ignore", category=ValueWarning)
|
||
|
|
best_score, best_cfg = float("inf"), None
|
||
|
|
for p in p_values:
|
||
|
|
for d in d_values:
|
||
|
|
for q in q_values:
|
||
|
|
order = (p, d, q)
|
||
|
|
try:
|
||
|
|
aic = evaluate_arima_model(series, order)
|
||
|
|
if aic < best_score:
|
||
|
|
best_score, best_cfg = aic, order
|
||
|
|
except Exception:
|
||
|
|
continue
|
||
|
|
|
||
|
|
model = ARIMA(series, order=best_cfg)
|
||
|
|
fit = model.fit()
|
||
|
|
forecast_index = pd.date_range(start=start_date, periods=horizon, freq='D')
|
||
|
|
res = fit.get_forecast(steps=horizon)
|
||
|
|
df = res.summary_frame()
|
||
|
|
df.index = forecast_index
|
||
|
|
df.index.name = 'date'
|
||
|
|
df.rename(columns={'mean': 'forecast'}, inplace=True)
|
||
|
|
return df
|
||
|
|
|
||
|
|
|
||
|
|
def knn_forecast_counterfactual(accident_series: pd.Series,
|
||
|
|
intervention_date: pd.Timestamp,
|
||
|
|
lookback: int = 14,
|
||
|
|
horizon: int = 30):
|
||
|
|
series = accident_series.asfreq('D').fillna(0)
|
||
|
|
intervention_date = pd.to_datetime(intervention_date).normalize()
|
||
|
|
|
||
|
|
df = pd.DataFrame({'y': series})
|
||
|
|
for i in range(1, lookback + 1):
|
||
|
|
df[f'lag_{i}'] = df['y'].shift(i)
|
||
|
|
|
||
|
|
train = df.loc[:intervention_date - pd.Timedelta(days=1)].dropna()
|
||
|
|
if len(train) < 5:
|
||
|
|
return None, None
|
||
|
|
X_train = train.filter(like='lag_').values
|
||
|
|
y_train = train['y'].values
|
||
|
|
knn = KNeighborsRegressor(n_neighbors=5)
|
||
|
|
knn.fit(X_train, y_train)
|
||
|
|
|
||
|
|
history = df.loc[:intervention_date - pd.Timedelta(days=1), 'y'].tolist()
|
||
|
|
preds = []
|
||
|
|
for _ in range(horizon):
|
||
|
|
if len(history) < lookback:
|
||
|
|
return None, None
|
||
|
|
x = np.array(history[-lookback:][::-1]).reshape(1, -1)
|
||
|
|
pred = knn.predict(x)[0]
|
||
|
|
preds.append(pred)
|
||
|
|
history.append(pred)
|
||
|
|
|
||
|
|
pred_index = pd.date_range(intervention_date, periods=horizon, freq='D')
|
||
|
|
return pd.Series(preds, index=pred_index, name='knn_pred'), None
|
||
|
|
|
||
|
|
|
||
|
|
def fit_and_extrapolate(series: pd.Series,
|
||
|
|
intervention_date: pd.Timestamp,
|
||
|
|
days: int = 30,
|
||
|
|
max_pre_days: int = MAX_PRE_DAYS):
|
||
|
|
series = series.asfreq('D').fillna(0)
|
||
|
|
series.index = pd.to_datetime(series.index).tz_localize(None).normalize()
|
||
|
|
intervention_date = pd.to_datetime(intervention_date).tz_localize(None).normalize()
|
||
|
|
|
||
|
|
pre = series.loc[:intervention_date - pd.Timedelta(days=1)]
|
||
|
|
if len(pre) > max_pre_days:
|
||
|
|
pre = pre.iloc[-max_pre_days:]
|
||
|
|
if len(pre) < 3:
|
||
|
|
return None, None, None
|
||
|
|
|
||
|
|
x_pre = np.arange(len(pre))
|
||
|
|
x_future = np.arange(len(pre), len(pre) + days)
|
||
|
|
|
||
|
|
try:
|
||
|
|
X_pre_glm = sm.add_constant(np.column_stack([x_pre, x_pre**2]))
|
||
|
|
glm = sm.GLM(pre.values, X_pre_glm, family=sm.families.Poisson())
|
||
|
|
glm_res = glm.fit()
|
||
|
|
X_future_glm = sm.add_constant(np.column_stack([x_future, x_future**2]))
|
||
|
|
glm_pred = glm_res.predict(X_future_glm)
|
||
|
|
except Exception:
|
||
|
|
glm_pred = None
|
||
|
|
|
||
|
|
try:
|
||
|
|
svr = make_pipeline(StandardScaler(), SVR(kernel='rbf', C=10, gamma=0.1))
|
||
|
|
svr.fit(x_pre.reshape(-1, 1), pre.values)
|
||
|
|
svr_pred = svr.predict(x_future.reshape(-1, 1))
|
||
|
|
except Exception:
|
||
|
|
svr_pred = None
|
||
|
|
|
||
|
|
post_index = pd.date_range(intervention_date, periods=days, freq='D')
|
||
|
|
|
||
|
|
glm_pred = pd.Series(glm_pred, index=post_index, name='glm_pred') if glm_pred is not None else None
|
||
|
|
svr_pred = pd.Series(svr_pred, index=post_index, name='svr_pred') if svr_pred is not None else None
|
||
|
|
|
||
|
|
post = series.reindex(post_index)
|
||
|
|
residuals = None
|
||
|
|
if svr_pred is not None:
|
||
|
|
residuals = pd.Series(post.values - svr_pred[:len(post)], index=post_index, name='residual')
|
||
|
|
|
||
|
|
return glm_pred, svr_pred, residuals
|
||
|
|
|