バッチ正規化(Batch Normalization)の実装[c++ Arrayfire]

六花です。

昔、自分で書いたバッチ正規化の逆伝播(導関数)が計算合わなかったので捨てることにしまして、ネット上に良い実装が落ちていないか調べたところ、良さそうなものを見つけたので紹介します。

Deriving the Gradient for the Backward Pass of Batch Normalization
https://kevinzakka.github.io/2016/09/14/batch_normalization/

「Recap」の項目を見てもらえるとわかりますが、非常に簡潔な式に収まっています。すごい!
こちらを参考に、バッチ正規化層をArrayfireで実装しました。

§環境
windows 10
Microsoft Visual Studio Community 2022 (64 ビット) Version 17.3.4
ArrayFire v3.8.2


struct Layer
{
    var_t isTrain = 1.0; // trainのときは1.0 testのときは0.0

    virtual void init() {}          // 勾配の初期化に用いる
    virtual void forward() = 0;     // 順伝播
    virtual void backward() = 0;    // 逆伝播
    virtual void SGD() {}           // 重みの更新
};

// https://kevinzakka.github.io/2016/09/14/batch_normalization/
struct BN_layer : public Layer
{
    const af::array& x;
    af::array& y;
    af::array& dx;
    const af::array& dy;

    const dim_t size_batch;

    af::array Std, Norm; // 一時変数
    af::array Mean, Var; // 評価時に必要
    af::array G, B; // 学習対象

    af::array dG, dB;

    BN_layer(const af::array& x, af::array& y, af::array& dx, const af::array& dy)
        : x(x)
        , y(y)
        , dx(dx)
        , dy(dy)
        , size_batch(x.dims(1))
    {
        Mean = af::constant(0.0, x.dims(0), 1, dtype_t);
        Var = af::constant(1.0, x.dims(0), 1, dtype_t);
        G = af::constant(1.0, x.dims(0), 1, dtype_t);
        B = af::constant(0.0, x.dims(0), 1, dtype_t);

        dG = G * 0.0;
        dB = B * 0.0;
    }

    virtual void init()
    {
        dG = 0.0;
        dB = 0.0;

        G.eval();
        B.eval();
    }

    virtual void forward()
    {
        const af::array Mean_b = af::mean(x, 1);
        const af::array Var_b = af::mean(af::pow(x - af::tile(Mean_b, 1, size_batch), 2.0), 1);

        {
            constexpr var_t eta = 0.99;
            Mean = af::select(af::constant(isTrain, x.dims(0), 1, dtype_t) == 1.0, eta * Mean + (1.0 - eta) * Mean_b, Mean);
            Var = af::select(af::constant(isTrain, x.dims(0), 1, dtype_t) == 1.0, eta * Var + (1.0 - eta) * ((var_t)size_batch / ((var_t)(size_batch - 1))) * Var_b, Var);
        }
        const af::array Mean_calc = af::select(af::constant(isTrain, x.dims(0), 1, dtype_t) == 1.0, Mean_b, Mean);
        const af::array Var_calc = af::select(af::constant(isTrain, x.dims(0), 1, dtype_t) == 1.0, Var_b, Var);

        Std = af::sqrt(Var_calc + eps);
        Norm = (x - af::tile(Mean_calc, 1, size_batch)) / af::tile(Std, 1, size_batch);

        y = af::tile(G, 1, size_batch) * Norm + af::tile(B, 1, size_batch);

        Mean.eval();
        Var.eval();
        Std.eval();
        Norm.eval();
    }

    virtual void backward()
    {
        dy.eval();
        const af::array dNorm = dy * af::tile(G, 1, size_batch);
        dB += af::sum(dy, 1);
        dG += af::sum(dy * Norm, 1);

        dx += (
            (var_t)size_batch * dNorm
            - af::tile(af::sum(dNorm, 1), 1, size_batch)
            - Norm * af::tile(af::sum(dNorm * Norm, 1), 1, size_batch)
            )
            / ((var_t)size_batch * af::tile(Std, 1, size_batch));

        dB.eval();
        dG.eval();
    }

    virtual void SGD()
    {
        G -= alpha * dG / (var_t)size_batch;
        B -= alpha * dB / (var_t)size_batch;

        G.eval();
        B.eval();
    }
};

