import torch
relu = torch.nn.ReLU()
class Khrongkhai(torch.nn.Module):
def __init__(self,m0,m1,m2,m3):
super(Khrongkhai,self).__init__()
self.lin1 = torch.nn.Linear(m0,m1)
self.lin2 = torch.nn.Linear(m1,m2)
self.lin3 = torch.nn.Linear(m2,m3)
def forward(self,x):
a1 = self.lin1(x)
h1 = relu(a1)
a2 = self.lin2(h1)
h2 = relu(a2)
a3 = self.lin3(h2)
return a3
print(issubclass(torch.nn.Linear,torch.nn.Module)) # ได้ True
khrongkhai = Khrongkhai(1,1,1,1)
print(len(list(khrongkhai.parameters()))) # ได้ 6
def __init__(self,m0,m1,m2,m3):
super(Khrongkhai,self).__init__()
self.add_module('lin1',torch.nn.Linear(m0,m1))
self.add_module('lin2',torch.nn.Linear(m1,m2))
self.add_module('lin3',torch.nn.Linear(m2,m3))
import numpy as np
import matplotlib.pyplot as plt
z = np.arange(4).repeat(40)
r = np.random.normal(z+1,0.25)
t = np.random.uniform(0,np.pi,160)
x = r*np.cos(t)
y = r*np.sin(t)
X = np.array([x,y]).T
plt.scatter(x,y,c=z,edgecolor='k',cmap='jet')
plt.show()
ha_entropy = torch.nn.CrossEntropyLoss()
X = torch.Tensor(X)
z = torch.LongTensor(z)
khrongkhai = Khrongkhai(2,60,40,4)
opt = torch.optim.Adam(khrongkhai.parameters(),lr=0.1)
ha_entropy = torch.nn.CrossEntropyLoss()
for i in range(200):
a = khrongkhai(X)
J = ha_entropy(a,z)
J.backward()
opt.step()
opt.zero_grad()
mx,my = np.meshgrid(np.linspace(x.min(),x.max(),200),np.linspace(y.min(),y.max(),200))
mX = torch.Tensor(np.array([mx.ravel(),my.ravel()]).T)
mz = khrongkhai(mX).argmax(1)
mz = mz.data.numpy().reshape(200,200)
plt.xlim(x.min(),x.max())
plt.ylim(y.min(),y.max())
plt.scatter(x,y,c=z,edgecolor='k',cmap='jet')
plt.contourf(mx,my,mz,alpha=0.2,cmap='jet')
plt.show()
khrongkhai = torch.nn.Sequential(
torch.nn.Linear(2,60),
relu,
torch.nn.Linear(60,40),
relu,
torch.nn.Linear(40,4)
)
print(khrongkhai)
Sequential(
(0): Linear(in_features=2, out_features=60, bias=True)
(1): ReLU()
(2): Linear(in_features=60, out_features=40, bias=True)
(3): ReLU()
(4): Linear(in_features=40, out_features=4, bias=True)
)
print(khrongkhai[0]) # ได้ Linear(in_features=2, out_features=60, bias=True)
khrongkhai = torch.nn.Sequential()
khrongkhai.add_module('lin1',torch.nn.Linear(2,60))
# หรือ khrongkhai.lin1 = torch.nn.Linear(2,60)
khrongkhai.add_module('relu1',relu)
khrongkhai.add_module('lin2',torch.nn.Linear(60,40))
khrongkhai.add_module('relu2',relu)
khrongkhai.add_module('lin3',torch.nn.Linear(40,4))
print(khrongkhai)
print(khrongkhai.lin2) # หรือ khrongkhai[2]
Sequential(
(lin1): Linear(in_features=2, out_features=60, bias=True)
(relu1): ReLU()
(lin2): Linear(in_features=60, out_features=40, bias=True)
(relu2): ReLU()
(lin3): Linear(in_features=40, out_features=4, bias=True)
)
Linear(in_features=60, out_features=40, bias=True)
l1 = torch.nn.Sequential(torch.nn.Linear(2,60),relu)
l2 = torch.nn.Sequential(torch.nn.Linear(60,40),relu)
l3 = torch.nn.Linear(40,4)
khrongkhai = torch.nn.Sequential(l1,l2,l3)
print(khrongkhai)
Sequential(
(0): Sequential(
(0): Linear(in_features=2, out_features=60, bias=True)
(1): ReLU()
)
(1): Sequential(
(0): Linear(in_features=60, out_features=40, bias=True)
(1): ReLU()
)
(2): Linear(in_features=40, out_features=4, bias=True)
)
class Flatten(torch.nn.Module):
def forward(self,x):
return x.flatten()
flat = Flatten()
flat = torch.nn.Module()
flat.forward = lambda x:x.flatten()
khrongkhai = torch.nn.Sequential(
torch.nn.Linear(2,80),
relu,
torch.nn.Linear(80,50),
relu,
torch.nn.Linear(50,1),
flat)
z = np.arange(2).repeat(120)
r = np.random.normal(z*2+2,0.5)
t = np.random.uniform(-0.5,0.5,240)*np.pi
x,y = r*np.cos(t),r*np.sin(t)
X = np.array([x,y]).T
X = torch.Tensor(X)
z = torch.Tensor(z)
opt = torch.optim.Adam(khrongkhai.parameters(),lr=0.1)
ha_entropy = torch.nn.BCEWithLogitsLoss()
for i in range(200):
a = khrongkhai(X)
J = ha_entropy(a,z)
J.backward()
opt.step()
opt.zero_grad()
mx,my = np.meshgrid(np.linspace(x.min(),x.max(),200),np.linspace(y.min(),y.max(),200))
mX = torch.Tensor(np.array([mx.ravel(),my.ravel()]).T)
mz = khrongkhai(mX)>0
mz = mz.data.numpy().reshape(200,200)
plt.xlim(x.min(),x.max())
plt.ylim(y.min(),y.max())
plt.scatter(x,y,c=z.data.numpy(),edgecolor='k',cmap='Paired')
plt.contourf(mx,my,mz,alpha=0.2,cmap='Paired')
plt.show()
ติดตามอัปเดตของบล็อกได้ที่แฟนเพจ