φυβλαςのβλογ
phyblasのブログ



pytorch เบื้องต้น บทที่ ๕: ออปทิไมเซอร์
เขียนเมื่อ 2018/09/08 09:57
แก้ไขล่าสุด 2021/09/28 16:42
>> ต่อจาก บทที่ ๔



ออปทิไมเซอร์ใน pytorch

ออปทิไมเซอร์คืออุปกรณ์สำหรับช่วยในการปรับพารามิเตอร์ด้วยวิธีการเคลื่อนลงตามความชัน

ใน pytorch ได้เตรียมออปทิไมเซอร์ชนิดต่างๆไว้ในมอดูลย่อย torch.optim

ในบทที่ผ่านมาเป็นการใช้วิธีการเคลื่อนลงตามความชันธรรมดาโดยการเขียนเองโดยไม่ได้ใช้ออปทิไมเซอร์ช่วย

เมื่อใช้ออปทิไมเซอร์จะทำให้การเขียนดูเรียบง่ายลง อีกทั้งยังสามารถใช้วิธีที่เหนือกว่าแค่การเคลื่อนลงตามความชันธรรมดา เช่น Adagrad หรือ Adam ได้ง่ายๆ

เกี่ยวกับออปทิไมเซอร์ชนิดต่างๆอ่านรายละเอียดได้ใน https://phyblas.hinaboshi.com/20171002

ออปทิไมเซอร์ที่เตรียมไว้ใน pytorch ได้แก่
- torch.optim.SGD
- torch.optim.Adam
- torch.optim.Adamax
- torch.optim.Adadelta
- torch.optim.Adagrad
- torch.optim.Rprop
- torch.optim.RMSprop

SGD ก็คือวิธีแบบดั้งเดิมที่แค่คำนวณอนุพันธ์แล้วใช้ปรับค่าเลยโดยตรง

ส่วนวิธีการที่มักถูกแนะนำให้ใช้มากที่สุดคือ Adam ในที่นี้ก็จะใช้วิธีนี้เป็นหลักด้วย



การใช้ออปทิไมเซอร์

ขอยกตัวอย่างการนำออปทิไมเซอร์มาใช้ในการปรับพารามิเตอร์ใน pytorch

ตัวอย่างขอใช้เป็นการวิเคราะห์การถดถอยเชิงเส้นสองตัวแปรเช่นเดียวกับในบทที่แล้ว เขียนโดยใช้ออปทิไมเซอร์ได้ดังนี้
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import torch

X = np.random.uniform(-2,2,[300,2])
x,y = X.T
z = x*0.7+y*1.5 + np.random.normal(1,0.4,300)

lin = torch.nn.Linear(2,1)
X = torch.Tensor(X)
z = torch.Tensor(z)
eta = 0.05
ha_mse = torch.nn.MSELoss()
opt = torch.optim.Adam(lin.parameters(),lr=eta) # สร้างออปทิไมเซอร์
n_thamsam = 100
for i in range(n_thamsam):
    h = lin(X).flatten()
    J = ha_mse(h,z)
    J.backward()
    opt.step() # ปรับพารามิเตอร์
    opt.zero_grad() # ล้างอนุพันธ์

mx,my = np.meshgrid(np.linspace(-2,2,11),np.linspace(-2,2,11))
mX = torch.Tensor(np.array([mx.ravel(),my.ravel()]).T)
mz = lin(mX).data.numpy().reshape(11,11)
plt.figure(figsize=[6,6])
ax = plt.axes([0,0,1,1],projection='3d')
ax.scatter(x,y,z,c=z,edgecolor='k',cmap='coolwarm')
ax.plot_surface(mx,my,mz,color='g',rstride=1,cstride=1,alpha=0.1,edgecolor='k')
plt.show()


สิ่งที่เปลี่ยนแปลงจากคราวก่อนก็คือ มีการสร้างออปทิไมเซอร์ขึ้น แล้วส่งพารามิเตอร์เข้าไปให้ออปทิไมเซอร์บันทึกไว้ว่าตัวนี้คือพารามิเตอร์ที่ต้องการปรับค่า ส่วนอัตราการเรียนรู้ก็กำหนดโดยคีย์เวิร์ด lr ก็ป้อนเข้าไปด้วย

