Image

PyTorchの最適化関数の使い方を解説!サンプルプログラムを紹介

PythonのPyTorchで、ディープラーニングを行う際に必要な最適化関数の使い方を解説してみました。最適化関数のサンプルプログラムも紹介しているのでPyTorchのディープラーニングの基本を学習することができます。

PyTorchの最適化関数について

今回は損失関数に引き続きディープラーニングでは必ず使用する最適化関数を押さえましょう。

目次

01. 最適化関数とは

02. PyTorchの最適化関数の使い方

03. まとめ


01. 最適化関数とは


最適化関数とはューラルネットワークのパラメータを最適なものにするための関数です。

というのも損失関数が最小になるようなパラメータを選択すれば理論上最適解ではあります。

しかし、その組み合わせは膨大でとても手探りでは不可能です。

そこで最適化関数を使用して最適化アルゴリズムによって最適解を見つけようというものです

こうすることで予測の精度が上昇します。


最適化関数ではニューラルネットワークのパラメータで損失関数を微分した時の値がゼロになるようにパラメータを決定する作業を行います。

最適化関数のアルゴリズムを勾配降下法といいます。

このような処理を繰り返すことで損失を最小化します。

Adamを使います。


パラメータとは数学で2つ以上の変数間の関数関係を間接的に用いる補助の変数

プログラミング的にはソフトウェアやシステムの挙動に影響を与える外部から投入されるデータ


02. PyTorchの最適化関数の使い方


Adamを利用する方法

Adamを利用することで最適化を比較的簡単に行うことができます。



#パッケージのインポート
import torch
from torch import nn
import torch.nn.functional as F
from torch import optim
import matplotlib.pyplot as plt

#ニューラルネットワークの定義
class Net(nn.Module):
def __init__(self, D_in, H, D_out):
super(Net, self).__init__()
self.linear1 = nn.Linear(D_in, H)
self.linear2 = nn.Linear(H, D_out)
def forward(self, x):
x = F.relu(self.linear1(x))
x = self.linear2(x)
return x

N = 64
D_in = 1000
H = 100
D_out = 10
epoch = 100

x = torch.rand(N, D_in)
y = torch.rand(N, D_out)

net = Net(D_in, H, D_out)
criterion = nn.MSELoss()

optimizer = optim.Adam(net.parameters(), lr=1e-4, betas=(0.9, 0.99), eps=1e-07
loss_list = []

for i in range(epoch):
y_pred = net(x)
loss = criterion(y_pred, y)
print("Epoch: {}, Loss: {:3f}".format(i+1, loss.item()))
loss_list.append(loss.item())

optimizer.zero_grad()
loss.backward()
optimizer.step()

plt.figure()
plt.title('Training Curve')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.plot(range(1,epoch+1), loss_list)
plt.show()
## 実行結果
Epoch: 1, Loss: 0.394883
Epoch: 2, Loss: 0.357445
Epoch: 3, Loss: 0.323615
Epoch: 4, Loss: 0.293376
Epoch: 5, Loss: 0.266553
Epoch: 6, Loss: 0.242410
Epoch: 7, Loss: 0.220740
Epoch: 8, Loss: 0.201153
Epoch: 9, Loss: 0.183613
Epoch: 10, Loss: 0.167950
Epoch: 11, Loss: 0.154022
Epoch: 12, Loss: 0.141958
...




04. まとめ


今回は最適化関数の使い方を学びました。

ざっくりとしたイメージができていればいいと思います。

お疲れ様でした。

PyTorchプログラミング入門を利用しながら実際に今回もやってみました。

まずは基本を押さえます。