
import numpy as np
class Phueanban:
    def __init__(self,nk=5):
        self.nk = nk # จำนวนเพื่อนบ้านที่จะพิจารณา
    def rianru(self,X,z):
        self.X = X # เก็บข้อมูลตำแหน่ง
        self.z = z # เก็บข้อมูลการแบ่งกลุ่ม
        self.n_klum = z.max()+1 # จำนวนกลุ่ม
    def thamnai(self,X):
        n = len(X) # จำนวนข้อมูลที่จะคำนวณหา
        n_batch = int(np.ceil(1000000./X.shape[1]/len(self.X)))
        z = np.empty(n,dtype=int)
        for c in range(0,n,n_batch):
            Xn = X[c:c+n_batch]
            n_Xn = len(Xn)
            raya2 = ((Xn[None]-self.X[:,None])**2).sum(2)
            klum_thi_klai = self.z[raya2.argsort(0)]
            n_nai_klum = np.stack([(klum_thi_klai[:self.nk]==k).sum(0) for k in range(self.n_klum)])
            mi_maksut = n_nai_klum.max(0)
            maksutmai = (n_nai_klum==mi_maksut)
            for i in range(n_Xn):
                for j in range(self.nk):
                    k = klum_thi_klai[j,i]
                    if(maksutmai[k,i]):
                        z[i+c] = k
                        break
        return zfrom sklearn import datasets
from sklearn.model_selection import train_test_split
mnist = datasets.fetch_openml('mnist_784')
X,z = mnist.data.astype(float),mnist.target.astype(int)
np.random.seed(0)
X_fuek,X_truat,z_fuek,z_truat = train_test_split(X,z,test_size=200/70000)
pb = Phueanban(nk=1)
pb.rianru(X_fuek,z_fuek)
zz = pb.thamnai(X_fuek[:200]) # ทำนายข้อมูลฝึก ดึงมาแค่ 200 ตัว ให้เท่ากับข้อมูลตรวจสอบ
print((zz==z_fuek[:200]).mean()) # ได้ 1.0
zz = pb.thamnai(X_truat) # ทำนายข้อมูลตรวจสอบ
print((zz==z_truat).mean()) # ได้ 0.975import time
import matplotlib.pyplot as plt
pb = Phueanban()
pb.rianru(X_fuek,z_fuek)
maen_fuek = []
maen_truat = []
t1 = time.time()
for nk in range(1,31):
    pb.nk = nk
    zz = pb.thamnai(X_fuek[:200])
    maen_fuek.append((zz==z_fuek[:200]).mean())
    zz = pb.thamnai(X_truat)
    maen_truat.append((zz==z_truat).mean())
    # เนื่องจากใช้เวลานาน ให้จับเวลาแล้วแสดงความคืบหน้าไปด้วยเพื่อไม่ให้เคว้ง
    print(u'รอบที่ %d เวลาผ่านไปแล้ว %.2f วินาที'%(nk,time.time()-t1))
plt.plot(np.arange(1,31),maen_fuek,'#771133')
plt.plot(np.arange(1,31),maen_truat,'#117733')
plt.xlabel(u'จำนวนเพื่อนบ้าน',family='Tahoma')
plt.ylabel(u'ความแม่นยำ',family='Tahoma')
plt.legend([u'ฝึกฝน',u'ตรวจสอบ'],prop={'family':'Tahoma'})
plt.show()
from sklearn.neighbors import KNeighborsClassifier as Knn
knn = Knn(n_jobs=-1)
knn.fit(X_fuek,z_fuek)
maen_fuek = []
maen_truat = []
for i in range(1,31):
    knn.set_params(n_neighbors=i)
    maen_fuek.append(knn.score(X_fuek[:200],z_fuek[:200]))
    maen_truat.append(knn.score(X_truat,z_truat))ติดตามอัปเดตของบล็อกได้ที่แฟนเพจ