นอกจากนี้ค่าตัวเลือกอื่นๆตามชนิดของออปทิไมเซอร์ ก็สามารถกำหนดค่าตรงนี้ด้วย เช่นกรณีของ Adam จะมีค่า β1,β2 ซึ่งกำหนดโดยใส่คีย์เวิร์ด betas โดยใส่พร้อมกันทั้ง ๒ ตัว เป็น betas=(β12) กรณีที่ไม่ใส่ก็จะเป็นค่ามาตรฐานตั้งต้น คือ betas=(0.9,0.999)

จากนั้นในระหว่างวนซ้ำ จะใช้คำสั่ง .step() เพื่อทำการปรับค่าพารามิเตอร์ทั้งหมดตามความชันที่คำนวณได้หลังจาก .backward() ไป

สุดท้ายก็สั่ง .zero_grad() เพื่อล้างอนุพันธ์ทั้งหมดที่หามาในรอบนั้นให้เป็น 0 ก่อนที่จะคำนวณรอบถัดไป

ในการใช้งาน pytorch โดยทั่วไปก็จะใช้ออปทิไมเซอร์ในลักษณะนี้ตลอด เป็นขั้นตอนที่ค่อนข้างตายตัว (วิธีที่ผ่านมาในบทก่อนๆแค่เพื่ออธิบายหลักการที่อยู่เบื้องหลัง เพื่อความเข้าใจ)

.zero_grad() นี้บางคนก็วางไว้ก่อน .backward() ผลที่ได้ก็เหมือนกัน ที่สำคัญคือก่อนจะ backward รอบใหม่ต้องมี zero_grad ก่อน



ดูข้อมูลของตัวออปทิไมเซอร์

ถ้าสั่ง print ตัวออปทิไมเซอร์ ข้อมูลของตัวออปทิไมเซอร์นั้นจะแสดงออกมา
lin = torch.nn.Linear(3,1)
opt = torch.optim.Adam(lin.parameters())
print(opt)

ได้
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    eps: 1e-08
    lr: 0.001
    weight_decay: 0
)

หากต้องการดูว่าพารามิเตอร์ที่ออปทิไมเซอร์ตัวนั้นดูแลอยู่มีอะไรบางให้ดูที่แอตทริบิวต์ .param_groups
print(opt.param_groups)

ได้
[{'amsgrad': False, 'eps': 1e-08, 'lr': 0.001, 'betas': (0.9, 0.999), 'params': [Parameter containing:
tensor([[-0.4728, -0.1602, -0.0321]], requires_grad=True), Parameter containing:
tensor([-0.4072], requires_grad=True)], 'weight_decay': 0}]



การทำให้พารามิเตอร์บางตัวไม่ปรับค่า

พารามิเตอร์โดยทั่วไปจะมีการคำนวณและบันทึกค่าอนุพันธ์เสมอ นั่นคือเป็นเทนเซอร์ที่ถูกตั้ง requires_grad=True

แต่เราสามารถตั้ง requires_grad=False ได้ ซึ่งจะทำให้พารามิเตอร์ตัวนั้นไม่มีการคำนวณอนุพันธ์ และต่อให้ถูกส่งให้ออปทิไมเซอร์ก็จะไม่มีการเปลี่ยนแปลงค่า

ถ้าพารามิเตอร์ที่ส่งให้ออปทิไมเซอร์ทุกตัวตั้ง requires_grad=False หมด จะ error เพราะไม่มีพารามิเตอร์ให้ปรับค่า
lin = torch.nn.Linear(2,2)
lin.weight.requires_grad = False
lin.bias.requires_grad = False
opt = torch.optim.Adam(lin.parameters()) # RuntimeError


>> อ่านต่อ บทที่ ๖


-----------------------------------------

囧囧囧囧囧囧囧囧囧囧囧囧囧囧囧囧囧囧囧囧囧囧囧囧囧

