φυβλαςのβλογ
บล็อกของ phyblas



pytorch เบื้องต้น บทที่ ๑๕: การบันทึกและอ่านแบบจำลองที่เรียนรู้เสร็จแล้ว
เขียนเมื่อ 2018/09/26 09:24
>> ต่อจาก บทที่ ๑๔



ในการใช้โครงข่ายประสาทเทียม จะเห็นว่าขั้นตอนประกอบไปด้วย
๑. สร้างแบบจำลอง
๒. ฝึกแบบจำลอง
๓. นำแบบจำลองที่ฝึกเสร็จแล้วมาใช้

ขั้นตอนที่กินเวลานานสุดคือขั้นตอนการฝึก แต่พอฝึกเสร็จแล้วเราอาจนำไปใช้ได้ตลอด ไม่ต้องกลับมาฝึกต่อแล้ว

ดังนั้นเพื่อให้ไม่ต้องฝึกใหม่ตลอด เราจำเป็นจะต้องบันทึกผลของการฝึก นั่นก็คือบันทึกพารามิเตอร์ที่ได้มาหลังเรียนรู้

pytorch ได้มีเตรียมคำสั่งสำหรับช่วยทำการบันทึกเก็บแบบจำลองหรือค่าพารามิเตอร์ที่เรียนรู้เสร็จแล้วได้อย่างง่าย มีอยู่ ๒ วิธี คือ เก็บทั้งตัวแบบจำลอง หรือจะเก็บแค่พารามิเตอร์



บันทึกตัวแบบจำลอง

ใช้คำสั่ง torch.save กับแบบจำลองที่เรียนรู้เสร็จแล้ว แบบจำลองจะถูกบันทึกในรูปแบบของ pickle ซึ่งเป็นรูปแบบการเก็บไฟล์แบบมาตรฐานของไพธอน เพียงแต่คำสั่งนี้ของ pytorch ช่วยให้สามารถทำการบันทึกเก็บได้โดยเขียนแค่สั้นๆในบรรทัดเดียวจึงมีความสะดวกมาก

ตัวอย่าง ลองสร้างโครงข่ายง่ายๆแล้วทำการฝึก ฝึกเสร็จก็บันทึก
import torch
x = torch.rand(100)*6
z = x.sin() + torch.randn(100)/6
khrongkhai1 = torch.nn.Sequential(
    torch.nn.Linear(1,32),
    torch.nn.ReLU(),
    torch.nn.Linear(32,1))
opt = torch.optim.Adam(khrongkhai1.parameters(), lr=0.05)
ha_mse =  torch.nn.MSELoss()

for t in range(1000):
    J = ha_mse(khrongkhai1(x[:,None]).flatten(),z)
    J.backward()
    opt.step()
    opt.zero_grad()

torch.save(khrongkhai1,'khrongkhai.pkl')

จากนั้นก็ใช้ torch.load เพื่อทำการโหลดขึ้นมา แล้วก็นำมาใช้ได้
khrongkhai2 = torch.load('khrongkhai.pkl')

mx = torch.linspace(0,6,200)
mz = khrongkhai2(mx[:,None]).flatten().data
import matplotlib.pyplot as plt
plt.plot(mx.numpy(),mz.numpy(),'r')
plt.scatter(x.numpy(),z.numpy(),c='b',edgecolor='y')
plt.show()



โครงข่ายที่ถูกโหลดขึ้นมาใหม่สามารถนำมาใช้งานได้ทันทีเมื่อไหร่ก็ได้



การบันทึกแค่พารามิเตอร์

บางครั้งเราอาจไม่จำเป็นต้องบันทึกทั้งแบบจำลอง บันทึกแค่พารามิเตอร์ก็พอแล้ว กรณีแบบนี้ให้ใช้เมธอด .state_dict() จะเป็นการสร้างดิกที่เก็บค่าพารามิเตอร์ทั้งหมด
dic = khrongkhai1.state_dict()
print(dic.keys()) # ได้ ['0.weight', '0.bias', '2.weight', '2.bias']
print(dic['2.bias']) # ได้ tensor([-0.3450])

เอาดิกที่ได้มาบันทึกเก็บไว้ได้ด้วย torch.save() เช่นกัน
torch.save(dic,'param.pkl')

จากนั้นก็สร้างโครงข่ายใหม่ให้เหมือนเดิม แล้วโหลดดิกขึ้นมา แล้วใช้เมธอด .load_state_dict()
khrongkhai3 = torch.nn.Sequential(
    torch.nn.Linear(1,32),
    torch.nn.ReLU(),
    torch.nn.Linear(32,1))
dic_param = torch.load('param.pkl')
khrongkhai3.load_state_dict(dic_param)

เท่านี้ก็นำกลับมาใช้ได้เช่นเดียวกัน

เพียงแต่ว่าถ้าโครงข่ายที่สร้างขึ้นมาใหม่มีโครงสร้างที่ต่างไปพอโหลดก็จะมีข้อผิดพลาด



แบบจำลองที่ฝึกด้วย GPU

