import numpy as np
import matplotlib.pyplot as plt
X = np.random.normal(2,4,[1000,2])
z = ((X[:,0]+X[:,1])<8)*2-((X[:,0]+X[:,1])<0)
X[:,0] *= 2
plt.axes(aspect=1)
plt.scatter(X[:,0],X[:,1],c=z,edgecolor='k',cmap='inferno',alpha=0.5)
plt.show()
from sklearn.preprocessing import StandardScaler as StaSke
X = StaSke().fit_transform(X) # ทำให้เป็นมาตรฐาน
praeruam_naiklum = [] # ความแปรปรวนร่วมเกี่ยวภายในแต่ละกลุ่ม
praeruam_rawangklum = [] # ความแปรปรวนร่วมเกี่ยวระหว่างกลุ่ม
for i in range(max(z)+1):
X_i = X[z==i] # ข้อมุลในแต่ละกลุ่ม
n_naiklum = X_i.shape[0] # จำนวนข้อมุลในแต่ละกลุ่ม
chalia_naiklum = X_i.mean(0) # ค่าเฉลี่ยภายในกลุ่ม
praeruam_naiklum.append(np.cov(X_i.T))
praeruam_rawangklum.append(n_naiklum*chalia_naiklum[:,None]*chalia_naiklum)
SW = np.sum(praeruam_naiklum,0)
SB = np.sum(praeruam_rawangklum,0)
praeruam = np.linalg.inv(SW).dot(SB)
kha_eig,vec_eig = np.linalg.eigh(praeruam) # หาเวกเตอร์และค่าลักษณะเฉพาะ
Xi = X.dot(vec_eig[:,::-1]) # คำนวณค่าในพิกัดใหม่
plt.axes(aspect=1)
plt.scatter(Xi[:,0],Xi[:,1],c=z,edgecolor='k',cmap='inferno',alpha=0.5)
plt.show()
class WikhroKanchamnaekPraphetChoengsen:
def rianru(self,X,z,sta=1):
if(sta):
self.staske = StaSke()
X = self.staske.fit_transform(X)
SW,SB = [],[]
for i in range(max(z)+1):
X_i = X[z==i]
m_i = X_i.mean(0)
SW.append(np.cov(X_i.T))
SB.append(X_i.shape[0]*m_i[:,None]*m_i)
SW = np.sum(SW,0)
SB = np.sum(SB,0)
kha_eig,vec_eig = np.linalg.eigh(np.linalg.inv(SW).dot(SB))
self.V = vec_eig[:,::-1]
self.a = kha_eig[::-1]/kha_eig.sum()
def plaeng(self,X,sta=1):
if(sta):
X = self.staske.transform(X)
return X.dot(self.V)
def rianru_plaeng(self,X,z,sta=1):
if(sta):
self.staske = StaSke()
X = self.staske.fit_transform(X)
self.rianru(X,z,sta=0)
return self.plaeng(X,sta=0)
X = (np.random.normal(0,0.4,[200,2,7]) + np.arange(7)).T.reshape(1400,2)
z = np.tile(np.arange(7),[200,1]).T.ravel()
plt.axes(aspect=1) # ระบบพิกัดเก่า
plt.scatter(X[:,0],X[:,1],c=z,edgecolor='k',cmap='rainbow',alpha=0.5)
Xi = WikhroKanchamnaekPraphetChoengsen().rianru_plaeng(X,z) # แปลงพิกัด
plt.figure()
plt.axes(aspect=1) # ระบบพิกัดใหม่
plt.scatter(Xi[:,0],Xi[:,1],c=z,edgecolor='k',cmap='rainbow',alpha=0.5)
plt.show()
ก่อนแปลงfrom sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA
from sklearn import datasets
X,z = datasets.make_blobs(1250,n_features=10,centers=5,cluster_std=2,random_state=5)
lda = LDA(n_components=2)
Xi = lda.fit_transform(X,z)
plt.axes(aspect=1)
plt.scatter(Xi[:,0],Xi[:,1],c=z,edgecolor='k',cmap='Spectral',alpha=0.4)
plt.show()
ติดตามอัปเดตของบล็อกได้ที่แฟนเพจ