φυβλαςのβλογ
phyblas的博客



pytorch เบื้องต้น บทที่ ๑๕: การบันทึกและอ่านแบบจำลองที่เรียนรู้เสร็จแล้ว
เขียนเมื่อ 2018/09/26 09:24
แก้ไขล่าสุด 2022/07/09 19:03
>> ต่อจาก บทที่ ๑๔



ในการใช้โครงข่ายประสาทเทียม จะเห็นว่าขั้นตอนประกอบไปด้วย
๑. สร้างแบบจำลอง
๒. ฝึกแบบจำลอง
๓. นำแบบจำลองที่ฝึกเสร็จแล้วมาใช้

ขั้นตอนที่กินเวลานานสุดคือขั้นตอนการฝึก แต่พอฝึกเสร็จแล้วเราอาจนำไปใช้ได้ตลอด ไม่ต้องกลับมาฝึกต่อแล้ว

ดังนั้นเพื่อให้ไม่ต้องฝึกใหม่ตลอด เราจำเป็นจะต้องบันทึกผลของการฝึก นั่นก็คือบันทึกพารามิเตอร์ที่ได้มาหลังเรียนรู้

pytorch ได้มีเตรียมคำสั่งสำหรับช่วยทำการบันทึกเก็บแบบจำลองหรือค่าพารามิเตอร์ที่เรียนรู้เสร็จแล้วได้อย่างง่าย มีอยู่ ๒ วิธี คือ เก็บทั้งตัวแบบจำลอง หรือจะเก็บแค่พารามิเตอร์



บันทึกตัวแบบจำลอง

ใช้คำสั่ง torch.save กับแบบจำลองที่เรียนรู้เสร็จแล้ว แบบจำลองจะถูกบันทึกในรูปแบบของ pickle ซึ่งเป็นรูปแบบการเก็บไฟล์แบบมาตรฐานของไพธอน เพียงแต่คำสั่งนี้ของ pytorch ช่วยให้สามารถทำการบันทึกเก็บได้โดยเขียนแค่สั้นๆในบรรทัดเดียวจึงมีความสะดวกมาก

ตัวอย่าง ลองสร้างโครงข่ายง่ายๆแล้วทำการฝึก ฝึกเสร็จก็บันทึก
import torch
x = torch.rand(100)*6
z = x.sin() + torch.randn(100)/6
khrongkhai1 = torch.nn.Sequential(
    torch.nn.Linear(1,32),
    torch.nn.ReLU(),
    torch.nn.Linear(32,1))
opt = torch.optim.Adam(khrongkhai1.parameters(), lr=0.05)
ha_mse =  torch.nn.MSELoss()

for t in range(1000):
    J = ha_mse(khrongkhai1(x[:,None]).flatten(),z)
    J.backward()
    opt.step()
    opt.zero_grad()

torch.save(khrongkhai1,'khrongkhai.pkl')

จากนั้นก็ใช้ torch.load เพื่อทำการโหลดขึ้นมา แล้วก็นำมาใช้ได้
khrongkhai2 = torch.load('khrongkhai.pkl')

mx = torch.linspace(0,6,200)
mz = khrongkhai2(mx[:,None]).flatten().data
import matplotlib.pyplot as plt
plt.plot(mx.numpy(),mz.numpy(),'r')
plt.scatter(x.numpy(),z.numpy(),c='b',edgecolor='y')
plt.show()



โครงข่ายที่ถูกโหลดขึ้นมาใหม่สามารถนำมาใช้งานได้ทันทีเมื่อไหร่ก็ได้



การบันทึกแค่พารามิเตอร์

บางครั้งเราอาจไม่จำเป็นต้องบันทึกทั้งแบบจำลอง บันทึกแค่พารามิเตอร์ก็พอแล้ว กรณีแบบนี้ให้ใช้เมธอด .state_dict() จะเป็นการสร้างดิกที่เก็บค่าพารามิเตอร์ทั้งหมด
dic = khrongkhai1.state_dict()
print(dic.keys()) # ได้ ['0.weight', '0.bias', '2.weight', '2.bias']
print(dic['2.bias']) # ได้ tensor([-0.3450])

เอาดิกที่ได้มาบันทึกเก็บไว้ได้ด้วย torch.save() เช่นกัน
torch.save(dic,'param.pkl')

จากนั้นก็สร้างโครงข่ายใหม่ให้เหมือนเดิม แล้วโหลดดิกขึ้นมา แล้วใช้เมธอด .load_state_dict()
khrongkhai3 = torch.nn.Sequential(
    torch.nn.Linear(1,32),
    torch.nn.ReLU(),
    torch.nn.Linear(32,1))
dic_param = torch.load('param.pkl')
khrongkhai3.load_state_dict(dic_param)

เท่านี้ก็นำกลับมาใช้ได้เช่นเดียวกัน

