Image

PyTorchの損失関数の使い方を解説!サンプルプログラムを紹介

PythonのPyTorchで、ディープラーニングを行う際に必要な損失関数の使い方を解説してみました。損失関数のサンプルプログラムも紹介しているのでPyTorchのディープラーニングの基本を学習することができます。

PyTorchの損失関数について

今回はディープラーニングでは必ず使う損失関数について押さえていきましょう

目次   

01. 損失関数とは

_02. バイナリ交差エントロピー損失

_03. ロジット付きバイナリ交差エントロピー損失

_04. ソフトマックス交差エントロピー損失

_05. 平均二乗誤差損失

_06. 平均絶対誤差損失

07. まとめ


01. 損失関数とは


まず損失関数とは、ニューラルネットワークの予測がうまく行ったのかどうか判断するために使用する関数です

この関数を使用して、予測と答えの誤差を求めます。

その誤差が最小になれば予測はより正確なものだったという評価がなされます

損失関数には下で触れるだけの種類があり、目的によって使い分けます

このような関数を用いて数学的なアプローチをすることで機械学習の予測の正確性を高めていきます。

以下では数学的な要素に踏み込みすぎない程度に、プログラミングでの活用方法をアウトプットしていきます。

ちなみにエントロピーとは不規則性の程度を表す量をいいます。

その通りといった感じですね。


_02. バイナリ交差エントロピー損失


バイナリ交差エントロピー損失データのクラスが2クラスの場合に使用します。

2クラスというのはデータの種類が2つであることを意味します。

バイナリ交差エントロピー損失は一種の距離を表すような指標で、ニューラルネットワークの出力と正解との間にどの程度の差があるのかを示す尺度です。

n個のデータがあったとしてバイナリ交差エントロピー損失L(y,t)はデータiに対するクラス1の予測確率yiと正解jクラスtiを表します。

クラス1の予測値yiニューラルネットワークの出力層から出力された値をシグモイド関数で変換した確率値を表しています。

出力層からの出力値をロジットといいます。

ロジットとはあるあるクラスの確率pとそうでない確率1-pnの比に対数をとった値です。

先に出てきたシグモイド関数はロジット関数の逆関数です。

そのためシグモイド関数にロジットを入力することでクラスの確率pを求めることができます。

要は出力値を0から1の範囲に抑えつつ扱いやすい確率の形に変換できる公式といった感じです。

バイナリ交差エントロピー損失の関数はnn.BCELoss()です。

シグモイド関数はnn.Sigmoid()です。

なお、nn.BCELossはtorch.float32型をデータ型として使用しなければなりません

そのため正解クラスのデータ型は本来intですがfloatに変換する必要があります



import torch
from torch import nn
m = nn.Sigmoid()
y = torch.rand(3)
t = torch.empty(3, dtype=torch.float32).random_(2)
criterion = nn.BCELoss()
loss = criterion(m(y), t)

print("y: {}".format(y))
print("m(y): {}".format(m(y)))
print("t: {}".format(t))
print("loss: {:.4f}".format(loss))
## 実行結果
>>>
y: tensor([0.2744, 0.9147, 0.3309])
m(y): tensor([0.5682, 0.7140, 0.5820])
t: tensor([0., 1., 0.])
loss: 0.6830


lossがバイナリ交差エントロピー損失です。


_03. ロジット付きバイナリ交差エントロピー損失


ロジット付きバイナリ交差エントロピー損失はバイナリ交差エントロピー損失に最初からシグモイド関数が加えられたものです

すなわち出力値をそのまま与えればバイナリ交差エントロピー損失が得られます。

n個のデータがあったとして、ロジット付きバイナリ交差エントロピー損失はデータiに対するロジットyiと正解のクラスtiをL(y, t)として表すことができます

ロジット付きバイナリ交差エントロピー損失の関数はnn.BCEWithLogitsLoss()です。

長いですね