バッチ正規化層は学習時と推論時で処理が違うので、基底クラスのLayerにisTrainという変数を追加しています。

これをモデルに組み込みます。


    std::vector<af::array> data; // 評価値
    data.push_back(af::constant(0.0, size_input, size_batch, dtype_t));
    for (int i_hidden_layer = 0; i_hidden_layer < size_hidden_layer; ++i_hidden_layer)
    {
        data.push_back(af::constant(0.0, size_hidden, size_batch, dtype_t));
        data.push_back(af::constant(0.0, size_hidden, size_batch, dtype_t));
        data.push_back(af::constant(0.0, size_hidden, size_batch, dtype_t));
    }
    data.push_back(af::constant(0.0, size_output, size_batch, dtype_t));

    std::vector<af::array> grad; // 誤差(傾き)
    grad.push_back(af::constant(0.0, size_input, size_batch, dtype_t));
    for (int i_hidden_layer = 0; i_hidden_layer < size_hidden_layer; ++i_hidden_layer)
    {
        grad.push_back(af::constant(0.0, size_hidden, size_batch, dtype_t));
        grad.push_back(af::constant(0.0, size_hidden, size_batch, dtype_t));
        grad.push_back(af::constant(0.0, size_hidden, size_batch, dtype_t));
    }
    grad.push_back(af::constant(0.0, size_output, size_batch, dtype_t));

    std::vector<std::shared_ptr<Layer>> layer; // 処理層
    layer.push_back(std::make_shared<FC_layer/* */>(data.at(0), data.at(1), grad.at(0), grad.at(1)));
    for (int i_hidden_layer = 0; i_hidden_layer < size_hidden_layer * 3; i_hidden_layer += 3)
    {
        layer.push_back(std::make_shared<BN_layer/*     */>(data.at(i_hidden_layer + 1), data.at(i_hidden_layer + 2), grad.at(i_hidden_layer + 1), grad.at(i_hidden_layer + 2)));
        layer.push_back(std::make_shared<tanhExp_layer/**/>(data.at(i_hidden_layer + 2), data.at(i_hidden_layer + 3), grad.at(i_hidden_layer + 2), grad.at(i_hidden_layer + 3)));
        layer.push_back(std::make_shared<FC_layer/*     */>(data.at(i_hidden_layer + 3), data.at(i_hidden_layer + 4), grad.at(i_hidden_layer + 3), grad.at(i_hidden_layer + 4)));
    }

深いモデルにしたい場合は、先日作成したShortcut_layerを挿入すべきかと思いますが、今回はsize_hidden_layer=2で実験するので必要ないと思います。
バッチ正規化層がない場合とある場合、学習率を100倍にしたときの正答率の推移を以下に載せます。