เพียงแต่ว่าถ้าโครงข่ายที่สร้างขึ้นมาใหม่มีโครงสร้างที่ต่างไปพอโหลดก็จะมีข้อผิดพลาด



แบบจำลองที่ฝึกด้วย GPU

บางครั้งเราอาจทำการฝึกในคอมเครื่องที่มี GPU แล้วเอาแบบจำลองที่ฝึกเสร็จมาใช้ในเครื่องอื่นๆซึ่งไม่ได้ติดตั้ง GPU ไว้ ในกรณีแบบนั้นก่อนบันทึกให้ทำการแปลงแบบจำลองกลับมาอยู่ใน cpu ก่อน ไม่เช่นนั้นพอไปโหลดใหม่จะ error

แต่กรณีที่บันทึกโดยไม่ได้แปลงกลับไปแล้ว ก็ยังมีวิธีที่ทำให้สามารถเปิดได้ แต่ต้องเขียนเพิ่มโดยเติม lambda s,_:s ตอนโหลด
dic_param = torch.load('param.pkl',lambda s,_:s)

เพียงแต่วิธีนี้แก้ได้แค่ในกรณีที่ใช้ GPU ตัวเดียว



>> อ่านต่อ บทที่ ๑๖


-----------------------------------------

囧囧囧囧囧囧囧囧囧囧囧囧囧囧囧囧囧囧囧囧囧囧囧囧囧

ดูสถิติของหน้านี้

หมวดหมู่

-- คอมพิวเตอร์ >> ปัญญาประดิษฐ์ >> โครงข่ายประสาทเทียม
-- คอมพิวเตอร์ >> เขียนโปรแกรม >> python >> pytorch

ไม่อนุญาตให้นำเนื้อหาของบทความไปลงที่อื่นโดยไม่ได้ขออนุญาตโดยเด็ดขาด หากต้องการนำบางส่วนไปลงสามารถทำได้โดยต้องไม่ใช่การก๊อปแปะแต่ให้เปลี่ยนคำพูดเป็นของตัวเอง หรือไม่ก็เขียนในลักษณะการยกข้อความอ้างอิง และไม่ว่ากรณีไหนก็ตาม ต้องให้เครดิตพร้อมใส่ลิงก์ของทุกบทความที่มีการใช้เนื้อหาเสมอ

目录

从日本来的名言
模块
-- numpy
-- matplotlib

-- pandas
-- manim
-- opencv
-- pyqt
-- pytorch
机器学习
-- 神经网络
javascript
蒙古语
语言学
maya
概率论
与日本相关的日记
与中国相关的日记
-- 与北京相关的日记
-- 与香港相关的日记
-- 与澳门相关的日记
与台湾相关的日记
与北欧相关的日记
与其他国家相关的日记
qiita
其他日志

按类别分日志



ติดตามอัปเดตของบล็อกได้ที่แฟนเพจ

  查看日志

  推荐日志

ตัวอักษรกรีกและเปรียบเทียบการใช้งานในภาษากรีกโบราณและกรีกสมัยใหม่
ที่มาของอักษรไทยและความเกี่ยวพันกับอักษรอื่นๆในตระกูลอักษรพราหมี
การสร้างแบบจำลองสามมิติเป็นไฟล์ .obj วิธีการอย่างง่ายที่ไม่ว่าใครก็ลองทำได้ทันที
รวมรายชื่อนักร้องเพลงกวางตุ้ง
ภาษาจีนแบ่งเป็นสำเนียงอะไรบ้าง มีความแตกต่างกันมากแค่ไหน
ทำความเข้าใจระบอบประชาธิปไตยจากประวัติศาสตร์ความเป็นมา
เรียนรู้วิธีการใช้ regular expression (regex)
การใช้ unix shell เบื้องต้น ใน linux และ mac
g ในภาษาญี่ปุ่นออกเสียง "ก" หรือ "ง" กันแน่
ทำความรู้จักกับปัญญาประดิษฐ์และการเรียนรู้ของเครื่อง
ค้นพบระบบดาวเคราะห์ ๘ ดวง เบื้องหลังความสำเร็จคือปัญญาประดิษฐ์ (AI)
หอดูดาวโบราณปักกิ่ง ตอนที่ ๑: แท่นสังเกตการณ์และสวนดอกไม้
พิพิธภัณฑ์สถาปัตยกรรมโบราณปักกิ่ง
เที่ยวเมืองตานตง ล่องเรือในน่านน้ำเกาหลีเหนือ
ตระเวนเที่ยวตามรอยฉากของอนิเมะในญี่ปุ่น
เที่ยวชมหอดูดาวที่ฐานสังเกตการณ์ซิงหลง
ทำไมจึงไม่ควรเขียนวรรณยุกต์เวลาทับศัพท์ภาษาต่างประเทศ