ดูสถิติของหน้านี้

หมวดหมู่

-- คอมพิวเตอร์ >> ปัญญาประดิษฐ์ >> โครงข่ายประสาทเทียม
-- คอมพิวเตอร์ >> เขียนโปรแกรม >> python >> pytorch

ไม่อนุญาตให้นำเนื้อหาของบทความไปลงที่อื่นโดยไม่ได้ขออนุญาตโดยเด็ดขาด หากต้องการนำบางส่วนไปลงสามารถทำได้โดยต้องไม่ใช่การก๊อปแปะแต่ให้เปลี่ยนคำพูดเป็นของตัวเอง หรือไม่ก็เขียนในลักษณะการยกข้อความอ้างอิง และไม่ว่ากรณีไหนก็ตาม ต้องให้เครดิตพร้อมใส่ลิงก์ของทุกบทความที่มีการใช้เนื้อหาเสมอ

目次

日本による名言集
モジュール
-- numpy
-- matplotlib

-- pandas
-- manim
-- opencv
-- pyqt
-- pytorch
機械学習
-- ニューラル
     ネットワーク
maya
javascript
確率論
日本での日記
中国での日記
-- 北京での日記
-- 香港での日記
-- 澳門での日記
台灣での日記
北欧での日記
他の国での日記
qiita
その他の記事

記事の類別



ติดตามอัปเดตของบล็อกได้ที่แฟนเพจ

  記事を検索

  おすすめの記事

การสร้างแบบจำลองสามมิติเป็นไฟล์ .obj วิธีการอย่างง่ายที่ไม่ว่าใครก็ลองทำได้ทันที
รวมรายชื่อนักร้องเพลงกวางตุ้ง
ภาษาจีนแบ่งเป็นสำเนียงอะไรบ้าง มีความแตกต่างกันมากแค่ไหน
ทำความเข้าใจระบอบประชาธิปไตยจากประวัติศาสตร์ความเป็นมา
เรียนรู้วิธีการใช้ regular expression (regex)
หลักการเขียนทับศัพท์ภาษาจีนกวางตุ้ง
การใช้ unix shell เบื้องต้น ใน linux และ mac
หลักการเขียนทับศัพท์ภาษาจีนกลาง
g ในภาษาญี่ปุ่นออกเสียง "ก" หรือ "ง" กันแน่
ทำความรู้จักกับปัญญาประดิษฐ์และการเรียนรู้ของเครื่อง
ค้นพบระบบดาวเคราะห์ ๘ ดวง เบื้องหลังความสำเร็จคือปัญญาประดิษฐ์ (AI)
หอดูดาวโบราณปักกิ่ง ตอนที่ ๑: แท่นสังเกตการณ์และสวนดอกไม้
พิพิธภัณฑ์สถาปัตยกรรมโบราณปักกิ่ง
เที่ยวเมืองตานตง ล่องเรือในน่านน้ำเกาหลีเหนือ
บันทึกการเที่ยวสวีเดน 1-12 พ.ค. 2014
แนะนำองค์การวิจัยและพัฒนาการสำรวจอวกาศญี่ปุ่น (JAXA)
เล่าประสบการณ์ค่ายอบรมวิชาการทางดาราศาสตร์โดยโซวเคนได 10 - 16 พ.ย. 2013
ตระเวนเที่ยวตามรอยฉากของอนิเมะในญี่ปุ่น
เที่ยวชมหอดูดาวที่ฐานสังเกตการณ์ซิงหลง
บันทึกการเที่ยวญี่ปุ่นครั้งแรกในชีวิต - ทุกอย่างเริ่มต้นที่สนามบินนานาชาติคันไซ
หลักการเขียนทับศัพท์ภาษาญี่ปุ่น
ทำไมจึงไม่ควรเขียนวรรณยุกต์เวลาทับศัพท์ภาษาต่างประเทศ
ทำไมถึงอยากมาเรียนต่อนอก
เหตุผลอะไรที่ต้องใช้ภาษาวิบัติ?

ไทย

日本語

中文