from torch.utils.data import SubsetRandomSampler as Sarasa
srs = Sarasa([1,3,5,9])
print([x for x in srs])
print([x for x in srs])
[3, 5, 9, 1]
[9, 5, 1, 3]
import numpy as np
import matplotlib.pyplot as plt
import torchvision.datasets as ds
import torchvision.transforms as tf
from torch.utils.data import DataLoader as Dalo
data = ds.MNIST('~/pytorchdata/mnist',transform=tf.ToTensor(),download=1)
sampler = Sarasa(np.arange(0,31,2))
plt.figure(figsize=[6.4,0.4])
rup = Dalo(data,batch_size=16,sampler=sampler)
plt.axes([0,0,1,1]).imshow(np.hstack([x for x,z in rup][0][:,0]),cmap='gray')
sampler = Sarasa(np.arange(1,32,2))
plt.figure(figsize=[6.4,0.4])
rup = Dalo(data,batch_size=16,sampler=sampler)
plt.axes([0,0,1,1]).imshow(np.hstack([x for x,z in rup][0][:,0]),cmap='gray')
plt.show()
def triam_rup(folder_rup,satsuan_truat,n_batch_fuek,n_batch_truat):
tran_fuek = tf.Compose([
tf.RandomVerticalFlip(), # เนื่องจากเป็นรูปร่างอิสระ จะพลิกบนล่างหรือซ้ายขวาก็ได้
tf.RandomHorizontalFlip(),
tf.ToTensor(),
tf.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])
tran_truat = tf.Compose([
tf.ToTensor(),
tf.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])
rup_fuek = ds.ImageFolder(folder_rup,transform=tran_fuek)
rup_truat = ds.ImageFolder(folder_rup,transform=tran_truat)
n = len(rup_fuek) # จำนวนข้อมูลทั้งหมด
n_truat = int(n*satsuan_truat) # จำนวนข้อมูลตรวจสอบ
sumlueak = np.random.permutation(n) # สร้างตัวสุ่มเลือก
sf = Sarasa(sumlueak[n_truat:]) # sampler ของข้อมูลฝึก
st = Sarasa(sumlueak[:n_truat]) # sampler ของข้อมูลตรวจสอบ
rup_fuek = Dalo(rup_fuek,batch_size=n_batch_fuek,sampler=sf)
rup_truat = Dalo(rup_truat,batch_size=n_batch_truat,sampler=st)
return rup_fuek,rup_truat
folder_rup = 'ruprang-misi-25x25x1000x6' # โฟลเดอร์ใหญ่ที่ใส่รูปแยกเป็นโฟลเดอร์ย่อยต่างๆไว้
rup_fuek,rup_truat = triam_rup(folder_rup,0.2,64,100000)
rup_truat = list(rup_truat)[0]
import torch
import time
relu = torch.nn.ReLU()
maxp = torch.nn.MaxPool2d(2)
ha_entropy = torch.nn.CrossEntropyLoss()
class Plianrup(torch.nn.Module):
def __init__(self,*k):
super(Plianrup,self).__init__()
self.k = k
def forward(self,x):
return x.reshape(x.size()[0],*self.k)
class Prasat(torch.nn.Sequential):
def __init__(self,eta=0.001):
super(Prasat,self).__init__(
torch.nn.Conv2d(3,16,4,1,0),
torch.nn.BatchNorm2d(16),
relu,
maxp,
torch.nn.Conv2d(16,16,4,1,0),
torch.nn.BatchNorm2d(16),
relu,
maxp,
Plianrup(-1),
torch.nn.Linear(16*4*4,32),
torch.nn.BatchNorm1d(32),
relu,
torch.nn.Linear(32,6)
)
self.opt = torch.optim.Adam(self.parameters(),lr=eta)
def rianru(self,rup_fuek,rup_truat,n_thamsam=200,ro=10):
X_truat,z_truat = rup_truat
self.khanaen_truat = []
self.khanaen_fuek = []
khanaen_sungsut = 0
t_roem = time.time()
for o in range(n_thamsam):
self.train()
X_fuek,z_fuek = [],[]
for Xb,zb in rup_fuek:
a = self(Xb)
J = ha_entropy(a,zb)
J.backward()
self.opt.step()
self.opt.zero_grad()
X_fuek.append(Xb)
z_fuek.append(zb)
X_fuek,z_fuek = torch.cat(X_fuek),torch.cat(z_fuek)
self.eval()
khanaen_fuek = self.ha_khanaen_(X_fuek,z_fuek)
khanaen_truat = self.ha_khanaen_(X_truat,z_truat)
self.khanaen_fuek.append(khanaen_fuek)
self.khanaen_truat.append(khanaen_truat)
print('%d ครั้งผ่านไป ใช้เวลาไป %.1f นาที ทำนายข้อมูลฝึกแม่น %.4f ข้อมูลตรวจสอบแม่น %.4f'%(o+1,(time.time()-t_roem)/60,khanaen_fuek,khanaen_truat))
if(khanaen_truat>khanaen_sungsut):
khanaen_sungsut = khanaen_truat
maiphoem = 0
else:
maiphoem += 1
if(ro>0 and maiphoem>=ro):
break
def thamnai_(self,X):
return self(X).argmax(1)
def ha_khanaen_(self,X,z):
return (self.thamnai_(X)==z).numpy().mean()
prasat = Prasat()
prasat.rianru(rup_fuek,rup_truat)
plt.plot(prasat.khanaen_fuek,'#aa33bb')
plt.plot(prasat.khanaen_truat,'#33bb22')
plt.legend([u'ข้อมูลฝึก',u'ข้อมูลตรวจสอบ'],prop={'family':'Tahoma','size':14})
plt.show()
ติดตามอัปเดตของบล็อกได้ที่แฟนเพจ