【tiny-dnn】ミニバッチやエポックの区切りで処理させたいとき
六花です。
前回までで、一応学習ができるようになったと思います。
しかし、おそらく次に出てきた欲求のせいでちょっとはまることになると思います。
今回はそれについてお話します。
■環境
Visual Studio Community 2017
tiny-dnn (2018/12/16ダウンロード)
■.train()は学習が終わるまで帰ってこない
前回のやり方だとおそらく超高速で終わったはずなので問題はなかったのですが、
学習データが大量になり、収束のためにエポック数を増やすと、当然処理時間は長くなります。
動いているんだかわからない状況、いやですよね。
途中経過を知りたくなると思います。
そこで、私はこういう処理を書きました。
■.train()は一話完結型
// オプティマイザの決定
tiny_dnn::adam optimizer;
for(int i = 0 ; i < 100; ++i)
{
// 学習実行
net.train<tiny_dnn::cross_entropy_multiclass>(optimizer, input_train, label_train, 64, 1);
// 誤差の計測
auto loss_train = net.get_loss<tiny_dnn::mse>(input_train, teach_train) / (int)((float)max_record * 0.7f);
auto loss_test = net.get_loss<tiny_dnn::mse>(input_test, teach_test ) / (int)((float)max_record * 0.3f);
// 出力
cout << loss_train << "\t" << loss_test << endl;
}
この処理だと、一見「1エポックの学習ごとに誤差の計測が出力される」と予想されると思います。
事実そうです。
この処理は確かにその通りの動作をします。
しかし、罠が存在します。
オプティマイザとして使用しているadamは、その仕様上変動する学習率を自前で所持しています。
これが、.train()が開始されたときに初期化されてしまうのです。
処理が進むに従い小さくなった学習率がまた大きな初期値に戻ってしまうのですから、学習が台無しになってしまいます!
■じゃあどうするの?
.train()は関数オーバーロードにより追加で引数を渡すことができます。
それが、OnBatchEnumerateとOnEpochEnumerateです。
つまりどういうことかというと、こう書けます。
auto onBatchEnumerate = [&]()
{
// ミニバッチが終わった時の処理
cout << "end mini batch" << endl;
};
auto onEpochEnumerate = [&]()
{
// エポックが終わった時の処理
cout << "end epoch" << endl;
};
// 学習実行
net.train<tiny_dnn::cross_entropy_multiclass>(optimizer, input_train, label_train, 64, 3, onBatchEnumerate, onEpochEnumerate);
実行してみると、沢山出力されるようになったのがわかると思います。
ちなみに私がこの機能に気づいたのは、リカレントネットワークのexampleっぽいソースコードを読んでいたときです。
tiny-dnnにドキュメントはないのだろうか。
皆さんも頑張ってtiny-dnnのうまい使い方を探してみてくださいね!