φυβλαςのβλογ
บล็อกของ phyblas



pytorch เบื้องต้น บทที่ ๑๐: มินิแบตช์
เขียนเมื่อ 2018/09/10 09:06
แก้ไขล่าสุด 2022/07/09 16:00
>> ต่อจาก บทที่ ๙



ในบทที่ผ่านๆมาในการเรียนรู้แต่ละรอบของโครงข่ายประสาทเทียมเราใช้ข้อมูลที่มีอยู่ทั้งหมดมาคำนวณค่าเสียหายแล้วปรับพารามิเตอร์ทีเดียว

แต่โดยทั่วไปในการใช้งานจริงถ้าข้อมูลมาจำนวนมากวิธีที่นิยมกว่าก็คือการแบ่งข้อมูลออกเป็นกลุ่มย่อยๆแล้วคำนวณค่าเสียหายแล้วปรับพารามิเตอร์โดยใช้ข้อมูลทีละกลุ่ม วิธีการแบบนั้นเรียกว่ามินิแบตช์ (minibatch)

ในเนื้อหาโครงข่ายประสาทเทียมเบื้องต้น บทที่ ๑๕ ได้เขียนวิธีการทำมินิแบตช์ไปแล้ว

แต่ใน pytorch ได้เตรียมคำสั่งที่สะดวกในการทำมินิแบตช์เอาไว้ นั่นคือใช้คลาส TensorDataset และ DataLoader ในมอดูลย่อย torch.utils.data

มอดูลย่อย torch.utils.data นี้จะไม่ถูก import มาถ้าเราแค่ import torch ดังนั้นจึงต้องทำการสั่ง import ตัวมันโดยตรง ในที่นี้จะทำการ import คลาสทั้ง ๒ ตัวนี้ในลักษณะนี้
import torch
from torch.utils.data import TensorDataset as Tenda
from torch.utils.data import DataLoader as Dalo
เพื่อความสะดวกจึงได้ย่อชื่อให้สั้นลงด้วย

ส่วนวิธีการใช้ก็คือ ก่อนอื่นสร้างเทนเซอร์ข้อมูลที่ต้องการทำมินิแบตช์แล้วนำมาใส่ในออบเจ็กต์ TensorDataset

เพื่อให้เห็นภาพง่ายขอใช้ข้อมูลง่ายๆนี้เป็นตัวอย่าง
X = torch.arange(8)[:,None]+torch.LongTensor([0,8])
z = torch.arange(8)*10
print(X)
print(z)

ได้
tensor([[ 0,  8],
        [ 1,  9],
        [ 2, 10],
        [ 3, 11],
        [ 4, 12],
        [ 5, 13],
        [ 6, 14],
        [ 7, 15]])
tensor([ 0, 10, 20, 30, 40, 50, 60, 70])

จากนั้นนำข้อมูลนี้มาสร้างออบเจ็กต์ TensorDataset
data = Tenda(X,z)
print(data) # ได้ <torch.utils.data.dataset.TensorDataset object at 0x000001D2623848B0>
แล้วก็นำมาป้อนให้ DataLoader โดยกำหนดขนาดมินิแบตช์ที่ต้องการ
dalo = Dalo(data,batch_size=3)
print(dalo) # ได้ <torch.utils.data.dataloader.DataLoader object at 0x000001D260FD7FA0>

ออบเจ็กต์ที่ได้นี้เป็นเจเนอเรเตอร์ จะทำงานเมื่อถูกเรียกใช้ด้วย for แล้วจะวนซ้ำส่งข้อมูลออกมาตามลำดับ

ในที่นี้กำหนด batch_size=3 ข้อมูลจะออกมาทีละ 3 ตัว ยกเว้นรอบสุดท้ายเหลือแค่ 2 จะออกมาแค่ 2
for Xb,zb in dalo:
    print(Xb,zb)

ได้
tensor([[ 0,  8],
        [ 1,  9],
        [ 2, 10]]) tensor([ 0, 10, 20])
tensor([[ 3, 11],
        [ 4, 12],
        [ 5, 13]]) tensor([30, 40, 50])
tensor([[ 6, 14],
        [ 7, 15]]) tensor([60, 70])

ถ้านำมาเข้า for ซ้ำอีกรอบก็จะให้ข้อมูลแบบเดิมออกมาอีก

