GluonTS を使用した仮想通貨の価格予測
Gluon Time Series(GluonTS)は、確率的な予測もサポートした時系列データ向けのライブラリで、独自のデータに基づいて組み込みモデルをトレーニングおよび評価することができます。
今回はこの GluonTS を使用して、仮想通貨の価格予測をしてみます。
価格予測の実装
環境構築
まずは、必要な環境をインストールしましょう。GluonTS のインストールを行います。
pip install --upgrade mxnet==1.7.0.post2
pip install gluonts
また、データの整形・可視化などに Pandas の機能を使用しますので、こちらのインストールも行っておきます。
pip install pandas
学習データを用意
次に学習に使用するデータを用意します。
今回は、DOGE と BNB という二種類の銘柄の予測を行ってみます。2021 年 4 月 10 日における、1 分ごとの価格データを 2021-04-10_BNBUSDT_DOGEUSDT.csv に用意しました。
こちらを読み込んで、学習に使用するためのデータセットを作成します。
# 学習データの開始期間・終了期間
START_DATE = datetime(year=2021, month=4, day=10)
END_DATE = datetime(year=2021, month=4, day=10)
# 予測期間(分)
PREDICTION_LENGTH = 120
# CSV ファイルから価格データの読み込み
df = pd.read_csv('2021-04-10_BNBUSDT_DOGEUSDT.csv')
# 価格データを整形する
stock_dataset = df.pivot(index='symbol', columns='time', values='price')
# データの開始時間を銘柄数分作成
dates = [pd.Timestamp(START_DATE.strftime('%Y-%m-%d'), freq='1min') for _ in range(stock_dataset.shape[0])]
# 学習は総データ数の 120 分を除いたデータ
train_target_values = [ts[:-PREDICTION_LENGTH] for ts in stock_dataset.values]
# テストは全て含まれたデータ
test_target_values = stock_dataset.copy().values
ここで作成したデータを GluonTS で使用するために ListDataset へ格納します。
# 学習 ListDataset を作成
train_ds = ListDataset([
{
FieldName.TARGET: target,
FieldName.START: start,
FieldName.ITEM_ID: code,
}
for (target, start, code) in zip(train_target_values, dates, stock_dataset.index)
], freq='1min')
# テスト ListDataset を作成
test_ds = ListDataset([
{
FieldName.TARGET: target,
FieldName.START: start,
FieldName.ITEM_ID: code,
}
for (target, start, code) in zip(test_target_values, dates, stock_dataset.index)
], freq='1min')
学習させる
それでは用意したデータを学習させてみます。
学習に使用するのは DeepAR というアルゴリズムです。 このアルゴリズムを使用して学習させるには DeepAREstimator オブジェクトを作成します。今回は必須パラメータのみを設定し、残りはデフォルトのパラメータで学習をさせてみます。
estimator = gluonts.model.deepar.DeepAREstimator(
freq='1min'
prediction_length=PREDICTION_LENGTH
)
predictor = estimator.train(train_ds)
上記のコードのように estimator.train
を呼び出すと学習が開始されます。学習の進行中には、コンソールに以下のような表示が出力されます。
(.venv) PS D:\Develop\Python\coin-price-predictor-py> py .\main.py
Multiprocessing is not supported on Windows, num_workers will be set to None.
learning rate from ``lr_scheduler`` has been overwritten by ``learning_rate`` in optimizer.
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:12<00:00, 4.13it/s, epoch=1/100, avg_epoch_loss=1.29]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:11<00:00, 4.47it/s, epoch=2/100, avg_epoch_loss=-.501]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:10<00:00, 4.56it/s, epoch=3/100, avg_epoch_loss=-.392]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:10<00:00, 4.62it/s, epoch=4/100, avg_epoch_loss=-.864]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:10<00:00, 4.71it/s, epoch=5/100, avg_epoch_loss=-.841]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:10<00:00, 4.65it/s, epoch=6/100, avg_epoch_loss=-.65]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:11<00:00, 4.54it/s, epoch=7/100, avg_epoch_loss=-1.02]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:11<00:00, 4.28it/s, epoch=8/100, avg_epoch_loss=-.899]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:11<00:00, 4.46it/s, epoch=9/100, avg_epoch_loss=-.487]
74%|███████████████████████████████████████████████████████████████████████████████████████████████▍ | 37/50 [00:08<00:02, 4.60it/s, epoch=10/100, avg_epoch_loss=-1.12]
学習結果を評価
次に学習結果を評価させてみます。評価には gluonts.evaluation.Evaluator を使用します。
# 学習結果から推論を実行する
forecast_it, ts_it = gluonts.evaluation.backtest.make_evaluation_predictions(
dataset=test_ds,
predictor=predictor,
num_samples=num_samples
)
# 時系列条件付け値の取得
tss = list(ts_it)
# 時系列予測の取得
forecasts = list(forecast_it)
# 評価を実行する
evaluator = gluonts.evaluation.Evaluator(quantiles=[0.5])
agg_metrics, item_metrics = evaluator(iter(tss), iter(forecasts), num_series=len(test_ds))
予測結果の可視化
最後に、予測結果を matplotlib を使用して可視化します。
# プロットの長さ(24時間分)
plot_length = 60 * 24
# サンプリングの 50% が含まれる区間、サンプリングの 90% が含まれる区間
prediction_intervals = (50, 90)
legend = ["実価格", "予測価格中央値"] + [f"{k}% 予測区間" for k in prediction_intervals][::-1]
plt.rcParams['font.family'] = 'BIZ UDGothic' # 日本語を表示できるように
_, ax = plt.subplots(1, 1, figsize=(10, 7))
ts_entry[-plot_length:].plot(ax=ax)
forecast_entry.plot(prediction_intervals=prediction_intervals, color='g')
ax.axvline(ts_entry.index[-PREDICTION_LENGTH], color='r')
plt.legend(legend, loc="upper left")
plt.title(forecast_entry.item_id)
plt.show()
plt.clf()
上記のコードを実行すると、以下のようなグラフが得られます。
BNB はわりといい線なのかな、という気がしますが、DOGE の方は大きく外れていますね。ただ「その予測をした気持ちはわかる」っていう外し方ですね(笑)どちらかというと、DOGE がイレギュラーな動きをしてしまったって感じがします。
ソースコード
今回書いたソースコードの全文は以下にあります。
コメント