Inverted Dropoutの実装とResNetの実験[c++ Arrayfire]

六花です。

今回はInvert Dropoutというドロップアウト層の派生を実装して、ちゃんとしたテストフェーズを通してResNetの実験をします。

§環境
windows 10
Microsoft Visual Studio Community 2022 (64 ビット) Version 17.3.4
ArrayFire v3.8.2

// Inverted Dropout
// https://data-analytics.fun/2021/11/13/understanding-dropout/
struct Dropout_layer : public Layer
{
    const af::array& x;
    af::array& y;
    af::array& dx;
    const af::array& dy;

    var_t rate_DO = 0.35;

    af::array filter_DO;

    Dropout_layer(const af::array& x, af::array& y, af::array& dx, const af::array& dy)
        : x(x)
        , y(y)
        , dx(dx)
        , dy(dy)
    {
        filter_DO = af::constant(0.0, x.dims(0), 1, dtype_t);
    }

    virtual void forward()
    {
        filter_DO = af::select(af::randu(x.dims(0), 1, dtype_t) <= rate_DO, filter_DO * 0.0, 1.0); // ドロップアウト率以下のところは0.0
        filter_DO = af::select((af::constant(isTrain, x.dims(0), 1, dtype_t) == 1.0), filter_DO, 1.0); // テストの時は全て1.0
        y = x * af::select((af::constant(isTrain, x.dims(0), 1, dtype_t) == 1.0), af::tile(filter_DO, 1, x.dims(1)) / (1.0 - rate_DO), (var_t)1);

        filter_DO.eval();
        y.eval();
    }

    virtual void backward()
    {
        dy.eval();
        dx += dy * af::tile(filter_DO, 1, x.dims(1)) / (1.0 - rate_DO);
    }
};

https://deepage.net/deep_learning/2016/11/30/resnet.html

上記のサイトで推奨ドロップアウト率が30~40%となっていたので35%にしています。
Invert Dropoutでは推論時ではなく学習時に(1.0 – ドロップアウト率)を割り算するようです。
なお、Dropout層を入れると勾配チェックがうまくいかなくなるようです。

後はちゃんとしたテストフェーズを記述します。

for (auto& itm : layer)
{
    itm->isTrain = 0.0;
}

var_t counter_ok = (var_t)0;
var_t diff = (var_t)0;
var_t norm = (var_t)0;

// ランダム選出
// シャッフルされたインデックスを作成する
af::array idx_data;
{
    af::array vals_data;
    af::array sort_data = af::randu(size_data_test, 1, dtype_t);
    af::sort(vals_data, idx_data, sort_data, 0);
}

constexpr int size_data_in_test = size_data_test / size_batch;
for (int step = 0; step < size_data_in_test; ++step)
{
    // 誤差(傾き)の初期化
    for (auto& itm : grad) { itm = 0.0; }
    for (auto& itm : layer) { itm->init(); }

    // 今回のステップの学習対象
    // af::seq は範囲を指定するためのもの
    // コンストラクタによって指定される範囲が特殊なので以下を参照のこと
    // https://arrayfire.org/docs/classaf_1_1seq.htm
    af::array idx_target = idx_data(af::seq((step + 0) * size_batch, (step + 1) * size_batch - 1)) + size_data_train;

    // 入力値を設定
    data.front() = input(af::span, idx_target);

    // 順伝播
    for (auto& itm : layer) { itm->forward(); }

    // 必要ならsoftmax
#ifdef DEF_use_softmax
    data.back() = af::exp(data.back()) / af::tile(af::sum(af::exp(data.back()), 0), data.back().dims(0));
#endif

    // 誤差の計算
    grad.back() = data.back() - output(af::span, idx_target); // y - t
    diff += af::mean<var_t>(af::abs(grad.back()));
    norm += af::norm(grad.back());

#ifdef DEF_use_softmax
    // 正答率の算定
    af::array val, plabels, tlabels;
    af::max(val, tlabels, output(af::span, idx_target), 0);
    af::max(val, plabels, data.back(), 0);
    counter_ok += af::count<var_t>(plabels == tlabels);
#endif
}

