import numpy as np
import matplotlib.pyplot as plt
from unagi import Affin,Sigmoid,Sigmoid_entropy,Adam,Relu
class Prasat:
def __init__(self,m,eta,kratun='relu'):
m.append(1)
self.m = m
self.chan = []
for i in range(len(m)-1):
self.chan.append(Affin(m[i],m[i+1],np.sqrt(2./m[i])))
if(i<len(m)-2):
if(kratun=='relu'):
self.chan.append(Relu())
else:
self.chan.append(Sigmoid())
self.chan.append(Sigmoid_entropy())
self.opt = Adam(self.param(),eta=eta)
def rianru(self,X,z,n_thamsam,n_batch=50):
n = len(z)
self.entropy = []
self.khanaen = []
for o in range(n_thamsam):
lueak = np.random.permutation(n)
for i in range(0,n,n_batch):
Xb = X[lueak[i:i+n_batch]]
zb = z[lueak[i:i+n_batch]]
entropy = self.ha_entropy(Xb,zb)
entropy.phraeyon()
self.opt()
entropy,khanaen = self.ha_entropy(Xb,zb,ao_khanaen=1)
self.entropy.append(entropy.kha)
self.khanaen.append(khanaen)
def ha_entropy(self,X,z,ao_khanaen=0):
for c in self.chan[:-1]:
X = c(X)
if(ao_khanaen):
return self.chan[-1](X,z),((X.kha>=0).flatten()==z).mean()
return self.chan[-1](X,z)
def param(self):
p = []
for c in self.chan:
if(hasattr(c,'param')):
p.extend(c.param)
return p
def thamnai(self,X):
for c in self.chan[:-1]:
X = c(X)
return (X.kha>=0).flatten().astype(int)
np.random.seed(7)
r = np.tile(np.sqrt(np.linspace(0.5,25,2000)),2)
t = np.random.normal(np.sqrt(r*50),0.5)
z = np.arange(2).repeat(2000)
t += z*np.pi
X = np.array([r*np.cos(t),r*np.sin(t)]).T
plt.scatter(X[:,0],X[:,1],50,c=z,alpha=0.1,edgecolor='k',cmap='RdYlGn')
plt.show()
prasat = Prasat(m=[2,70],eta=0.005)
prasat.rianru(X,z,n_thamsam=100,n_batch=50)
plt.subplot(211,xticks=[])
plt.plot(prasat.entropy,'#772277')
plt.title(u'เอนโทรปี',family='Tahoma',size=12)
plt.subplot(212)
plt.plot(prasat.khanaen,'#227777')
plt.title(u'คะแนน',family='Tahoma',size=12)
plt.figure()
mx,my = np.meshgrid(np.linspace(X[:,0].min(),X[:,0].max(),200),np.linspace(X[:,1].min(),X[:,1].max(),200))
mX = np.array([mx.ravel(),my.ravel()]).T
mz = prasat.thamnai(mX).reshape(200,-1)
plt.axes(aspect=1,xlim=(X[:,0].min(),X[:,0].max()),ylim=(X[:,1].min(),X[:,1].max()))
plt.contourf(mx,my,mz,cmap='RdYlGn',alpha=0.2)
plt.scatter(X[:,0],X[:,1],20,c=z,alpha=0.5,edgecolor='k',cmap='RdYlGn')
plt.show()
ติดตามอัปเดตของบล็อกได้ที่แฟนเพจ