// tanhExp alpha = 0.0001
epoch : 1       diff :  0.0192724       norm :  11.1596 accuracy :      10.1562
epoch : 101     diff :  0.0179332       norm :  10.8161 accuracy :      14.8438
epoch : 201     diff :  0.017582        norm :  10.5947 accuracy :      22.6562
epoch : 301     diff :  0.0177274       norm :  10.6867 accuracy :      13.2812
epoch : 401     diff :  0.0177372       norm :  10.6986 accuracy :      21.0938
epoch : 501     diff :  0.0175152       norm :  10.584  accuracy :      23.4375
epoch : 601     diff :  0.0173587       norm :  10.4955 accuracy :      25
epoch : 701     diff :  0.0173365       norm :  10.5304 accuracy :      24.2188
epoch : 801     diff :  0.0171328       norm :  10.4384 accuracy :      25
epoch : 901     diff :  0.0171049       norm :  10.4661 accuracy :      28.125
epoch : 1001    diff :  0.0167975       norm :  10.3115 accuracy :      28.9062
epoch : 1101    diff :  0.0166932       norm :  10.2941 accuracy :      28.125
epoch : 1201    diff :  0.0162885       norm :  10.0633 accuracy :      39.8438
epoch : 1301    diff :  0.0162624       norm :  10.0967 accuracy :      35.9375
epoch : 1401    diff :  0.015995        norm :  9.95596 accuracy :      43.75
epoch : 1501    diff :  0.0158486       norm :  9.92333 accuracy :      41.4062
epoch : 1601    diff :  0.0155752       norm :  9.75906 accuracy :      47.6562
epoch : 1701    diff :  0.0154601       norm :  9.73563 accuracy :      49.2188
epoch : 1801    diff :  0.0153018       norm :  9.68003 accuracy :      53.9062
epoch : 1901    diff :  0.0151829       norm :  9.6462  accuracy :      48.4375
epoch : 2001    diff :  0.0149164       norm :  9.49316 accuracy :      50.7812
epoch : 2101    diff :  0.0146155       norm :  9.33597 accuracy :      58.5938
epoch : 2201    diff :  0.0145862       norm :  9.35398 accuracy :      53.9062
epoch : 2301    diff :  0.0144299       norm :  9.27817 accuracy :      56.25
epoch : 2401    diff :  0.0143156       norm :  9.24132 accuracy :      56.25
epoch : 2501    diff :  0.0138128       norm :  8.93688 accuracy :      63.2812
epoch : 2601    diff :  0.0137335       norm :  8.91735 accuracy :      64.0625
epoch : 2701    diff :  0.0136017       norm :  8.8589  accuracy :      60.9375
epoch : 2801    diff :  0.0135729       norm :  8.87159 accuracy :      61.7188
epoch : 2901    diff :  0.0134057       norm :  8.81603 accuracy :      64.8438
epoch : 3001    diff :  0.012934        norm :  8.52779 accuracy :      74.2188
epoch : 3101    diff :  0.0129994       norm :  8.62714 accuracy :      70.3125
epoch : 3201    diff :  0.0125961       norm :  8.39045 accuracy :      73.4375
epoch : 3301    diff :  0.0126337       norm :  8.42499 accuracy :      75.7812
epoch : 3401    diff :  0.0127017       norm :  8.50228 accuracy :      66.4062
epoch : 3501    diff :  0.0120642       norm :  8.09769 accuracy :      76.5625
epoch : 3601    diff :  0.0119556       norm :  8.0563  accuracy :      77.3438
epoch : 3701    diff :  0.0121662       norm :  8.24673 accuracy :      71.875
epoch : 3801    diff :  0.0119128       norm :  8.08681 accuracy :      77.3438
epoch : 3901    diff :  0.011671        norm :  7.92819 accuracy :      78.9062
epoch : 4001    diff :  0.0115751       norm :  7.91488 accuracy :      78.125
epoch : 4101    diff :  0.0113907       norm :  7.83296 accuracy :      75.7812
epoch : 4201    diff :  0.0112091       norm :  7.72397 accuracy :      84.375
epoch : 4301    diff :  0.0109818       norm :  7.57628 accuracy :      85.9375
epoch : 4401    diff :  0.0109191       norm :  7.56331 accuracy :      85.1562
epoch : 4501    diff :  0.0108726       norm :  7.58057 accuracy :      79.6875
epoch : 4601    diff :  nan     norm :  nan     accuracy :      0