แต่ถ้ากำหนดให้ shuffle=True (หรือ 1) ไว้ก็จะทำให้ข้อมูลสุ่มลำดับ แต่ละรอบจะเรียงลำดับและแบ่งกลุ่มออกมาไม่ซ้ำกัน
dalo = Dalo(data,batch_size=3,shuffle=1)
for i in range(2):
    for Xb,zb in dalo:
        print(Xb,zb)
    print('----------')
ได้
tensor([[ 1,  9],
        [ 5, 13],
        [ 6, 14]]) tensor([10, 50, 60])
tensor([[ 0,  8],
        [ 2, 10],
        [ 3, 11]]) tensor([ 0, 20, 30])
tensor([[ 7, 15],
        [ 4, 12]]) tensor([70, 40])
----------
tensor([[ 2, 10],
        [ 5, 13],
        [ 0,  8]]) tensor([20, 50,  0])
tensor([[ 4, 12],
        [ 7, 15],
        [ 6, 14]]) tensor([40, 70, 60])
tensor([[ 1,  9],
        [ 3, 11]]) tensor([10, 30])
----------

นอกจากนี้ยังมีค่าอื่นๆที่อาจปรับแต่งเพิ่มเติมอีกได้เช่น

drop_last ถ้าเป็น True จะตัดส่วนท้ายสุดทิ้งถ้ามีไม่ครบจำนวน batch_size

num_workers กำหนดจำนวนซับโพรเซสสำหรับรันคู่ขนาน



เพื่อความสะดวกจึงอาจลองตั้งเป็นฟังก์ชันไว้แล้วนำมาใช้แบบนี้ได้
def sang_minibatch(X,z,batch_size,shuffle=True):
    dataset = Tenda(X,z)
    return Dalo(dataset,batch_size,shuffle)

เนื่องจากปกติแล้วการให้แต่ละรอบมีการสุ่มสลับไม่เหมือนกันจะดีกว่า แต่ค่าตั้งต้นกลับให้ shuffle=False ดังนั้นในที่นี้จึงสร้างใหม่โดยตั้ง shuffle=True เป็นค่าตั้งต้น

เวลาใช้ก็แค่ใส่ข้อมูลตัวแปรต้น X และตัวแปรตาม z แล้วก็จำนวนมินิแบตช์ แล้วก็จะได้ตัวเจเนอเรเตอร์สำหรับวนซ้ำเพื่อทำมินิแบตช์มา

ต่อมา อาจลองสร้างคลาสที่ทำการเรียนรู้แบบมินิแบตช์ขึ้นมาได้ดังนี้
relu = torch.nn.ReLU()
ha_entropy = torch.nn.CrossEntropyLoss()

class Prasat(torch.nn.Sequential):
    def __init__(self,m,eta=0.01):
        super(Prasat,self).__init__()
        for i in range(1,len(m)):
            self.add_module('lin%d'%i,torch.nn.Linear(m[i-1],m[i]))
            if(i<len(m)-1):
                self.add_module('relu%d'%i,relu)
        self.opt = torch.optim.Adam(self.parameters(),lr=eta)
        
    
    def rianru(self,X,z,n_thamsam,n_batch=64):
        X = torch.Tensor(X)
        z = torch.LongTensor(z)
        minibatch = sang_minibatch(X,z,n_batch)
        for o in range(n_thamsam):
            for Xb,zb in minibatch:
                ha_entropy(self(Xb),zb).backward()
                self.opt.step()
                self.opt.zero_grad()
            
    def thamnai(self,X):
        X = torch.Tensor(X)
        return self(X).argmax(1).numpy()

จากนั้นก็ลองนำมาใช้เพื่อวิเคราะห์จำแนกข้อมูล ๖ กลุ่มดู
import numpy as np
import matplotlib.pyplot as plt

z = np.arange(6).repeat(120)
r = np.random.normal(z+1,0.2)
t = np.random.uniform(-0.5,0.5,720)*np.pi
x = r*np.cos(t)
y = r*np.sin(t)
X = np.array([x,y]).T

prasat = Prasat([2,60,40,6],eta=0.01)
prasat.rianru(X,z,50)

