クラス分類問題でもGradientCheckingしよう[c++ Arrayfire]
六花です。
モデルを積層化できるようになったのでバッチ正規化層(Batch Normalization Layer)の実装について書こうと思いましたが、実装してみたところ回帰分析だと学習がうまく行かないようなので、まずクラス分類問題を解けるようにしたいと思います。
§環境
windows 10
Microsoft Visual Studio Community 2022 (64 ビット) Version 17.3.4
ArrayFire v3.8.2
まず問題を作ります。
inputの要素は0から1をランダムに代入します。
outputのラベル番号はinputをすべて足し算した結果の整数部分とします。
例えばinputが100個ある場合、足し算した結果は0から100になりますから、outputのラベルは101個必要になります。
(101番目のラベルが使用される可能性は限りなく低いですが、あった方が破綻しないでしょう。)
// inputの合計によってoutputのクラスが異なる
// 多分output = input + 1にしておいた方が問題がないと思う
af::array input = af::randu(size_input, size_data, dtype_t);
af::array output = af::constant(0.0, size_output, size_data, dtype_t);
// 最初に全て0の行列を作ってから、条件を満たしたところに1を入れていく
af::array temp_input_sum = af::sum(input, 0);
gfor(af::seq seq, size_output)
{
output(seq, af::span) =
af::select(
seq + 0.0 <= temp_input_sum(0, af::span) && temp_input_sum(0, af::span) < seq + 1.0
, 1.0
, output(0, af::span) * 0.0);
}
gforはArrayfireの特殊なfor文で、for文と異なり並列処理で計算するので処理が早くなります。
ただ、イテレータがaf::seqという特殊なもので、Arrayfire独自の型を使います。
さて、クラス分類で必要なものと言えば、softmaxと交差エントロピー誤差です。
これらの説明は検索すると詳しくわかりやすい説明が大量に出てくるので、Arrayfireにおいての実装について以下に記述します。
// 順伝播と逆伝播
auto forback = [&]()
{
// 順伝播
for (auto& itm : layer) { itm->forward(); }
// 必要ならsoftmax
#ifdef DEF_use_softmax
data.back() = af::exp(data.back()) / af::tile(af::sum(af::exp(data.back()), 0), data.back().dims(0));
#endif
// 誤差の計算
grad.back() = data.back() - output(af::span, af::seq(0, size_batch - 1)); // y - t
//逆伝播
std::for_each(layer.rbegin(), layer.rend(), [](auto& itm) { itm->backward(); });
};
順伝播と逆伝播をひとまとめにしたラムダ式です。
クラス分類したい時だけ、#ifdefの中身を有効にすれば良いです。
SoftMax層として層を作るというやり方もありますが、誤差はdx += dyで素通りなので、今回はこの方法で実装しました。
#ifdef DEF_use_softmax
grad_plus = -1.0 * af::sum<var_t>((output(af::span, af::seq(0, size_batch - 1)) * af::log(data.back()))); // cross entropy
#else
grad_plus = -1.0 * af::sum<var_t>(af::pow(grad.back(), 2.0) / (var_t)2); // mse
#endif
誤差関数です。
softmaxを使わない場合はMSEを使いますが、softmaxを使う場合は交差エントロピー誤差を使用します。
出力と教師信号を取り違えないように注意してください。
#ifdef DEF_use_softmax
// 正答率の算定
af::array val, plabels, tlabels;
af::max(val, tlabels, output(af::span, idx_target), 0);
af::max(val, plabels, data.back(), 0);
accuracy = (var_t)100 * af::count<var_t>(plabels == tlabels) / (var_t)tlabels.elements();
#endif
正答率の算定は、af::maxのリファレンスから飛んだサンプルソースコードをほぼコピーしました。
valはおそらく最大値が入っている行列で、今回の正答率の算出には必要ありません。(引数に指定する変数として必要)
https://arrayfire.org/docs/machine_learning_2neural_network_8cpp-example.htm#a4
constexpr var_t alpha = 0.01; // 学習率
constexpr int size_input = 100; // 入力層の要素数
constexpr int size_hidden = 300; // 隠れ層の要素数
constexpr int size_output = 101; // 出力層の要素数
constexpr int size_hidden_layer = 1; // 隠れ層の規模
std::vector<af::array> data; // 評価値
data.push_back(af::constant(0.0, size_input, size_batch, dtype_t));
for (int i_hidden_layer = 0; i_hidden_layer < size_hidden_layer; ++i_hidden_layer)
{
data.push_back(af::constant(0.0, size_hidden, size_batch, dtype_t));
data.push_back(af::constant(0.0, size_hidden, size_batch, dtype_t));
}
data.push_back(af::constant(0.0, size_output, size_batch, dtype_t));
std::vector<af::array> grad; // 誤差(傾き)
grad.push_back(af::constant(0.0, size_input, size_batch, dtype_t));
for (int i_hidden_layer = 0; i_hidden_layer < size_hidden_layer; ++i_hidden_layer)
{
grad.push_back(af::constant(0.0, size_hidden, size_batch, dtype_t));
grad.push_back(af::constant(0.0, size_hidden, size_batch, dtype_t));
}
grad.push_back(af::constant(0.0, size_output, size_batch, dtype_t));
std::vector<std::shared_ptr<Layer>> layer; // 処理層
layer.push_back(std::make_shared<FC_layer/* */>(data.at(0), data.at(1), grad.at(0), grad.at(1)));
for (int i_hidden_layer = 0; i_hidden_layer < size_hidden_layer * 2; i_hidden_layer += 2)
{
layer.push_back(std::make_shared<tanhExp_layer/**/>(data.at(i_hidden_layer + 1), data.at(i_hidden_layer + 2), grad.at(i_hidden_layer + 1), grad.at(i_hidden_layer + 2)));
layer.push_back(std::make_shared<FC_layer/* */>(data.at(i_hidden_layer + 2), data.at(i_hidden_layer + 3), grad.at(i_hidden_layer + 2), grad.at(i_hidden_layer + 3)));
}
上記のモデルで、学習させた結果がこちらです。
epoch : 1 diff : 0.0179113 norm : 10.8407 accuracy : 10.1562
epoch : 101 diff : 0.0152842 norm : 9.6419 accuracy : 46.0938
epoch : 201 diff : 0.013505 norm : 8.84851 accuracy : 68.75
epoch : 301 diff : 0.0120635 norm : 8.15111 accuracy : 77.3438
次第に正答率が上がっていくのが確認できます。