NeuralProphet之四:事件(Events)
在预测问题中,经常需要考虑反复出现的特殊事件。
1 生成数据
首先生成事件训练数据
# data_location = "https://raw.githubusercontent.com/ourownstory/neuralprophet-data/main/datasets/"
data_location = 'datasets/'
df = pd.read_csv(data_location + 'wp_log_peyton_manning.csv')
playoffs_history = pd.DataFrame({
'event': 'playoff',
'ds': pd.to_datetime(['2008-01-13', '2009-01-03', '2010-01-16',
'2010-01-24', '2010-02-07', '2011-01-08',
'2013-01-12', '2014-01-12', '2014-01-19',
'2014-02-02', '2015-01-11', '2016-01-17']),
})
superbowls_history = pd.DataFrame({
'event': 'superbowl',
'ds': pd.to_datetime(['2010-02-07', '2014-02-02']),
})
history_events_df = pd.concat((playoffs_history, superbowls_history))
print(history_events_df)
history_df = m.create_df_with_events(df, history_events_df)
print(history_df)
生成事件预测数据
playoffs_future = pd.DataFrame({
'event': 'playoff',
'ds': pd.to_datetime(['2016-01-21', '2016-02-07'])
})
superbowl_future = pd.DataFrame({
'event': 'superbowl',
'ds': pd.to_datetime(['2016-01-23', '2016-02-07'])
})
future_events_df = pd.concat((playoffs_future, superbowl_future))
print(future_events_df)
2 注册
通过add_events
函数为NeuralProphet
对象添加事件配置
m = NeuralProphet(
n_forecasts=10,
yearly_seasonality=False,
weekly_seasonality=False,
daily_seasonality=False,
)
m = m.add_events(["superbowl", "playoff"])
事件参数配置
- 模式:
additive
和multiplicative
m = m.add_events(["superbowl", "playoff"], mode="multiplicative")
- 事件窗口:将一个特定事件周围的日子也视为特殊事件
m = m.add_events(["superbowl", "playoff"], lower_window=-1, upper_window=1)
也可以单独针对不同事件设置窗口
m = m.add_events("superbowl", lower_window=-1, upper_window=1)
m = m.add_events("playoff", upper_window=2)
- 国家法定假期:
neural_prophet
支持标准的特定国家的假期
m = m.add_country_holidays("US", mode="additive", lower_window=-1, upper_window=1)
- 正则化:
m = m.add_events(["superbowl", "playoff"], regularization=0.05)
也可以单独针对不同事件正则化
m = m.add_events("superbowl", regularization=0.05)
m = m.add_events("playoff", regularization=0.03)
3 训练与预测
metrics = m.fit(history_df, freq="D")
forecast = m.predict(df=history_df)
fig = m.plot(forecast)