mx,my = np.meshgrid(np.linspace(x.min(),x.max(),200),np.linspace(y.min(),y.max(),200))
mX = np.array([mx.ravel(),my.ravel()]).T
mz = prasat.thamnai(mX)
mz = mz.reshape(200,200)
plt.xlim(x.min(),x.max())
plt.ylim(y.min(),y.max())
plt.contourf(mx,my,mz,alpha=0.2,cmap='terrain')
plt.scatter(x,y,c=z,edgecolor='k',cmap='terrain')
plt.show()
 


>> อ่านต่อ บทที่ ๑๑


-----------------------------------------

囧囧囧囧囧囧囧囧囧囧囧囧囧囧囧囧囧囧囧囧囧囧囧囧囧

ดูสถิติของหน้านี้

หมวดหมู่

-- คอมพิวเตอร์ >> ปัญญาประดิษฐ์ >> โครงข่ายประสาทเทียม
-- คอมพิวเตอร์ >> เขียนโปรแกรม >> python >> pytorch

ไม่อนุญาตให้นำเนื้อหาของบทความไปลงที่อื่นโดยไม่ได้ขออนุญาตโดยเด็ดขาด หากต้องการนำบางส่วนไปลงสามารถทำได้โดยต้องไม่ใช่การก๊อปแปะแต่ให้เปลี่ยนคำพูดเป็นของตัวเอง หรือไม่ก็เขียนในลักษณะการยกข้อความอ้างอิง และไม่ว่ากรณีไหนก็ตาม ต้องให้เครดิตพร้อมใส่ลิงก์ของทุกบทความที่มีการใช้เนื้อหาเสมอ

สารบัญ

รวมคำแปลวลีเด็ดจากญี่ปุ่น
มอดูลต่างๆ
-- numpy
-- matplotlib

-- pandas
-- manim
-- opencv
-- pyqt
-- pytorch
การเรียนรู้ของเครื่อง
-- โครงข่าย
     ประสาทเทียม
ภาษา javascript
ภาษา mongol
ภาษาศาสตร์
maya
ความน่าจะเป็น
บันทึกในญี่ปุ่น
บันทึกในจีน
-- บันทึกในปักกิ่ง
-- บันทึกในฮ่องกง
-- บันทึกในมาเก๊า
บันทึกในไต้หวัน
บันทึกในยุโรปเหนือ
บันทึกในประเทศอื่นๆ
qiita
บทความอื่นๆ

บทความแบ่งตามหมวด



ติดตามอัปเดตของบล็อกได้ที่แฟนเพจ

  ค้นหาบทความ

  บทความแนะนำ

ตัวอักษรกรีกและเปรียบเทียบการใช้งานในภาษากรีกโบราณและกรีกสมัยใหม่
ที่มาของอักษรไทยและความเกี่ยวพันกับอักษรอื่นๆในตระกูลอักษรพราหมี
การสร้างแบบจำลองสามมิติเป็นไฟล์ .obj วิธีการอย่างง่ายที่ไม่ว่าใครก็ลองทำได้ทันที
รวมรายชื่อนักร้องเพลงกวางตุ้ง
ภาษาจีนแบ่งเป็นสำเนียงอะไรบ้าง มีความแตกต่างกันมากแค่ไหน
ทำความเข้าใจระบอบประชาธิปไตยจากประวัติศาสตร์ความเป็นมา
เรียนรู้วิธีการใช้ regular expression (regex)
การใช้ unix shell เบื้องต้น ใน linux และ mac
g ในภาษาญี่ปุ่นออกเสียง "ก" หรือ "ง" กันแน่
ทำความรู้จักกับปัญญาประดิษฐ์และการเรียนรู้ของเครื่อง
ค้นพบระบบดาวเคราะห์ ๘ ดวง เบื้องหลังความสำเร็จคือปัญญาประดิษฐ์ (AI)
หอดูดาวโบราณปักกิ่ง ตอนที่ ๑: แท่นสังเกตการณ์และสวนดอกไม้
พิพิธภัณฑ์สถาปัตยกรรมโบราณปักกิ่ง
เที่ยวเมืองตานตง ล่องเรือในน่านน้ำเกาหลีเหนือ
ตระเวนเที่ยวตามรอยฉากของอนิเมะในญี่ปุ่น
เที่ยวชมหอดูดาวที่ฐานสังเกตการณ์ซิงหลง
ทำไมจึงไม่ควรเขียนวรรณยุกต์เวลาทับศัพท์ภาษาต่างประเทศ

ไทย

日本語

中文