import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
X,z = datasets.make_blobs(n_samples=100,centers=4,cluster_std=0.8,random_state=0)
plt.axes(aspect=1).scatter(X[:,0],X[:,1],c=z,s=30,edgecolor='k',cmap='rainbow')
plt.show()
from sklearn.neighbors import KNeighborsClassifier as Knn
knn = Knn()
knn.fit(X,z)
nmesh = 200
mx,my = np.meshgrid(np.linspace(X[:,0].min(),X[:,0].max(),nmesh),np.linspace(X[:,1].min(),X[:,1].max(),nmesh))
mX = np.stack([mx.ravel(),my.ravel()],1)
mz = knn.predict(mX).reshape(nmesh,nmesh)
plt.axes(xlim=[X[:,0].min(),X[:,0].max()],ylim=[X[:,1].min(),X[:,1].max()],aspect=1)
plt.contourf(mx,my,mz,alpha=0.1,cmap='rainbow')
plt.contour(mx,my,mz,colors='#222222')
plt.scatter(X[:,0],X[:,1],c=z,edgecolor='k',cmap='rainbow')
plt.show()
knn = Knn(n_neighbors=1,p=1)
# หรือ knn = Knn(1,p=1) ก็ได้ เพราะ n_neighbors เป็นคีย์เวิร์ดลำดับแรกอยู่แล้ว
import time
X,z = datasets.make_blobs(n_samples=10000,n_features=10,centers=4,random_state=0)
for al in ['ball_tree','kd_tree','brute','auto']:
t1 = time.time()
knn = Knn(algorithm=al)
knn.fit(X,z)
knn.predict(X)
print(u'%s: %.3f วินาที'%(al,time.time()-t1))
ball_tree: 0.670 วินาที
kd_tree: 0.585 วินาที
brute: 3.151 วินาที
auto: 0.585 วินาที
X,z = datasets.make_blobs(n_samples=20000,n_features=20,random_state=0)
for j in [1,2,3]:
t1 = time.time()
Knn(n_jobs=j).fit(X,z).predict(X)
print(u'n_jobs=%s: %.3f วินาที'%(j,time.time()-t1))
n_jobs=1: 8.393 วินาที
n_jobs=2: 4.825 วินาที
n_jobs=3: 4.235 วินาที
X,z = datasets.make_blobs(n_samples=8,centers=2,random_state=5)
knn = Knn(n_neighbors=3)
knn.fit(X,z)
plt.axes(aspect=1)
plt.scatter(X[:,0],X[:,1],c=z,edgecolor='k',cmap='summer')
plt.show()
k = knn.kneighbors(X)
print(k[0])
print(k[1])
for i in range(8):
print(', '.join(['%d > %.2f'%(k[1][i][j],k[0][i][j]) for j in range(3)]))
ได้
[[ 0. 1.12763922 1.91394086]
[ 0. 0.68420112 1.16141094]
[ 0. 1.43561929 1.9628876 ]
[ 0. 0.77421595 1.16141094]
[ 0. 2.24451596 2.95862282]
[ 0. 1.12763922 1.35891489]
[ 0. 0.77421595 1.43561929]
[ 0. 0.68420112 1.35891489]]
[[0 5 6]
[1 7 3]
[2 6 3]
[3 6 1]
[4 0 6]
[5 0 7]
[6 3 2]
[7 1 5]]
0 > 0.00, 5 > 1.13, 6 > 1.91
1 > 0.00, 7 > 0.68, 3 > 1.16
2 > 0.00, 6 > 1.44, 3 > 1.96
3 > 0.00, 6 > 0.77, 1 > 1.16
4 > 0.00, 0 > 2.24, 6 > 2.96
5 > 0.00, 0 > 1.13, 7 > 1.36
6 > 0.00, 3 > 0.77, 2 > 1.44
7 > 0.00, 1 > 0.68, 5 > 1.36
print(knn.kneighbors(X,n_neighbors=8,return_distance=0))
ได้
[[0 5 6 1 3 7 4 2]
[1 7 3 5 6 0 2 4]
[2 6 3 1 0 4 7 5]
[3 6 1 7 2 0 5 4]
[4 0 6 5 2 3 1 7]
[5 0 7 1 3 6 4 2]
[6 3 2 1 0 5 7 4]
[7 1 5 3 0 6 2 4]]
print(type(knn.kneighbors_graph(X))) # ให้แสดงชนิด
print(knn.kneighbors_graph(X).toarray()) # แปลงเป็นอาเรย์ธรรมดา
<class 'scipy.sparse.csr.csr_matrix'>
[[ 1. 0. 0. 0. 0. 1. 1. 0.]
[ 0. 1. 0. 1. 0. 0. 0. 1.]
[ 0. 0. 1. 1. 0. 0. 1. 0.]
[ 0. 1. 0. 1. 0. 0. 1. 0.]
[ 1. 0. 0. 0. 1. 0. 1. 0.]
[ 1. 0. 0. 0. 0. 1. 0. 1.]
[ 0. 0. 1. 1. 0. 0. 1. 0.]
[ 0. 1. 0. 0. 0. 1. 0. 1.]]
print(knn.kneighbors_graph(X,mode='distance',n_neighbors=8).toarray())
[[ 0. 2.00350493 3.21293728 2.07395977 2.24451596 1.12763922
1.91394086 2.23635633]
[ 2.00350493 0. 3.12214116 1.16141094 4.05546124 1.46770732
1.80934825 0.68420112]
[ 3.21293728 3.12214116 0. 1.9628876 3.45228072 3.83210647
1.43561929 3.80417203]
[ 2.07395977 1.16141094 1.9628876 0. 3.58705789 2.16656867
0.77421595 1.84527381]
[ 2.24451596 4.05546124 3.45228072 3.58705789 0. 3.35721827
2.95862282 4.43043208]
[ 1.12763922 1.46770732 3.83210647 2.16656867 3.35721827 0.
2.40100554 1.35891489]
[ 1.91394086 1.80934825 1.43561929 0.77421595 2.95862282 2.40100554
0. 2.46612122]
[ 2.23635633 0.68420112 3.80417203 1.84527381 4.43043208 1.35891489
2.46612122 0. ]]
X,z = datasets.make_blobs(n_samples=200,centers=2,cluster_std=2.5,random_state=12)
nmesh = 200
mx,my = np.meshgrid(np.linspace(X[:,0].min(),X[:,0].max(),nmesh),np.linspace(X[:,1].min(),X[:,1].max(),nmesh))
mX = np.stack([mx.ravel(),my.ravel()],1)
for i in [0,1]:
n = 3+27*i
knn = Knn(n)
knn.fit(X,z)
k = knn.kneighbors(X)
for j in [0,1]:
if(j==1):
mz = knn.predict_proba(mX)[:,1].reshape(nmesh,nmesh)
else:
mz = knn.predict(mX).reshape(nmesh,nmesh)
plt.subplot(221+i+2*j,xlim=[X[:,0].min(),X[:,0].max()],ylim=[X[:,1].min(),X[:,1].max()],aspect=1)
plt.scatter(X[:,0],X[:,1],10,c=z,edgecolor='k',cmap='winter')
plt.contourf(mx,my,mz,100,cmap='winter',zorder=0)
if(j==0):
plt.title('n=%d'%n)
else:
plt.ylabel('proba')
plt.show()
ติดตามอัปเดตของบล็อกได้ที่แฟนเพจ