บางครั้งเราอาจทำการฝึกในคอมเครื่องที่มี GPU แล้วเอาแบบจำลองที่ฝึกเสร็จมาใช้ในเครื่องอื่นๆซึ่งไม่ได้ติดตั้ง GPU ไว้ ในกรณีแบบนั้นก่อนบันทึกให้ทำการแปลงแบบจำลองกลับมาอยู่ใน cpu ก่อน ไม่เช่นนั้นพอไปโหลดใหม่จะ error

แต่กรณีที่บันทึกโดยไม่ได้แปลงกลับไปแล้ว ก็ยังมีวิธีที่ทำให้สามารถเปิดได้ แต่ต้องเขียนเพิ่มโดยเติม lambda s,_:s ตอนโหลด
dic_param = torch.load('param.pkl',lambda s,_:s)

เพียงแต่วิธีนี้แก้ได้แค่ในกรณีที่ใช้ GPU ตัวเดียว



>> อ่านต่อ บทที่ ๑๖


-----------------------------------------

囧囧囧囧囧囧囧囧囧囧囧囧囧囧囧囧囧囧囧囧囧囧囧囧囧

囧囧囧囧囧囧囧囧囧囧囧囧囧囧囧囧囧囧囧囧囧囧囧囧囧

-----------------------------------------

หมวดหมู่

-- คอมพิวเตอร์ >> ปัญญาประดิษฐ์ >> โครงข่ายประสาทเทียม
-- คอมพิวเตอร์ >> เขียนโปรแกรม >> python >> pytorch

ไม่อนุญาตให้นำเนื้อหาของบทความไปลงที่อื่นโดยไม่ได้ขออนุญาตโดยเด็ดขาด หากต้องการนำบางส่วนไปลงสามารถทำได้โดยต้องไม่ใช่การก๊อปแปะแต่ให้เปลี่ยนคำพูดเป็นของตัวเอง หรือไม่ก็เขียนในลักษณะการยกข้อความอ้างอิง และไม่ว่ากรณีไหนก็ตาม ต้องให้เครดิตพร้อมใส่ลิงก์ของทุกบทความที่มีการใช้เนื้อหาเสมอ

สารบัญ

รวมคำแปลวลีเด็ดจากญี่ปุ่น
python
-- numpy
-- matplotlib

-- pandas
-- pytorch
maya
การเรียนรู้ของเครื่อง
-- โครงข่ายประสาทเทียม
บันทึกในญี่ปุ่น
บันทึกในจีน
-- บันทึกในปักกิ่ง
บันทึกในไต้หวัน
บันทึกในยุโรปเหนือ
บันทึกในประเทศอื่นๆ
เรียนภาษาจีน
บทความอื่นๆ

บทความแบ่งตามหมวด



ติดตามอัพเดตของบล็อกได้ที่แฟนเพจ

  ค้นหาบทความ

  บทความแนะนำ

หลักการเขียนทับศัพท์ภาษาจีนกลาง
g ในภาษาญี่ปุ่นออกเสียง "ก" หรือ "ง" กันแน่
ค้นพบระบบดาวเคราะห์ ๘ ดวง เบื้องหลังความสำเร็จคือปัญญาประดิษฐ์ (AI)
หอดูดาวโบราณปักกิ่ง ตอนที่ ๑: แท่นสังเกตการณ์และสวนดอกไม้
พิพิธภัณฑ์สถาปัตยกรรมโบราณปักกิ่ง
บ้านเก่าของจางเสวียเหลียงในเทียนจิน
เที่ยวจิ่นโจว ๓ วัน ๒ คืน 23 - 25 พ.ค. 2015
เที่ยวเมืองตานตง ล่องเรือในน่านน้ำเกาหลีเหนือ
บันทึกการเที่ยวสวีเดน 1-12 พ.ค. 2014
แนะนำองค์การวิจัยและพัฒนาการสำรวจอวกาศญี่ปุ่น (JAXA)
เที่ยวฮ่องกงในคืนคริสต์มาสอีฟ เดินทางไกลจากสนามบินมาทานติ่มซำอร่อยโต้รุ่ง
เล่าประสบการณ์ค่ายอบรมวิชาการทางดาราศาสตร์โดยโซวเคนได 10 - 16 พ.ย. 2013
ตระเวนเที่ยวตามรอยฉากของอนิเมะในญี่ปุ่น
เที่ยวชมหอดูดาวที่ฐานสังเกตการณ์ซิงหลง
บันทึกการเที่ยวญี่ปุ่นครั้งแรกในชีวิต - ทุกอย่างเริ่มต้นที่สนามบินนานาชาติคันไซ
หลักการเขียนคำทับศัพท์ภาษาญี่ปุ่น
ทำไมจึงไม่ควรเขียนวรรณยุกต์เวลาทับศัพท์ภาษาต่างประเทศ
ทำไมถึงอยากมาเรียนต่อนอก
เหตุผลอะไรที่ต้องใช้ภาษาวิบัติ?

ไทย

日本語

中文