【tiny-dnn】クラス分類はtrain、回帰分析はfit
六花です。
クリスマスや年末年始でにわかに忙しくなってきました。
折角ブログを作ったので日刊で書いていましたが、年末年始はお休みするかもしれません。
■.fit()
さて、前回学習に使ったメソッドを覚えていらっしゃいますか?
.train()です。
実は、学習のためのメソッドはもう一つあります。
今回はその.fit()メソッドについて書こうと思います。
学習に使ったメソッドを覚えていらっしゃいますか?
.train()です。
実は、学習のためのメソッドはもう一つあります。
今回はその.fit()メソッドについて書こうと思います。
■環境
Visual Studio Community 2017
tiny-dnn (2018/12/16ダウンロード)
■内部的には.train()と同じもの
.train()はクラス分類問題に特化したメソッドですが、.fit()は回帰分析もできる汎用的なメソッドです。(内部の処理は合流しています。)
.train()では引数にラベル(データ一件に対してlabel_tを一つ)を取っていました。
.fit()では、引数にベクトル(データ一件に対してfloat_tの配列(要素数は最終層と同じ数)を一つ)を取ります。
例えば三つのクラス分類を行うとき、.train()に「1」というラベルデータを送ってもいいし、.fit()に「{0.0f, 1.0f, 0.0f}」という教師データを送ってもいいということですね。
前回のコードを流用するならば、以下のようになります。
誤差計測用の教師データが使えそうですね。
// 学習実行
//net.train<tiny_dnn::cross_entropy_multiclass>(optimizer, input_train, label_train, 64, 3, onBatchEnumerate, onEpochEnumerate);
net.fit<tiny_dnn::mse>(optimizer, input_train, teach_train, 64, 3, onBatchEnumerate, onEpochEnumerate);
これをビルドすると、前回の同じような0.04くらいの誤差になると思います。
回帰分析というと「関数の出力を学習するんだよ」というような説明がされることがあります。
それだけだと最終層の要素数は1であるべきかのように感じられますが、実際は複数個でも問題ありません。
(つまり戻り値がベクトルの関数ですね。)
学習データによっては、クラス分類として解くよりも精度が高くなる場合もあるかもしれません。