ในการใช้โครงข่ายประสาทเทียม จะเห็นว่าขั้นตอนประกอบไปด้วย
๑. สร้างแบบจำลอง
๒. ฝึกแบบจำลอง
๓. นำแบบจำลองที่ฝึกเสร็จแล้วมาใช้
ขั้นตอนที่กินเวลานานสุดคือขั้นตอนการฝึก แต่พอฝึกเสร็จแล้วเราอาจนำไปใช้ได้ตลอด ไม่ต้องกลับมาฝึกต่อแล้ว
ดังนั้นเพื่อให้ไม่ต้องฝึกใหม่ตลอด เราจำเป็นจะต้องบันทึกผลของการฝึก นั่นก็คือบันทึกพารามิเตอร์ที่ได้มาหลังเรียนรู้
pytorch ได้มีเตรียมคำสั่งสำหรับช่วยทำการบันทึกเก็บแบบจำลองหรือค่าพารามิเตอร์ที่เรียนรู้เสร็จแล้วได้อย่างง่าย มีอยู่ ๒ วิธี คือ เก็บทั้งตัวแบบจำลอง หรือจะเก็บแค่พารามิเตอร์
บันทึกตัวแบบจำลอง ใช้คำสั่ง torch.save กับแบบจำลองที่เรียนรู้เสร็จแล้ว แบบจำลองจะถูกบันทึกในรูปแบบของ pickle ซึ่งเป็นรูปแบบการเก็บไฟล์แบบมาตรฐานของไพธอน เพียงแต่คำสั่งนี้ของ pytorch ช่วยให้สามารถทำการบันทึกเก็บได้โดยเขียนแค่สั้นๆในบรรทัดเดียวจึงมีความสะดวกมาก
ตัวอย่าง ลองสร้างโครงข่ายง่ายๆแล้วทำการฝึก ฝึกเสร็จก็บันทึก
import torch
x = torch.rand(100)*6
z = x.sin() + torch.randn(100)/6
khrongkhai1 = torch.nn.Sequential(
torch.nn.Linear(1,32),
torch.nn.ReLU(),
torch.nn.Linear(32,1))
opt = torch.optim.Adam(khrongkhai1.parameters(), lr=0.05)
ha_mse = torch.nn.MSELoss()
for t in range(1000):
J = ha_mse(khrongkhai1(x[:,None]).flatten(),z)
J.backward()
opt.step()
opt.zero_grad()
torch.save(khrongkhai1,'khrongkhai.pkl')
จากนั้นก็ใช้ torch.load เพื่อทำการโหลดขึ้นมา แล้วก็นำมาใช้ได้
khrongkhai2 = torch.load('khrongkhai.pkl')
mx = torch.linspace(0,6,200)
mz = khrongkhai2(mx[:,None]).flatten().data
import matplotlib.pyplot as plt
plt.plot(mx.numpy(),mz.numpy(),'r')
plt.scatter(x.numpy(),z.numpy(),c='b',edgecolor='y')
plt.show()
โครงข่ายที่ถูกโหลดขึ้นมาใหม่สามารถนำมาใช้งานได้ทันทีเมื่อไหร่ก็ได้
การบันทึกแค่พารามิเตอร์ บางครั้งเราอาจไม่จำเป็นต้องบันทึกทั้งแบบจำลอง บันทึกแค่พารามิเตอร์ก็พอแล้ว กรณีแบบนี้ให้ใช้เมธอด .state_dict() จะเป็นการสร้างดิกที่เก็บค่าพารามิเตอร์ทั้งหมด
dic = khrongkhai1.state_dict()
print(dic.keys()) # ได้ ['0.weight', '0.bias', '2.weight', '2.bias']
print(dic['2.bias']) # ได้ tensor([-0.3450])
เอาดิกที่ได้มาบันทึกเก็บไว้ได้ด้วย torch.save() เช่นกัน
torch.save(dic,'param.pkl')
จากนั้นก็สร้างโครงข่ายใหม่ให้เหมือนเดิม แล้วโหลดดิกขึ้นมา แล้วใช้เมธอด .load_state_dict()
khrongkhai3 = torch.nn.Sequential(
torch.nn.Linear(1,32),
torch.nn.ReLU(),
torch.nn.Linear(32,1))
dic_param = torch.load('param.pkl')
khrongkhai3.load_state_dict(dic_param)
เท่านี้ก็นำกลับมาใช้ได้เช่นเดียวกัน
เพียงแต่ว่าถ้าโครงข่ายที่สร้างขึ้นมาใหม่มีโครงสร้างที่ต่างไปพอโหลดก็จะมีข้อผิดพลาด
แบบจำลองที่ฝึกด้วย GPU บางครั้งเราอาจทำการฝึกในคอมเครื่องที่มี GPU แล้วเอาแบบจำลองที่ฝึกเสร็จมาใช้ในเครื่องอื่นๆซึ่งไม่ได้ติดตั้ง GPU ไว้ ในกรณีแบบนั้นก่อนบันทึกให้ทำการแปลงแบบจำลองกลับมาอยู่ใน cpu ก่อน ไม่เช่นนั้นพอไปโหลดใหม่จะ error
แต่กรณีที่บันทึกโดยไม่ได้แปลงกลับไปแล้ว ก็ยังมีวิธีที่ทำให้สามารถเปิดได้ แต่ต้องเขียนเพิ่มโดยเติม lambda s,_:s ตอนโหลด
dic_param = torch.load('param.pkl',lambda s,_:s)
เพียงแต่วิธีนี้แก้ได้แค่ในกรณีที่ใช้ GPU ตัวเดียว