import torch
from torch import nn
y = torch.rand(3)
t = torch.empty(3, dype=torch.float32).random_(2)
criterion = nn.BCEWithLogitsLoss()
loss = criterion(y, t)

print("y: {}".format(y))
print("t: {}".format(t))
print("loss: {:.4f}".format(loss))
## 実行結果
y: tensor([0.9709, 0.8976, 0.3228])
t: tensor([0., 1., 0.])
loss: 0.8338


lossがロジット付きバイナリ交差エントロピー損失です。

.format()では指定した変数を{}の中に代入してそれを出力しています。


_04. ソフトマックス交差エントロピー損失


ソフトマックス交差エントロピー損失もバイナリ交差エントロピー損失と同じように、ニューラルネットワークの出力と正解クラスがどのくらい離れているかを評価する尺度です

特に2クラス以上の多クラスに分類されている場合に用いられます。

2クラスの分類ではシグモイド関数を使用しましたが、2クラス以上ではソフトマックス交差エントロピー損失を使用します。

ソフトマックスエントロピー損失はn個のデータがあったとしてデータiに対するクラスkのロジットyiと正解クラスtiのデータを使用してL(y, t)で表すことが可能です。

ソフトマックス交差エントロピー損失はnn.CrossEntropyLossです。



import torch
from torch import nn
y = torch.rand(3, 5)
t = torch.empty(3, dtype=torch.int64).random_(5)
criterion = nn.CrossEntropyLoss()
loss = criterion(y, t)

print("y:{}".format(y))
print("t:{}".format(t))
print("loss: {:4f}".format(loss))
## 実行結果
y: tensor([[0.7775, 0.7587, 0.9474, 0.5149, 0.7741],
[0.5059, 0.4802, 0.9846, 0.6292, 0.0167],
[0.4339, 0.6873, 0.4253, 0.7067, 0.5678]])
t: tensor([1, 4, 1])
loss: 1.757074


データ数は3つで各クラスに出力します。

クラス数は5つです。

lossがソフトマックス交差エントロピー損失を表しています


_05. 平均二乗誤差損失


平均二乗誤差損失は回帰問題でよく用いられる損失関数です。

値の意味としてはニューラルネットワークの予測値と正解値の差を二乗した値の平均を求めたものです。

利用用途としては体重や性別で足のサイズや身長を予測するようなものに用います。

この値が小さければ誤差は小さいということで正しい予測ができていたことがわかります。

n個のデータがあったときデータiの予測値yiと正解値tiの二乗誤差平均損失はL(y, t)で表します。

関数はnn.MSELoss()を使用します。



import torch
from torch import nn
y = torch.rand(1, 10)
t = torch.rand(1, 10)
criterion = nn.MSELoss()
loss = criterion(y, t)

print(loss)
## 実行結果
tensor(0.2185)


tensor(0.2185)が平均二乗誤差損失です


_06. 平均絶対誤差損失


平均絶対誤差損失は平均二乗誤差損失と同じように回帰問題で利用されます。

平均二乗誤差損失では予測値と正解値の差を二乗した値の平均を求めます。

平均絶対誤差損失は予測値と正解値の差の絶対値の平均を求めます。

n個のデータがあったとしてデータiの予測値yiと正解値tiの平均絶対損失はL(y, t)で表されます。

関数はnn.L1Loss()を使用します。



import torch
from torch import nn
y = torch.rand(1, 10)
t = torch.rand(1, 10)
criterion = nn.L1Loss()
loss = criterion(y, t)

print(loss)
## 実行結果
tensor(0.4042)


07. まとめ


お疲れ様でした。

ここまで読んでいただきありがとうございました。

今回はニューラルネットワークの予測が正しいか評価する損失関数について学びました。

損失関数がどんなもので使い方の例が押さえられていれば問題ありません。

次回は最適化関数ついて学びます。

PyTorchプログラミング入門を利用しながら実際に今回もやってみました。

ありがとうございます。

まずは基本を押さえます。