バッチ正規化(Batch Normalization)の実装[c++ Arrayfire]
六花です。
昔、自分で書いたバッチ正規化の逆伝播(導関数)が計算合わなかったので捨てることにしまして、ネット上に良い実装が落ちていないか調べたところ、良さそうなものを見つけたので紹介します。
Deriving the Gradient for the Backward Pass of Batch Normalization
https://kevinzakka.github.io/2016/09/14/batch_normalization/
「Recap」の項目を見てもらえるとわかりますが、非常に簡潔な式に収まっています。すごい!
こちらを参考に、バッチ正規化層をArrayfireで実装しました。
§環境
windows 10
Microsoft Visual Studio Community 2022 (64 ビット) Version 17.3.4
ArrayFire v3.8.2
struct Layer
{
var_t isTrain = 1.0; // trainのときは1.0 testのときは0.0
virtual void init() {} // 勾配の初期化に用いる
virtual void forward() = 0; // 順伝播
virtual void backward() = 0; // 逆伝播
virtual void SGD() {} // 重みの更新
};
// https://kevinzakka.github.io/2016/09/14/batch_normalization/
struct BN_layer : public Layer
{
const af::array& x;
af::array& y;
af::array& dx;
const af::array& dy;
const dim_t size_batch;
af::array Std, Norm; // 一時変数
af::array Mean, Var; // 評価時に必要
af::array G, B; // 学習対象
af::array dG, dB;
BN_layer(const af::array& x, af::array& y, af::array& dx, const af::array& dy)
: x(x)
, y(y)
, dx(dx)
, dy(dy)
, size_batch(x.dims(1))
{
Mean = af::constant(0.0, x.dims(0), 1, dtype_t);
Var = af::constant(1.0, x.dims(0), 1, dtype_t);
G = af::constant(1.0, x.dims(0), 1, dtype_t);
B = af::constant(0.0, x.dims(0), 1, dtype_t);
dG = G * 0.0;
dB = B * 0.0;
}
virtual void init()
{
dG = 0.0;
dB = 0.0;
G.eval();
B.eval();
}
virtual void forward()
{
const af::array Mean_b = af::mean(x, 1);
const af::array Var_b = af::mean(af::pow(x - af::tile(Mean_b, 1, size_batch), 2.0), 1);
{
constexpr var_t eta = 0.99;
Mean = af::select(af::constant(isTrain, x.dims(0), 1, dtype_t) == 1.0, eta * Mean + (1.0 - eta) * Mean_b, Mean);
Var = af::select(af::constant(isTrain, x.dims(0), 1, dtype_t) == 1.0, eta * Var + (1.0 - eta) * ((var_t)size_batch / ((var_t)(size_batch - 1))) * Var_b, Var);
}
const af::array Mean_calc = af::select(af::constant(isTrain, x.dims(0), 1, dtype_t) == 1.0, Mean_b, Mean);
const af::array Var_calc = af::select(af::constant(isTrain, x.dims(0), 1, dtype_t) == 1.0, Var_b, Var);
Std = af::sqrt(Var_calc + eps);
Norm = (x - af::tile(Mean_calc, 1, size_batch)) / af::tile(Std, 1, size_batch);
y = af::tile(G, 1, size_batch) * Norm + af::tile(B, 1, size_batch);
Mean.eval();
Var.eval();
Std.eval();
Norm.eval();
}
virtual void backward()
{
dy.eval();
const af::array dNorm = dy * af::tile(G, 1, size_batch);
dB += af::sum(dy, 1);
dG += af::sum(dy * Norm, 1);
dx += (
(var_t)size_batch * dNorm
- af::tile(af::sum(dNorm, 1), 1, size_batch)
- Norm * af::tile(af::sum(dNorm * Norm, 1), 1, size_batch)
)
/ ((var_t)size_batch * af::tile(Std, 1, size_batch));
dB.eval();
dG.eval();
}
virtual void SGD()
{
G -= alpha * dG / (var_t)size_batch;
B -= alpha * dB / (var_t)size_batch;
G.eval();
B.eval();
}
};
バッチ正規化層は学習時と推論時で処理が違うので、基底クラスのLayerにisTrainという変数を追加しています。
これをモデルに組み込みます。
std::vector<af::array> data; // 評価値
data.push_back(af::constant(0.0, size_input, size_batch, dtype_t));
for (int i_hidden_layer = 0; i_hidden_layer < size_hidden_layer; ++i_hidden_layer)
{
data.push_back(af::constant(0.0, size_hidden, size_batch, dtype_t));
data.push_back(af::constant(0.0, size_hidden, size_batch, dtype_t));
data.push_back(af::constant(0.0, size_hidden, size_batch, dtype_t));
}
data.push_back(af::constant(0.0, size_output, size_batch, dtype_t));
std::vector<af::array> grad; // 誤差(傾き)
grad.push_back(af::constant(0.0, size_input, size_batch, dtype_t));
for (int i_hidden_layer = 0; i_hidden_layer < size_hidden_layer; ++i_hidden_layer)
{
grad.push_back(af::constant(0.0, size_hidden, size_batch, dtype_t));
grad.push_back(af::constant(0.0, size_hidden, size_batch, dtype_t));
grad.push_back(af::constant(0.0, size_hidden, size_batch, dtype_t));
}
grad.push_back(af::constant(0.0, size_output, size_batch, dtype_t));
std::vector<std::shared_ptr<Layer>> layer; // 処理層
layer.push_back(std::make_shared<FC_layer/* */>(data.at(0), data.at(1), grad.at(0), grad.at(1)));
for (int i_hidden_layer = 0; i_hidden_layer < size_hidden_layer * 3; i_hidden_layer += 3)
{
layer.push_back(std::make_shared<BN_layer/* */>(data.at(i_hidden_layer + 1), data.at(i_hidden_layer + 2), grad.at(i_hidden_layer + 1), grad.at(i_hidden_layer + 2)));
layer.push_back(std::make_shared<tanhExp_layer/**/>(data.at(i_hidden_layer + 2), data.at(i_hidden_layer + 3), grad.at(i_hidden_layer + 2), grad.at(i_hidden_layer + 3)));
layer.push_back(std::make_shared<FC_layer/* */>(data.at(i_hidden_layer + 3), data.at(i_hidden_layer + 4), grad.at(i_hidden_layer + 3), grad.at(i_hidden_layer + 4)));
}
深いモデルにしたい場合は、先日作成したShortcut_layerを挿入すべきかと思いますが、今回はsize_hidden_layer=2で実験するので必要ないと思います。
バッチ正規化層がない場合とある場合、学習率を100倍にしたときの正答率の推移を以下に載せます。
// tanhExp alpha = 0.0001
epoch : 1 diff : 0.0192724 norm : 11.1596 accuracy : 10.1562
epoch : 101 diff : 0.0179332 norm : 10.8161 accuracy : 14.8438
epoch : 201 diff : 0.017582 norm : 10.5947 accuracy : 22.6562
epoch : 301 diff : 0.0177274 norm : 10.6867 accuracy : 13.2812
epoch : 401 diff : 0.0177372 norm : 10.6986 accuracy : 21.0938
epoch : 501 diff : 0.0175152 norm : 10.584 accuracy : 23.4375
epoch : 601 diff : 0.0173587 norm : 10.4955 accuracy : 25
epoch : 701 diff : 0.0173365 norm : 10.5304 accuracy : 24.2188
epoch : 801 diff : 0.0171328 norm : 10.4384 accuracy : 25
epoch : 901 diff : 0.0171049 norm : 10.4661 accuracy : 28.125
epoch : 1001 diff : 0.0167975 norm : 10.3115 accuracy : 28.9062
epoch : 1101 diff : 0.0166932 norm : 10.2941 accuracy : 28.125
epoch : 1201 diff : 0.0162885 norm : 10.0633 accuracy : 39.8438
epoch : 1301 diff : 0.0162624 norm : 10.0967 accuracy : 35.9375
epoch : 1401 diff : 0.015995 norm : 9.95596 accuracy : 43.75
epoch : 1501 diff : 0.0158486 norm : 9.92333 accuracy : 41.4062
epoch : 1601 diff : 0.0155752 norm : 9.75906 accuracy : 47.6562
epoch : 1701 diff : 0.0154601 norm : 9.73563 accuracy : 49.2188
epoch : 1801 diff : 0.0153018 norm : 9.68003 accuracy : 53.9062
epoch : 1901 diff : 0.0151829 norm : 9.6462 accuracy : 48.4375
epoch : 2001 diff : 0.0149164 norm : 9.49316 accuracy : 50.7812
epoch : 2101 diff : 0.0146155 norm : 9.33597 accuracy : 58.5938
epoch : 2201 diff : 0.0145862 norm : 9.35398 accuracy : 53.9062
epoch : 2301 diff : 0.0144299 norm : 9.27817 accuracy : 56.25
epoch : 2401 diff : 0.0143156 norm : 9.24132 accuracy : 56.25
epoch : 2501 diff : 0.0138128 norm : 8.93688 accuracy : 63.2812
epoch : 2601 diff : 0.0137335 norm : 8.91735 accuracy : 64.0625
epoch : 2701 diff : 0.0136017 norm : 8.8589 accuracy : 60.9375
epoch : 2801 diff : 0.0135729 norm : 8.87159 accuracy : 61.7188
epoch : 2901 diff : 0.0134057 norm : 8.81603 accuracy : 64.8438
epoch : 3001 diff : 0.012934 norm : 8.52779 accuracy : 74.2188
epoch : 3101 diff : 0.0129994 norm : 8.62714 accuracy : 70.3125
epoch : 3201 diff : 0.0125961 norm : 8.39045 accuracy : 73.4375
epoch : 3301 diff : 0.0126337 norm : 8.42499 accuracy : 75.7812
epoch : 3401 diff : 0.0127017 norm : 8.50228 accuracy : 66.4062
epoch : 3501 diff : 0.0120642 norm : 8.09769 accuracy : 76.5625
epoch : 3601 diff : 0.0119556 norm : 8.0563 accuracy : 77.3438
epoch : 3701 diff : 0.0121662 norm : 8.24673 accuracy : 71.875
epoch : 3801 diff : 0.0119128 norm : 8.08681 accuracy : 77.3438
epoch : 3901 diff : 0.011671 norm : 7.92819 accuracy : 78.9062
epoch : 4001 diff : 0.0115751 norm : 7.91488 accuracy : 78.125
epoch : 4101 diff : 0.0113907 norm : 7.83296 accuracy : 75.7812
epoch : 4201 diff : 0.0112091 norm : 7.72397 accuracy : 84.375
epoch : 4301 diff : 0.0109818 norm : 7.57628 accuracy : 85.9375
epoch : 4401 diff : 0.0109191 norm : 7.56331 accuracy : 85.1562
epoch : 4501 diff : 0.0108726 norm : 7.58057 accuracy : 79.6875
epoch : 4601 diff : nan norm : nan accuracy : 0
// BN alpha = 0.0001
epoch : 1 diff : 0.0196451 norm : 11.3754 accuracy : 0
epoch : 101 diff : 0.0174411 norm : 10.6922 accuracy : 18.75
epoch : 201 diff : 0.0163546 norm : 10.141 accuracy : 32.0312
epoch : 301 diff : 0.0157799 norm : 9.96257 accuracy : 29.6875
epoch : 401 diff : 0.0156522 norm : 10.0103 accuracy : 31.25
epoch : 501 diff : 0.014774 norm : 9.51013 accuracy : 47.6562
epoch : 601 diff : 0.0141584 norm : 9.18429 accuracy : 49.2188
epoch : 701 diff : 0.0140881 norm : 9.24847 accuracy : 46.875
epoch : 801 diff : 0.0134546 norm : 8.9029 accuracy : 57.0312
epoch : 901 diff : 0.0133436 norm : 8.93784 accuracy : 55.4688
epoch : 1001 diff : 0.0131468 norm : 8.8971 accuracy : 50
epoch : 1101 diff : 0.0126429 norm : 8.65494 accuracy : 59.375
epoch : 1201 diff : 0.0118861 norm : 8.21505 accuracy : 67.9688
epoch : 1301 diff : 0.011964 norm : 8.31101 accuracy : 66.4062
epoch : 1401 diff : 0.0112667 norm : 7.88477 accuracy : 70.3125
epoch : 1501 diff : 0.010903 norm : 7.76457 accuracy : 68.75
epoch : 1601 diff : 0.0117846 norm : 8.57495 accuracy : 55.4688
epoch : 1701 diff : 0.0106091 norm : 7.71092 accuracy : 65.625
epoch : 1801 diff : 0.0109676 norm : 8.0774 accuracy : 63.2812
epoch : 1901 diff : 0.00975205 norm : 7.18112 accuracy : 76.5625
epoch : 2001 diff : 0.00960622 norm : 7.21011 accuracy : 75
epoch : 2101 diff : 0.0108849 norm : 8.28597 accuracy : 60.1562
epoch : 2201 diff : 0.00994902 norm : 7.64011 accuracy : 65.625
epoch : 2301 diff : 0.00859682 norm : 6.57081 accuracy : 81.25
epoch : 2401 diff : 0.00841644 norm : 6.45366 accuracy : 81.25
epoch : 2501 diff : 0.00817581 norm : 6.40977 accuracy : 84.375
epoch : 2601 diff : 0.00847368 norm : 6.75952 accuracy : 76.5625
epoch : 2701 diff : 0.00763888 norm : 5.97458 accuracy : 88.2812
epoch : 2801 diff : 0.00798304 norm : 6.32499 accuracy : 81.25
epoch : 2901 diff : 0.00754648 norm : 6.05507 accuracy : 84.375
epoch : 3001 diff : 0.00737502 norm : 5.94544 accuracy : 87.5
epoch : 3101 diff : 0.0100366 norm : 8.25889 accuracy : 58.5938
epoch : 3201 diff : 0.00735531 norm : 5.9466 accuracy : 85.9375
epoch : 3301 diff : 0.00678528 norm : 5.45495 accuracy : 92.1875
epoch : 3401 diff : 0.00905309 norm : 7.6949 accuracy : 62.5
epoch : 3501 diff : 0.00667509 norm : 5.53527 accuracy : 87.5
epoch : 3601 diff : 0.00728889 norm : 6.20548 accuracy : 80.4688
epoch : 3701 diff : 0.00713524 norm : 6.11726 accuracy : 79.6875
epoch : 3801 diff : 0.00912937 norm : 7.94254 accuracy : 62.5
epoch : 3901 diff : 0.00651565 norm : 5.46807 accuracy : 87.5
epoch : 4001 diff : 0.00728201 norm : 6.07996 accuracy : 83.5938
epoch : 4101 diff : 0.00588398 norm : 5.05563 accuracy : 89.8438
epoch : 4201 diff : 0.00651716 norm : 5.63472 accuracy : 82.8125
epoch : 4301 diff : 0.0063791 norm : 5.50415 accuracy : 88.2812
epoch : 4401 diff : 0.0061517 norm : 5.2178 accuracy : 89.8438
epoch : 4501 diff : 0.00752665 norm : 6.79134 accuracy : 70.3125
epoch : 4601 diff : 0.0058363 norm : 5.1437 accuracy : 87.5
epoch : 4701 diff : 0.00724119 norm : 6.41938 accuracy : 76.5625
epoch : 4801 diff : 0.00707362 norm : 6.23075 accuracy : 80.4688
epoch : 4901 diff : 0.00731898 norm : 6.56829 accuracy : 76.5625
epoch : 5001 diff : 0.0067601 norm : 6.21467 accuracy : 75
// BN alpha = 0.01
epoch : 1 diff : 0.0174886 norm : 10.6918 accuracy : 15.625
epoch : 101 diff : 0.00526453 norm : 5.11817 accuracy : 85.9375
epoch : 201 diff : 0.00571623 norm : 6.26158 accuracy : 76.5625
epoch : 301 diff : 0.00175628 norm : 2.22323 accuracy : 98.4375
epoch : 401 diff : 0.0019938 norm : 2.80483 accuracy : 96.875
epoch : 501 diff : 0.000712807 norm : 1.19924 accuracy : 99.2188
epoch : 601 diff : nan norm : nan accuracy : 0