// BN alpha = 0.0001
epoch : 1       diff :  0.0196451       norm :  11.3754 accuracy :      0
epoch : 101     diff :  0.0174411       norm :  10.6922 accuracy :      18.75
epoch : 201     diff :  0.0163546       norm :  10.141  accuracy :      32.0312
epoch : 301     diff :  0.0157799       norm :  9.96257 accuracy :      29.6875
epoch : 401     diff :  0.0156522       norm :  10.0103 accuracy :      31.25
epoch : 501     diff :  0.014774        norm :  9.51013 accuracy :      47.6562
epoch : 601     diff :  0.0141584       norm :  9.18429 accuracy :      49.2188
epoch : 701     diff :  0.0140881       norm :  9.24847 accuracy :      46.875
epoch : 801     diff :  0.0134546       norm :  8.9029  accuracy :      57.0312
epoch : 901     diff :  0.0133436       norm :  8.93784 accuracy :      55.4688
epoch : 1001    diff :  0.0131468       norm :  8.8971  accuracy :      50
epoch : 1101    diff :  0.0126429       norm :  8.65494 accuracy :      59.375
epoch : 1201    diff :  0.0118861       norm :  8.21505 accuracy :      67.9688
epoch : 1301    diff :  0.011964        norm :  8.31101 accuracy :      66.4062
epoch : 1401    diff :  0.0112667       norm :  7.88477 accuracy :      70.3125
epoch : 1501    diff :  0.010903        norm :  7.76457 accuracy :      68.75
epoch : 1601    diff :  0.0117846       norm :  8.57495 accuracy :      55.4688
epoch : 1701    diff :  0.0106091       norm :  7.71092 accuracy :      65.625
epoch : 1801    diff :  0.0109676       norm :  8.0774  accuracy :      63.2812
epoch : 1901    diff :  0.00975205      norm :  7.18112 accuracy :      76.5625
epoch : 2001    diff :  0.00960622      norm :  7.21011 accuracy :      75
epoch : 2101    diff :  0.0108849       norm :  8.28597 accuracy :      60.1562
epoch : 2201    diff :  0.00994902      norm :  7.64011 accuracy :      65.625
epoch : 2301    diff :  0.00859682      norm :  6.57081 accuracy :      81.25
epoch : 2401    diff :  0.00841644      norm :  6.45366 accuracy :      81.25
epoch : 2501    diff :  0.00817581      norm :  6.40977 accuracy :      84.375
epoch : 2601    diff :  0.00847368      norm :  6.75952 accuracy :      76.5625
epoch : 2701    diff :  0.00763888      norm :  5.97458 accuracy :      88.2812
epoch : 2801    diff :  0.00798304      norm :  6.32499 accuracy :      81.25
epoch : 2901    diff :  0.00754648      norm :  6.05507 accuracy :      84.375
epoch : 3001    diff :  0.00737502      norm :  5.94544 accuracy :      87.5
epoch : 3101    diff :  0.0100366       norm :  8.25889 accuracy :      58.5938
epoch : 3201    diff :  0.00735531      norm :  5.9466  accuracy :      85.9375
epoch : 3301    diff :  0.00678528      norm :  5.45495 accuracy :      92.1875
epoch : 3401    diff :  0.00905309      norm :  7.6949  accuracy :      62.5
epoch : 3501    diff :  0.00667509      norm :  5.53527 accuracy :      87.5
epoch : 3601    diff :  0.00728889      norm :  6.20548 accuracy :      80.4688
epoch : 3701    diff :  0.00713524      norm :  6.11726 accuracy :      79.6875
epoch : 3801    diff :  0.00912937      norm :  7.94254 accuracy :      62.5
epoch : 3901    diff :  0.00651565      norm :  5.46807 accuracy :      87.5
epoch : 4001    diff :  0.00728201      norm :  6.07996 accuracy :      83.5938
epoch : 4101    diff :  0.00588398      norm :  5.05563 accuracy :      89.8438
epoch : 4201    diff :  0.00651716      norm :  5.63472 accuracy :      82.8125
epoch : 4301    diff :  0.0063791       norm :  5.50415 accuracy :      88.2812
epoch : 4401    diff :  0.0061517       norm :  5.2178  accuracy :      89.8438
epoch : 4501    diff :  0.00752665      norm :  6.79134 accuracy :      70.3125
epoch : 4601    diff :  0.0058363       norm :  5.1437  accuracy :      87.5
epoch : 4701    diff :  0.00724119      norm :  6.41938 accuracy :      76.5625
epoch : 4801    diff :  0.00707362      norm :  6.23075 accuracy :      80.4688
epoch : 4901    diff :  0.00731898      norm :  6.56829 accuracy :      76.5625
epoch : 5001    diff :  0.0067601       norm :  6.21467 accuracy :      75

// BN alpha = 0.01
epoch : 1       diff :  0.0174886       norm :  10.6918 accuracy :      15.625
epoch : 101     diff :  0.00526453      norm :  5.11817 accuracy :      85.9375
epoch : 201     diff :  0.00571623      norm :  6.26158 accuracy :      76.5625
epoch : 301     diff :  0.00175628      norm :  2.22323 accuracy :      98.4375
epoch : 401     diff :  0.0019938       norm :  2.80483 accuracy :      96.875
epoch : 501     diff :  0.000712807     norm :  1.19924 accuracy :      99.2188
epoch : 601     diff :  nan     norm :  nan     accuracy :      0


おすすめ

コメントを残す

メールアドレスが公開されることはありません。 が付いている欄は必須項目です