const auto accuracy = (var_t)100 * counter_ok / (var_t)(size_data_in_test * size_batch);
diff /= (var_t)size_data_in_test;
norm /= (var_t)size_data_in_test;

for (auto& itm : layer)
{
    itm->isTrain = 1.0;
}

適当に学習時のソースをコピペすると勾配の更新までコピペするので注意してください。

今回は、
1.FC+BN+tanhExp
2.FC+BN+tanhExp+Shortcut
3.FC+BN+tanhExp+Shortcut+Dropout
の三つのパターンの結果を用意しました。

// BN 300 5 0.01
epoch : 1       accuracy train: 5.95703 accuracy test :        3.09361
epoch : 101     accuracy train: 86.5234 accuracy test :        61.3491
epoch : 201     accuracy train: 96.8262 accuracy test :        80.6327
epoch : 301     accuracy train: 99.3164 accuracy test :        82.53
epoch : 401     accuracy train: 99.8535 accuracy test :        82.0696
epoch : 501     accuracy train: 99.7559 accuracy test :        81.7208
epoch : 601     accuracy train: 99.9512 accuracy test :        81.7418
epoch : 701     accuracy train: 100     accuracy test :        81.5011
epoch : 801     accuracy train: 99.4629 accuracy test :        81.1733
epoch : 901     accuracy train: 99.5605 accuracy test :        81.2012
epoch : 1001    accuracy train: 100     accuracy test :        81.1
epoch : 1101    accuracy train: 100     accuracy test :        80.8908
epoch : 1201    accuracy train: 100     accuracy test :        80.7059
epoch : 1301    accuracy train: 100     accuracy test :        80.678
epoch : 1401    accuracy train: 100     accuracy test :        80.6989
epoch : 1501    accuracy train: 99.2188 accuracy test :        80.4827
epoch : 1601    accuracy train: 100     accuracy test :        80.4583
epoch : 1701    accuracy train: 100     accuracy test :        80.5071
epoch : 1801    accuracy train: 100     accuracy test :        80.385
epoch : 1901    accuracy train: 100     accuracy test :        79.5619
epoch : 2001    accuracy train: 100     accuracy test :        80.2246
epoch : 2101    accuracy train: 100     accuracy test :        80.0991
epoch : 2201    accuracy train: 100     accuracy test :        80.0851
epoch : 2301    accuracy train: 100     accuracy test :        80.1862
epoch : 2401    accuracy train: 100     accuracy test :        79.9944
epoch : 2501    accuracy train: 100     accuracy test :        80.0049
epoch : 2601    accuracy train: 100     accuracy test :        80.1444
epoch : 2701    accuracy train: 100     accuracy test :        80.0119
epoch : 2801    accuracy train: 100     accuracy test :        79.9351
epoch : 2901    accuracy train: 99.9512 accuracy test :        79.841
epoch : 3001    accuracy train: 100     accuracy test :        79.7956

// BN SC 300 5 0.01
epoch : 1       accuracy train: 7.08008 accuracy test :        7.20564
epoch : 101     accuracy train: 77.0508 accuracy test :        52.2356
epoch : 201     accuracy train: 88.623  accuracy test :        78.8574
epoch : 301     accuracy train: 94.0918 accuracy test :        86.9245
epoch : 401     accuracy train: 95.2148 accuracy test :        88.637
epoch : 501     accuracy train: 94.9219 accuracy test :        89.1253
epoch : 601     accuracy train: 98.4375 accuracy test :        89.1497
epoch : 701     accuracy train: 97.8516 accuracy test :        88.8323
epoch : 801     accuracy train: 95.5566 accuracy test :        88.1731
epoch : 901     accuracy train: 97.8516 accuracy test :        87.7965
epoch : 1001    accuracy train: 97.0215 accuracy test :        87.4372
epoch : 1101    accuracy train: 99.6094 accuracy test :        86.7501
epoch : 1201    accuracy train: 99.8535 accuracy test :        86.5479
epoch : 1301    accuracy train: 99.9512 accuracy test :        86.3072
epoch : 1401    accuracy train: 99.9023 accuracy test :        86.0352
epoch : 1501    accuracy train: 97.5586 accuracy test :        85.8154
epoch : 1601    accuracy train: 99.9512 accuracy test :        85.5015
epoch : 1701    accuracy train: 100     accuracy test :        85.2923
epoch : 1801    accuracy train: 100     accuracy test :        85.0202
epoch : 1901    accuracy train: 99.9512 accuracy test :        84.4657
epoch : 2001    accuracy train: 99.9512 accuracy test :        84.7028
epoch : 2101    accuracy train: 100     accuracy test :        84.5459
epoch : 2201    accuracy train: 100     accuracy test :        84.4203
epoch : 2301    accuracy train: 100     accuracy test :        84.3924
epoch : 2401    accuracy train: 99.9512 accuracy test :        84.1099
epoch : 2501    accuracy train: 100     accuracy test :        84.1692
epoch : 2601    accuracy train: 99.9512 accuracy test :        84.0995
epoch : 2701    accuracy train: 100     accuracy test :        83.96
epoch : 2801    accuracy train: 100     accuracy test :        83.7646
epoch : 2901    accuracy train: 99.8535 accuracy test :        83.7925
epoch : 3001    accuracy train: 100     accuracy test :        83.6426

