import torch
from torch.utils.data import TensorDataset as Tenda
from torch.utils.data import DataLoader as Dalo
เพื่อความสะดวกจึงได้ย่อชื่อให้สั้นลงด้วยX = torch.arange(8)[:,None]+torch.LongTensor([0,8])
z = torch.arange(8)*10
print(X)
print(z)
tensor([[ 0, 8],
[ 1, 9],
[ 2, 10],
[ 3, 11],
[ 4, 12],
[ 5, 13],
[ 6, 14],
[ 7, 15]])
tensor([ 0, 10, 20, 30, 40, 50, 60, 70])
data = Tenda(X,z)
print(data) # ได้ <torch.utils.data.dataset.TensorDataset object at 0x000001D2623848B0>
แล้วก็นำมาป้อนให้ DataLoader โดยกำหนดขนาดมินิแบตช์ที่ต้องการ
dalo = Dalo(data,batch_size=3)
print(dalo) # ได้ <torch.utils.data.dataloader.DataLoader object at 0x000001D260FD7FA0>
for Xb,zb in dalo:
print(Xb,zb)
tensor([[ 0, 8],
[ 1, 9],
[ 2, 10]]) tensor([ 0, 10, 20])
tensor([[ 3, 11],
[ 4, 12],
[ 5, 13]]) tensor([30, 40, 50])
tensor([[ 6, 14],
[ 7, 15]]) tensor([60, 70])
dalo = Dalo(data,batch_size=3,shuffle=1)
for i in range(2):
for Xb,zb in dalo:
print(Xb,zb)
print('----------')
ได้
tensor([[ 1, 9],
[ 5, 13],
[ 6, 14]]) tensor([10, 50, 60])
tensor([[ 0, 8],
[ 2, 10],
[ 3, 11]]) tensor([ 0, 20, 30])
tensor([[ 7, 15],
[ 4, 12]]) tensor([70, 40])
----------
tensor([[ 2, 10],
[ 5, 13],
[ 0, 8]]) tensor([20, 50, 0])
tensor([[ 4, 12],
[ 7, 15],
[ 6, 14]]) tensor([40, 70, 60])
tensor([[ 1, 9],
[ 3, 11]]) tensor([10, 30])
----------
def sang_minibatch(X,z,batch_size,shuffle=True):
dataset = Tenda(X,z)
return Dalo(dataset,batch_size,shuffle)
relu = torch.nn.ReLU()
ha_entropy = torch.nn.CrossEntropyLoss()
class Prasat(torch.nn.Sequential):
def __init__(self,m,eta=0.01):
super(Prasat,self).__init__()
for i in range(1,len(m)):
self.add_module('lin%d'%i,torch.nn.Linear(m[i-1],m[i]))
if(i<len(m)-1):
self.add_module('relu%d'%i,relu)
self.opt = torch.optim.Adam(self.parameters(),lr=eta)
def rianru(self,X,z,n_thamsam,n_batch=64):
X = torch.Tensor(X)
z = torch.LongTensor(z)
minibatch = sang_minibatch(X,z,n_batch)
for o in range(n_thamsam):
for Xb,zb in minibatch:
ha_entropy(self(Xb),zb).backward()
self.opt.step()
self.opt.zero_grad()
def thamnai(self,X):
X = torch.Tensor(X)
return self(X).argmax(1).numpy()
import numpy as np
import matplotlib.pyplot as plt
z = np.arange(6).repeat(120)
r = np.random.normal(z+1,0.2)
t = np.random.uniform(-0.5,0.5,720)*np.pi
x = r*np.cos(t)
y = r*np.sin(t)
X = np.array([x,y]).T
prasat = Prasat([2,60,40,6],eta=0.01)
prasat.rianru(X,z,50)
mx,my = np.meshgrid(np.linspace(x.min(),x.max(),200),np.linspace(y.min(),y.max(),200))
mX = np.array([mx.ravel(),my.ravel()]).T
mz = prasat.thamnai(mX)
mz = mz.reshape(200,200)
plt.xlim(x.min(),x.max())
plt.ylim(y.min(),y.max())
plt.contourf(mx,my,mz,alpha=0.2,cmap='terrain')
plt.scatter(x,y,c=z,edgecolor='k',cmap='terrain')
plt.show()
ติดตามอัปเดตของบล็อกได้ที่แฟนเพจ