>> ต่อจาก
บทที่ ๙ ในบทที่ผ่านๆมาในการเรียนรู้แต่ละรอบของโครงข่ายประสาทเทียมเราใช้ข้อมูลที่มีอยู่ทั้งหมดมาคำนวณค่าเสียหายแล้วปรับพารามิเตอร์ทีเดียว
แต่โดยทั่วไปในการใช้งานจริงถ้าข้อมูลมาจำนวนมากวิธีที่นิยมกว่าก็คือการแบ่งข้อมูลออกเป็นกลุ่มย่อยๆแล้วคำนวณค่าเสียหายแล้วปรับพารามิเตอร์โดยใช้ข้อมูลทีละกลุ่ม วิธีการแบบนั้นเรียกว่า
มินิแบตช์ (minibatch) ในเนื้อหา
โครงข่ายประสาทเทียมเบื้องต้น บทที่ ๑๕ ได้เขียนวิธีการทำมินิแบตช์ไปแล้ว
แต่ใน pytorch ได้เตรียมคำสั่งที่สะดวกในการทำมินิแบตช์เอาไว้ นั่นคือใช้คลาส TensorDataset และ DataLoader ในมอดูลย่อย torch.utils.data
มอดูลย่อย torch.utils.data นี้จะไม่ถูก import มาถ้าเราแค่ import torch ดังนั้นจึงต้องทำการสั่ง import ตัวมันโดยตรง ในที่นี้จะทำการ import คลาสทั้ง ๒ ตัวนี้ในลักษณะนี้
import torch
from torch.utils.data import TensorDataset as Tenda
from torch.utils.data import DataLoader as Dalo
เพื่อความสะดวกจึงได้ย่อชื่อให้สั้นลงด้วย
ส่วนวิธีการใช้ก็คือ ก่อนอื่นสร้างเทนเซอร์ข้อมูลที่ต้องการทำมินิแบตช์แล้วนำมาใส่ในออบเจ็กต์ TensorDataset
เพื่อให้เห็นภาพง่ายขอใช้ข้อมูลง่ายๆนี้เป็นตัวอย่าง
X = torch.arange(8)[:,None]+torch.LongTensor([0,8])
z = torch.arange(8)*10
print(X)
print(z)
ได้
tensor([[ 0, 8],
[ 1, 9],
[ 2, 10],
[ 3, 11],
[ 4, 12],
[ 5, 13],
[ 6, 14],
[ 7, 15]])
tensor([ 0, 10, 20, 30, 40, 50, 60, 70])
จากนั้นนำข้อมูลนี้มาสร้างออบเจ็กต์ TensorDataset
data = Tenda(X,z)
print(data) # ได้ <torch.utils.data.dataset.TensorDataset object at 0x000001D2623848B0>
แล้วก็นำมาป้อนให้ DataLoader โดยกำหนดขนาดมินิแบตช์ที่ต้องการ
dalo = Dalo(data,batch_size=3)
print(dalo) # ได้ <torch.utils.data.dataloader.DataLoader object at 0x000001D260FD7FA0>
ออบเจ็กต์ที่ได้นี้เป็นเจเนอเรเตอร์ จะทำงานเมื่อถูกเรียกใช้ด้วย for แล้วจะวนซ้ำส่งข้อมูลออกมาตามลำดับ
ในที่นี้กำหนด batch_size=3 ข้อมูลจะออกมาทีละ 3 ตัว ยกเว้นรอบสุดท้ายเหลือแค่ 2 จะออกมาแค่ 2
for Xb,zb in dalo:
print(Xb,zb)
ได้
tensor([[ 0, 8],
[ 1, 9],
[ 2, 10]]) tensor([ 0, 10, 20])
tensor([[ 3, 11],
[ 4, 12],
[ 5, 13]]) tensor([30, 40, 50])
tensor([[ 6, 14],
[ 7, 15]]) tensor([60, 70])
ถ้านำมาเข้า for ซ้ำอีกรอบก็จะให้ข้อมูลแบบเดิมออกมาอีก
แต่ถ้ากำหนดให้ shuffle=True (หรือ 1) ไว้ก็จะทำให้ข้อมูลสุ่มลำดับ แต่ละรอบจะเรียงลำดับและแบ่งกลุ่มออกมาไม่ซ้ำกัน
dalo = Dalo(data,batch_size=3,shuffle=1)
for i in range(2):
for Xb,zb in dalo:
print(Xb,zb)
print('----------')
ได้
tensor([[ 1, 9],
[ 5, 13],
[ 6, 14]]) tensor([10, 50, 60])
tensor([[ 0, 8],
[ 2, 10],
[ 3, 11]]) tensor([ 0, 20, 30])
tensor([[ 7, 15],
[ 4, 12]]) tensor([70, 40])
----------
tensor([[ 2, 10],
[ 5, 13],
[ 0, 8]]) tensor([20, 50, 0])
tensor([[ 4, 12],
[ 7, 15],
[ 6, 14]]) tensor([40, 70, 60])
tensor([[ 1, 9],
[ 3, 11]]) tensor([10, 30])
----------
นอกจากนี้ยังมีค่าอื่นๆที่อาจปรับแต่งเพิ่มเติมอีกได้เช่น
drop_last ถ้าเป็น True จะตัดส่วนท้ายสุดทิ้งถ้ามีไม่ครบจำนวน batch_size
num_workers กำหนดจำนวนซับโพรเซสสำหรับรันคู่ขนาน
เพื่อความสะดวกจึงอาจลองตั้งเป็นฟังก์ชันไว้แล้วนำมาใช้แบบนี้ได้
def sang_minibatch(X,z,batch_size,shuffle=True):
dataset = Tenda(X,z)
return Dalo(dataset,batch_size,shuffle)
เนื่องจากปกติแล้วการให้แต่ละรอบมีการสุ่มสลับไม่เหมือนกันจะดีกว่า แต่ค่าตั้งต้นกลับให้ shuffle=False ดังนั้นในที่นี้จึงสร้างใหม่โดยตั้ง shuffle=True เป็นค่าตั้งต้น
เวลาใช้ก็แค่ใส่ข้อมูลตัวแปรต้น X และตัวแปรตาม z แล้วก็จำนวนมินิแบตช์ แล้วก็จะได้ตัวเจเนอเรเตอร์สำหรับวนซ้ำเพื่อทำมินิแบตช์มา
ต่อมา อาจลองสร้างคลาสที่ทำการเรียนรู้แบบมินิแบตช์ขึ้นมาได้ดังนี้
relu = torch.nn.ReLU()
ha_entropy = torch.nn.CrossEntropyLoss()
class Prasat(torch.nn.Sequential):
def __init__(self,m,eta=0.01):
super(Prasat,self).__init__()
for i in range(1,len(m)):
self.add_module('lin%d'%i,torch.nn.Linear(m[i-1],m[i]))
if(i<len(m)-1):
self.add_module('relu%d'%i,relu)
self.opt = torch.optim.Adam(self.parameters(),lr=eta)
def rianru(self,X,z,n_thamsam,n_batch=64):
X = torch.Tensor(X)
z = torch.LongTensor(z)
minibatch = sang_minibatch(X,z,n_batch)
for o in range(n_thamsam):
for Xb,zb in minibatch:
ha_entropy(self(Xb),zb).backward()
self.opt.step()
self.opt.zero_grad()
def thamnai(self,X):
X = torch.Tensor(X)
return self(X).argmax(1).numpy()
จากนั้นก็ลองนำมาใช้เพื่อวิเคราะห์จำแนกข้อมูล ๖ กลุ่มดู
import numpy as np
import matplotlib.pyplot as plt
z = np.arange(6).repeat(120)
r = np.random.normal(z+1,0.2)
t = np.random.uniform(-0.5,0.5,720)*np.pi
x = r*np.cos(t)
y = r*np.sin(t)
X = np.array([x,y]).T
prasat = Prasat([2,60,40,6],eta=0.01)
prasat.rianru(X,z,50)
mx,my = np.meshgrid(np.linspace(x.min(),x.max(),200),np.linspace(y.min(),y.max(),200))
mX = np.array([mx.ravel(),my.ravel()]).T
mz = prasat.thamnai(mX)
mz = mz.reshape(200,200)
plt.xlim(x.min(),x.max())
plt.ylim(y.min(),y.max())
plt.contourf(mx,my,mz,alpha=0.2,cmap='terrain')
plt.scatter(x,y,c=z,edgecolor='k',cmap='terrain')
plt.show()