// BN SC DO 5 0.01
epoch : 1       accuracy train: 5.27344 accuracy test :        6.36858
epoch : 101     accuracy train: 61.8652 accuracy test :        44.6882
epoch : 201     accuracy train: 78.7598 accuracy test :        75.1918
epoch : 301     accuracy train: 86.1816 accuracy test :        83.81
epoch : 401     accuracy train: 86.4258 accuracy test :        87.8732
epoch : 501     accuracy train: 87.9395 accuracy test :        88.6509
epoch : 601     accuracy train: 91.1133 accuracy test :        90.2727
epoch : 701     accuracy train: 91.1621 accuracy test :        89.8054
epoch : 801     accuracy train: 89.1602 accuracy test :        91.5318
epoch : 901     accuracy train: 92.5293 accuracy test :        91.4586
epoch : 1001    accuracy train: 93.7012 accuracy test :        90.5622
epoch : 1101    accuracy train: 93.1641 accuracy test :        91.8771
epoch : 1201    accuracy train: 94.1895 accuracy test :        91.1482
epoch : 1301    accuracy train: 93.6523 accuracy test :        91.3609
epoch : 1401    accuracy train: 96.0449 accuracy test :        92.0061
epoch : 1501    accuracy train: 95.6055 accuracy test :        91.49
epoch : 1601    accuracy train: 91.3086 accuracy test :        91.3086
epoch : 1701    accuracy train: 87.5488 accuracy test :        90.0879
epoch : 1801    accuracy train: 94.7754 accuracy test :        92.048
epoch : 1901    accuracy train: 96.4844 accuracy test :        90.9877
epoch : 2001    accuracy train: 95.8984 accuracy test :        93.185
epoch : 2101    accuracy train: 93.5547 accuracy test :        92.613
epoch : 2201    accuracy train: 95.0195 accuracy test :        91.326
epoch : 2301    accuracy train: 95.9473 accuracy test :        92.055
epoch : 2401    accuracy train: 91.6504 accuracy test :        92.3165
epoch : 2501    accuracy train: 96.2402 accuracy test :        92.1143
epoch : 2601    accuracy train: 96.1426 accuracy test :        92.4037
epoch : 2701    accuracy train: 91.8945 accuracy test :        92.756
epoch : 2801    accuracy train: 92.334  accuracy test :        92.3131
epoch : 2901    accuracy train: 92.0898 accuracy test :        91.8666
epoch : 3001    accuracy train: 0       accuracy test :        0

Shortcut layerを入れると、テストフェーズの正答率が明確に上がるのがわかります。
Dropout layerを入れると、正答率が上がった上で、長い間学習を続けても正答率が落ちないことがわかります。
ただ、このモデルではDropout layerを入れた時はnanが発生してしまいました。
学習フェーズの正答率が100かそうでないかの違いに関係があるのかはわかりません。
問題に対してモデルが小さすぎるのかもしれません。

おすすめ

コメントを残す

メールアドレスが公開されることはありません。 が付いている欄は必須項目です