Inverted Dropoutの実装とResNetの実験[c++ Arrayfire]
六花です。
今回はInvert Dropoutというドロップアウト層の派生を実装して、ちゃんとしたテストフェーズを通してResNetの実験をします。
§環境
windows 10
Microsoft Visual Studio Community 2022 (64 ビット) Version 17.3.4
ArrayFire v3.8.2
// Inverted Dropout
// https://data-analytics.fun/2021/11/13/understanding-dropout/
struct Dropout_layer : public Layer
{
const af::array& x;
af::array& y;
af::array& dx;
const af::array& dy;
var_t rate_DO = 0.35;
af::array filter_DO;
Dropout_layer(const af::array& x, af::array& y, af::array& dx, const af::array& dy)
: x(x)
, y(y)
, dx(dx)
, dy(dy)
{
filter_DO = af::constant(0.0, x.dims(0), 1, dtype_t);
}
virtual void forward()
{
filter_DO = af::select(af::randu(x.dims(0), 1, dtype_t) <= rate_DO, filter_DO * 0.0, 1.0); // ドロップアウト率以下のところは0.0
filter_DO = af::select((af::constant(isTrain, x.dims(0), 1, dtype_t) == 1.0), filter_DO, 1.0); // テストの時は全て1.0
y = x * af::select((af::constant(isTrain, x.dims(0), 1, dtype_t) == 1.0), af::tile(filter_DO, 1, x.dims(1)) / (1.0 - rate_DO), (var_t)1);
filter_DO.eval();
y.eval();
}
virtual void backward()
{
dy.eval();
dx += dy * af::tile(filter_DO, 1, x.dims(1)) / (1.0 - rate_DO);
}
};
https://deepage.net/deep_learning/2016/11/30/resnet.html
上記のサイトで推奨ドロップアウト率が30~40%となっていたので35%にしています。
Invert Dropoutでは推論時ではなく学習時に(1.0 – ドロップアウト率)を割り算するようです。
なお、Dropout層を入れると勾配チェックがうまくいかなくなるようです。
後はちゃんとしたテストフェーズを記述します。
for (auto& itm : layer)
{
itm->isTrain = 0.0;
}
var_t counter_ok = (var_t)0;
var_t diff = (var_t)0;
var_t norm = (var_t)0;
// ランダム選出
// シャッフルされたインデックスを作成する
af::array idx_data;
{
af::array vals_data;
af::array sort_data = af::randu(size_data_test, 1, dtype_t);
af::sort(vals_data, idx_data, sort_data, 0);
}
constexpr int size_data_in_test = size_data_test / size_batch;
for (int step = 0; step < size_data_in_test; ++step)
{
// 誤差(傾き)の初期化
for (auto& itm : grad) { itm = 0.0; }
for (auto& itm : layer) { itm->init(); }
// 今回のステップの学習対象
// af::seq は範囲を指定するためのもの
// コンストラクタによって指定される範囲が特殊なので以下を参照のこと
// https://arrayfire.org/docs/classaf_1_1seq.htm
af::array idx_target = idx_data(af::seq((step + 0) * size_batch, (step + 1) * size_batch - 1)) + size_data_train;
// 入力値を設定
data.front() = input(af::span, idx_target);
// 順伝播
for (auto& itm : layer) { itm->forward(); }
// 必要ならsoftmax
#ifdef DEF_use_softmax
data.back() = af::exp(data.back()) / af::tile(af::sum(af::exp(data.back()), 0), data.back().dims(0));
#endif
// 誤差の計算
grad.back() = data.back() - output(af::span, idx_target); // y - t
diff += af::mean<var_t>(af::abs(grad.back()));
norm += af::norm(grad.back());
#ifdef DEF_use_softmax
// 正答率の算定
af::array val, plabels, tlabels;
af::max(val, tlabels, output(af::span, idx_target), 0);
af::max(val, plabels, data.back(), 0);
counter_ok += af::count<var_t>(plabels == tlabels);
#endif
}
const auto accuracy = (var_t)100 * counter_ok / (var_t)(size_data_in_test * size_batch);
diff /= (var_t)size_data_in_test;
norm /= (var_t)size_data_in_test;
for (auto& itm : layer)
{
itm->isTrain = 1.0;
}
適当に学習時のソースをコピペすると勾配の更新までコピペするので注意してください。
今回は、
1.FC+BN+tanhExp
2.FC+BN+tanhExp+Shortcut
3.FC+BN+tanhExp+Shortcut+Dropout
の三つのパターンの結果を用意しました。
// BN 300 5 0.01 epoch : 1 accuracy train: 5.95703 accuracy test : 3.09361 epoch : 101 accuracy train: 86.5234 accuracy test : 61.3491 epoch : 201 accuracy train: 96.8262 accuracy test : 80.6327 epoch : 301 accuracy train: 99.3164 accuracy test : 82.53 epoch : 401 accuracy train: 99.8535 accuracy test : 82.0696 epoch : 501 accuracy train: 99.7559 accuracy test : 81.7208 epoch : 601 accuracy train: 99.9512 accuracy test : 81.7418 epoch : 701 accuracy train: 100 accuracy test : 81.5011 epoch : 801 accuracy train: 99.4629 accuracy test : 81.1733 epoch : 901 accuracy train: 99.5605 accuracy test : 81.2012 epoch : 1001 accuracy train: 100 accuracy test : 81.1 epoch : 1101 accuracy train: 100 accuracy test : 80.8908 epoch : 1201 accuracy train: 100 accuracy test : 80.7059 epoch : 1301 accuracy train: 100 accuracy test : 80.678 epoch : 1401 accuracy train: 100 accuracy test : 80.6989 epoch : 1501 accuracy train: 99.2188 accuracy test : 80.4827 epoch : 1601 accuracy train: 100 accuracy test : 80.4583 epoch : 1701 accuracy train: 100 accuracy test : 80.5071 epoch : 1801 accuracy train: 100 accuracy test : 80.385 epoch : 1901 accuracy train: 100 accuracy test : 79.5619 epoch : 2001 accuracy train: 100 accuracy test : 80.2246 epoch : 2101 accuracy train: 100 accuracy test : 80.0991 epoch : 2201 accuracy train: 100 accuracy test : 80.0851 epoch : 2301 accuracy train: 100 accuracy test : 80.1862 epoch : 2401 accuracy train: 100 accuracy test : 79.9944 epoch : 2501 accuracy train: 100 accuracy test : 80.0049 epoch : 2601 accuracy train: 100 accuracy test : 80.1444 epoch : 2701 accuracy train: 100 accuracy test : 80.0119 epoch : 2801 accuracy train: 100 accuracy test : 79.9351 epoch : 2901 accuracy train: 99.9512 accuracy test : 79.841 epoch : 3001 accuracy train: 100 accuracy test : 79.7956 // BN SC 300 5 0.01 epoch : 1 accuracy train: 7.08008 accuracy test : 7.20564 epoch : 101 accuracy train: 77.0508 accuracy test : 52.2356 epoch : 201 accuracy train: 88.623 accuracy test : 78.8574 epoch : 301 accuracy train: 94.0918 accuracy test : 86.9245 epoch : 401 accuracy train: 95.2148 accuracy test : 88.637 epoch : 501 accuracy train: 94.9219 accuracy test : 89.1253 epoch : 601 accuracy train: 98.4375 accuracy test : 89.1497 epoch : 701 accuracy train: 97.8516 accuracy test : 88.8323 epoch : 801 accuracy train: 95.5566 accuracy test : 88.1731 epoch : 901 accuracy train: 97.8516 accuracy test : 87.7965 epoch : 1001 accuracy train: 97.0215 accuracy test : 87.4372 epoch : 1101 accuracy train: 99.6094 accuracy test : 86.7501 epoch : 1201 accuracy train: 99.8535 accuracy test : 86.5479 epoch : 1301 accuracy train: 99.9512 accuracy test : 86.3072 epoch : 1401 accuracy train: 99.9023 accuracy test : 86.0352 epoch : 1501 accuracy train: 97.5586 accuracy test : 85.8154 epoch : 1601 accuracy train: 99.9512 accuracy test : 85.5015 epoch : 1701 accuracy train: 100 accuracy test : 85.2923 epoch : 1801 accuracy train: 100 accuracy test : 85.0202 epoch : 1901 accuracy train: 99.9512 accuracy test : 84.4657 epoch : 2001 accuracy train: 99.9512 accuracy test : 84.7028 epoch : 2101 accuracy train: 100 accuracy test : 84.5459 epoch : 2201 accuracy train: 100 accuracy test : 84.4203 epoch : 2301 accuracy train: 100 accuracy test : 84.3924 epoch : 2401 accuracy train: 99.9512 accuracy test : 84.1099 epoch : 2501 accuracy train: 100 accuracy test : 84.1692 epoch : 2601 accuracy train: 99.9512 accuracy test : 84.0995 epoch : 2701 accuracy train: 100 accuracy test : 83.96 epoch : 2801 accuracy train: 100 accuracy test : 83.7646 epoch : 2901 accuracy train: 99.8535 accuracy test : 83.7925 epoch : 3001 accuracy train: 100 accuracy test : 83.6426 // BN SC DO 5 0.01 epoch : 1 accuracy train: 5.27344 accuracy test : 6.36858 epoch : 101 accuracy train: 61.8652 accuracy test : 44.6882 epoch : 201 accuracy train: 78.7598 accuracy test : 75.1918 epoch : 301 accuracy train: 86.1816 accuracy test : 83.81 epoch : 401 accuracy train: 86.4258 accuracy test : 87.8732 epoch : 501 accuracy train: 87.9395 accuracy test : 88.6509 epoch : 601 accuracy train: 91.1133 accuracy test : 90.2727 epoch : 701 accuracy train: 91.1621 accuracy test : 89.8054 epoch : 801 accuracy train: 89.1602 accuracy test : 91.5318 epoch : 901 accuracy train: 92.5293 accuracy test : 91.4586 epoch : 1001 accuracy train: 93.7012 accuracy test : 90.5622 epoch : 1101 accuracy train: 93.1641 accuracy test : 91.8771 epoch : 1201 accuracy train: 94.1895 accuracy test : 91.1482 epoch : 1301 accuracy train: 93.6523 accuracy test : 91.3609 epoch : 1401 accuracy train: 96.0449 accuracy test : 92.0061 epoch : 1501 accuracy train: 95.6055 accuracy test : 91.49 epoch : 1601 accuracy train: 91.3086 accuracy test : 91.3086 epoch : 1701 accuracy train: 87.5488 accuracy test : 90.0879 epoch : 1801 accuracy train: 94.7754 accuracy test : 92.048 epoch : 1901 accuracy train: 96.4844 accuracy test : 90.9877 epoch : 2001 accuracy train: 95.8984 accuracy test : 93.185 epoch : 2101 accuracy train: 93.5547 accuracy test : 92.613 epoch : 2201 accuracy train: 95.0195 accuracy test : 91.326 epoch : 2301 accuracy train: 95.9473 accuracy test : 92.055 epoch : 2401 accuracy train: 91.6504 accuracy test : 92.3165 epoch : 2501 accuracy train: 96.2402 accuracy test : 92.1143 epoch : 2601 accuracy train: 96.1426 accuracy test : 92.4037 epoch : 2701 accuracy train: 91.8945 accuracy test : 92.756 epoch : 2801 accuracy train: 92.334 accuracy test : 92.3131 epoch : 2901 accuracy train: 92.0898 accuracy test : 91.8666 epoch : 3001 accuracy train: 0 accuracy test : 0
Shortcut layerを入れると、テストフェーズの正答率が明確に上がるのがわかります。
Dropout layerを入れると、正答率が上がった上で、長い間学習を続けても正答率が落ちないことがわかります。
ただ、このモデルではDropout layerを入れた時はnanが発生してしまいました。
学習フェーズの正答率が100かそうでないかの違いに関係があるのかはわかりません。
問題に対してモデルが小さすぎるのかもしれません。