φυβλαςのβλογ
phyblas的博客



วิธีการเพื่อนบ้านใกล้สุด k ตัวโดยใช้ sklearn
เขียนเมื่อ 2017/10/31 18:27
แก้ไขล่าสุด 2022/07/21 14:59
หลังจากที่ได้แนะนำการเขียนคลาสของวิธีการเพื่อนบ้านใกล้สุด k ตัวด้วยตัวเองไปแล้ว https://phyblas.hinaboshi.com/20171028

แต่ว่าคลาสที่เขียนนั้นเป็นแค่ของที่ลองสร้างขึ้นเองเพื่อให้เข้าใจหลักการทำงานคร่าวๆเท่านั้น ประสิทธิภาพการใช้งานจริงๆไม่ค่อยดี ดังนั้นในการใช้งานจริงเราจะใช้ของที่มีคนทำมาไว้เป็นอย่างดีแล้วอย่าง sklearn ดีกว่า

การทำงานของวิธีการเพื่อนบ้านใกล้สุด k ตัวใน sklearn นั้นมีการใช้อัลกอริธึมที่เรียกว่า KD tree หรือ Ball tree ในการค้นหาจุดที่ใกล้ที่สุด แทนที่จะมาคำนวณระยะห่างจากทุกจุด ดังนั้นจึงเร็วกว่า เหมาะแก่การใช้งานจริงๆ

รายละเอียดตรงนั้นค่อนข้างซับซ้อนจึงจะไม่กล่าวถึง ในบทความนี้จะพูดถึงแค่การใช้วิธีการเพื่อนบ้านใกล้สุด k ตัวโดยใช้คลาสใน sklearn

ขอเริ่มจากยกตัวอย่างการใช้ให้ดูแล้วค่อยอธิบายโค้ดทีหลัง

สมมุติว่าต้องการจำแนกแบ่งเขตข้อมูลสองมิติที่มีการกระจายตัวแบบนี้



ซึ่งสร้างจากโค้ดนี้
import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets

X,z = datasets.make_blobs(n_samples=100,centers=4,cluster_std=0.8,random_state=0)
plt.axes(aspect=1).scatter(X[:,0],X[:,1],c=z,s=30,edgecolor='k',cmap='rainbow')
plt.show()

เมื่อลองทำการแบ่งเขตด้วยวิธีการเพื่อนบ้านใกล้สุด k ตัวก็ทำได้ดังนี้
from sklearn.neighbors import KNeighborsClassifier as Knn
knn = Knn()
knn.fit(X,z)

nmesh = 200
mx,my = np.meshgrid(np.linspace(X[:,0].min(),X[:,0].max(),nmesh),np.linspace(X[:,1].min(),X[:,1].max(),nmesh))
mX = np.stack([mx.ravel(),my.ravel()],1)
mz = knn.predict(mX).reshape(nmesh,nmesh)
plt.axes(xlim=[X[:,0].min(),X[:,0].max()],ylim=[X[:,1].min(),X[:,1].max()],aspect=1)
plt.contourf(mx,my,mz,alpha=0.1,cmap='rainbow')
plt.contour(mx,my,mz,colors='#222222')
plt.scatter(X[:,0],X[:,1],c=z,edgecolor='k',cmap='rainbow')
plt.show()



วิธีการใช้อาจเอาไปเทียบกับการถดถอยโลจิสติกที่ได้เขียนถึงไปใน https://phyblas.hinaboshi.com/20171010

นั่นคือสามารถสรุปง่ายๆว่าขั้นตอนการใช้งานหลักๆก็คือ

1. สร้างออบเจ็กต์จากคลาส KNeighborsClassifier (ในที่นี้ย่อเป็น Knn)
2. นำออบเจ็กต์ที่ได้มาใช้เมธอด fit เพื่อทำการเรียนรู้ข้อมูลที่ป้อนเข้าไป
3. ทำนายหมวดหมู่ของจุดที่ต้องการโดยใช้เมธอด predict

ตัวอย่างข้างต้นใช้ไปโดยที่ใช้ค่าต่างๆเป็นค่าตั้งต้นทั้งหมดไม่ได้ปรับแต่งอะไร แต่ว่าเราสามารถปรับแต่งอะไรได้หลายอย่าง อาร์กิวเมนต์มีดังนี้

n_neighbors
คือจำนวนเพื่อนบ้านที่จะพิจารณา

ค่าตั้งต้นคือ 5

weights
วิธีการคิดน้ำหนัก กรณีที่จำนวนเพื่อนบ้านมากกว่า 1 โดยเลือกได้ ๒ แบบคือ
- uniform คือ คิดน้ำหนักแต่ละจุดเท่ากันไม่ว่าจะระยะห่างเท่าไหร่
- distance คิดน้ำหนักตามส่วนกลับของระยะทางของจุด ยิ่งใกล้ยิ่งมีน้ำหนักมาก

นอกจากนี้ยังอาจใส่เป็นฟังก์ชันบางอย่างสำหรับคิด โดยฟังก์ชันจะต้องรับอาเรย์ของค่าระยะทางไปคำนวณแล้วคืนค่ากลับมาเป็นอาเรย์รูปร่างเดิม

ค่าตั้งต้นคือ uniform

algorithm
วิธีการคิดเพื่อค้นหาจุดที่ใกล้ที่สุด ใส่ได้ ๔ อย่างคือ
- ball_tree คือ ใช้วิธี Ball tree
- kd_tree คือ ใช้วิธี KD tree
- brute คือ คำนวณทั้งหมด
- auto เลือกวิธีการเองตามความเหมาะสม ขึ้นกับข้อมูลที่ใส่เข้าไป

ค่าตั้งต้นคือ auto

metric
มาตรวัดระยะทาง (วิธีการคิดว่าจะคำนวณระยะทางยังไง)

มีอยู่หลายวิธี แต่ในกรณีส่วนใหญ่จะใช้ minkowski คือ sum(|x - y|^p)^(1/p)

ค่าตั้งต้นคือ minkowski

มาตรวัดแบบอื่นดูได้ที่ http://scikit-learn.org/stable/modules/generated/sklearn.neighbors.DistanceMetric.html

p
จำนวนเลขชี้กำลัง สำหรับ minkowski

ค่าตั้งต้นคือ 2

n_jobs
จำนวนจ็อบในการรันแบบคู่ขนานในกรณีที่คอมมีหลายคอร์ ถ้าใส่ -1 จะใช้คอร์ทั้งหมดเท่าที่มี

ค่าตั้งต้นคือ 1



ตัวอย่าง ลองปรับบางค่า เช่นลองแก้บรรทัดที่สร้าง knn เป็น
knn = Knn(n_neighbors=1,p=1)
# หรือ knn = Knn(1,p=1) ก็ได้ เพราะ n_neighbors เป็นคีย์เวิร์ดลำดับแรกอยู่แล้ว

ผลที่ได้ก็จะเปลี่ยนไป กลายเป็นแบบนี้



สังเกตได้ว่าเส้นแบ่งเป็นเส้นตรง เพราะระยะทางถูกคำนวณจากผลรวมค่าสัมบูรณ์ แทนที่จะเป็นยกกำลังสอง

ต่อมาลองเปรียบเทียบอัลกอริธึมดูบ้าง ลองเขียนแบบนี้ ข้อมูลมี 1000 ตัว แล้วตัวแปรต้นมี 10 ชนิด
import time
X,z = datasets.make_blobs(n_samples=10000,n_features=10,centers=4,random_state=0)
for al in ['ball_tree','kd_tree','brute','auto']:
    t1 = time.time()
    knn = Knn(algorithm=al)
    knn.fit(X,z)
    knn.predict(X)
    print(u'%s: %.3f วินาที'%(al,time.time()-t1))

ได้
ball_tree: 0.670 วินาที
kd_tree: 0.585 วินาที
brute: 3.151 วินาที
auto: 0.585 วินาที

จะเห็นได้ว่าการใช้ ball_tree หรือ kd_tree นั้นได้ผลเร็วกว่า brute ซึ่งเป็นการคำนวณทั้งหมดมาก ส่วน auto ในที่นี้ kd_tree ถูกเลือก จึงได้ผลเท่ากัน

ต่อมาลองทดสอบเรื่องจำนวนจ็อบดูด้วย
X,z = datasets.make_blobs(n_samples=20000,n_features=20,random_state=0)
for j in [1,2,3]:
    t1 = time.time()
    Knn(n_jobs=j).fit(X,z).predict(X)
    print(u'n_jobs=%s: %.3f วินาที'%(j,time.time()-t1))

ได้
n_jobs=1: 8.393 วินาที
n_jobs=2: 4.825 วินาที
n_jobs=3: 4.235 วินาที

ยิ่งจ็อบเยอะก็ยิ่งเร็ว แต่เครื่องก็ทำงานหนักเต็มที่ขึ้นตาม



นอกจากนี้ในตัวออบเจ็กต์ knn ยังมีเมธอดน่าสนในอื่นๆที่สามารถใช้ได้ ได้แก่

.kneighbors
จะคืนค่าระยะห่างของเพื่อนบ้านที่ถูกใช้ พร้อมกับดัชนีของจุดนั้น

ตัวอย่าง
X,z = datasets.make_blobs(n_samples=8,centers=2,random_state=5)
knn = Knn(n_neighbors=3)
knn.fit(X,z)
plt.axes(aspect=1)
plt.scatter(X[:,0],X[:,1],c=z,edgecolor='k',cmap='summer')
plt.show()
k = knn.kneighbors(X)
print(k[0])
print(k[1])
for i in range(8):
    print(', '.join(['%d > %.2f'%(k[1][i][j],k[0][i][j]) for j in range(3)]))
ได้
[[ 0.          1.12763922  1.91394086]
 [ 0.          0.68420112  1.16141094]
 [ 0.          1.43561929  1.9628876 ]
 [ 0.          0.77421595  1.16141094]
 [ 0.          2.24451596  2.95862282]
 [ 0.          1.12763922  1.35891489]
 [ 0.          0.77421595  1.43561929]
 [ 0.          0.68420112  1.35891489]]
[[0 5 6]
 [1 7 3]
 [2 6 3]
 [3 6 1]
 [4 0 6]
 [5 0 7]
 [6 3 2]
 [7 1 5]]
0 > 0.00, 5 > 1.13, 6 > 1.91
1 > 0.00, 7 > 0.68, 3 > 1.16
2 > 0.00, 6 > 1.44, 3 > 1.96
3 > 0.00, 6 > 0.77, 1 > 1.16
4 > 0.00, 0 > 2.24, 6 > 2.96
5 > 0.00, 0 > 1.13, 7 > 1.36
6 > 0.00, 3 > 0.77, 2 > 1.44
7 > 0.00, 1 > 0.68, 5 > 1.36


สามารถใส่คีย์เวิร์ดเพิ่มเติมได้แก่
- n_neighbors จำนวนเพื่อนบ้านที่จะให้คืนค่ากลับมา หากไม่ใส่จะคืนเท่ากับค่า n_neighbors ที่กำหนดให้ตั้งแต่ตอนแรก
- return_distance ถ้าใส่เป็น 0 จะไม่คืนค่าระยะทางกลับมา แต่จะคืนแค่อย่างเดียวคือดัชนีของจุด

ลองใส่ดู
print(knn.kneighbors(X,n_neighbors=8,return_distance=0))
ได้
[[0 5 6 1 3 7 4 2]
 [1 7 3 5 6 0 2 4]
 [2 6 3 1 0 4 7 5]
 [3 6 1 7 2 0 5 4]
 [4 0 6 5 2 3 1 7]
 [5 0 7 1 3 6 4 2]
 [6 3 2 1 0 5 7 4]
 [7 1 5 3 0 6 2 4]]


.kneighbors_graph
จะคืนค่าที่แสดงว่าตัวไหนเป็นเพื่อนบ้านของตัวนั้นบ้าง

เพียงแต่ว่าผลที่ได้จะอยู่ในรูปของ sparse matrix ดังนั้นเพื่อให้เห็นชัดสามารถแปลงเป็นอาเรย์ธรรมดาได้โดยเมธอด .toarray อีกที
print(type(knn.kneighbors_graph(X))) # ให้แสดงชนิด
print(knn.kneighbors_graph(X).toarray()) # แปลงเป็นอาเรย์ธรรมดา

ได้
<class 'scipy.sparse.csr.csr_matrix'>
[[ 1.  0.  0.  0.  0.  1.  1.  0.]
 [ 0.  1.  0.  1.  0.  0.  0.  1.]
 [ 0.  0.  1.  1.  0.  0.  1.  0.]
 [ 0.  1.  0.  1.  0.  0.  1.  0.]
 [ 1.  0.  0.  0.  1.  0.  1.  0.]
 [ 1.  0.  0.  0.  0.  1.  0.  1.]
 [ 0.  0.  1.  1.  0.  0.  1.  0.]
 [ 0.  1.  0.  0.  0.  1.  0.  1.]]


ถ้าใส่คีย์เวิร์ด mode='distance' ลงไปจะคืนค่าเป็นระยะทางแทน แล้วก็มีคีย์เวิร์ด n_neighbors สามารถเปลี่ยนจำนวนเพื่อนบ้านที่พิจารณาได้ด้วย

ตัวอย่าง
print(knn.kneighbors_graph(X,mode='distance',n_neighbors=8).toarray())

ได้
[[ 0.          2.00350493  3.21293728  2.07395977  2.24451596  1.12763922
   1.91394086  2.23635633]
 [ 2.00350493  0.          3.12214116  1.16141094  4.05546124  1.46770732
   1.80934825  0.68420112]
 [ 3.21293728  3.12214116  0.          1.9628876   3.45228072  3.83210647
   1.43561929  3.80417203]
 [ 2.07395977  1.16141094  1.9628876   0.          3.58705789  2.16656867
   0.77421595  1.84527381]
 [ 2.24451596  4.05546124  3.45228072  3.58705789  0.          3.35721827
   2.95862282  4.43043208]
 [ 1.12763922  1.46770732  3.83210647  2.16656867  3.35721827  0.
   2.40100554  1.35891489]
 [ 1.91394086  1.80934825  1.43561929  0.77421595  2.95862282  2.40100554
   0.          2.46612122]
 [ 2.23635633  0.68420112  3.80417203  1.84527381  4.43043208  1.35891489
   2.46612122  0.        ]]



.predict_proba
เป็นการคำนวณความน่าจะเป็นว่าจะอยู่กลุ่มไหน แทนที่จะทำนายชัดลงไปเลยว่าอยู่กลุ่มไหน

ค่าความน่าจะเป็นในที่นี่คล้ายกับในการถดถอยโลจิสติก แต่ต่างกันตรงที่วิธีการคำนวณ

ตัวอย่างเทียบระหว่าง predict กับ predict_proba โดยจะเทียบให้เห็นด้วยว่าค่าจำนวนเพื่อนบ้านต่างกันก็มีความละเอียดในการแบ่งเขตทำนายที่ต่างกัน
X,z = datasets.make_blobs(n_samples=200,centers=2,cluster_std=2.5,random_state=12)
nmesh = 200
mx,my = np.meshgrid(np.linspace(X[:,0].min(),X[:,0].max(),nmesh),np.linspace(X[:,1].min(),X[:,1].max(),nmesh))
mX = np.stack([mx.ravel(),my.ravel()],1)
for i in [0,1]:
    n = 3+27*i
    knn = Knn(n)
    knn.fit(X,z)
    k = knn.kneighbors(X)
    for j in [0,1]:
        if(j==1):
            mz = knn.predict_proba(mX)[:,1].reshape(nmesh,nmesh)
        else:
            mz = knn.predict(mX).reshape(nmesh,nmesh)
        plt.subplot(221+i+2*j,xlim=[X[:,0].min(),X[:,0].max()],ylim=[X[:,1].min(),X[:,1].max()],aspect=1)
        plt.scatter(X[:,0],X[:,1],10,c=z,edgecolor='k',cmap='winter')
        plt.contourf(mx,my,mz,100,cmap='winter',zorder=0)
        if(j==0):
            plt.title('n=%d'%n)
        else:
            plt.ylabel('proba')
plt.show()


โดยรวมแล้วถือว่า sklearn สามารถใช้งานได้อย่างสะดวกดีมาก



อ้างอิง


-----------------------------------------

囧囧囧囧囧囧囧囧囧囧囧囧囧囧囧囧囧囧囧囧囧囧囧囧囧

ดูสถิติของหน้านี้

หมวดหมู่

-- คอมพิวเตอร์ >> ปัญญาประดิษฐ์
-- คอมพิวเตอร์ >> เขียนโปรแกรม >> python >> numpy
-- คอมพิวเตอร์ >> เขียนโปรแกรม >> python >> matplotlib
-- คอมพิวเตอร์ >> เขียนโปรแกรม >> python >> sklearn

ไม่อนุญาตให้นำเนื้อหาของบทความไปลงที่อื่นโดยไม่ได้ขออนุญาตโดยเด็ดขาด หากต้องการนำบางส่วนไปลงสามารถทำได้โดยต้องไม่ใช่การก๊อปแปะแต่ให้เปลี่ยนคำพูดเป็นของตัวเอง หรือไม่ก็เขียนในลักษณะการยกข้อความอ้างอิง และไม่ว่ากรณีไหนก็ตาม ต้องให้เครดิตพร้อมใส่ลิงก์ของทุกบทความที่มีการใช้เนื้อหาเสมอ

目录

从日本来的名言
模块
-- numpy
-- matplotlib

-- pandas
-- manim
-- opencv
-- pyqt
-- pytorch
机器学习
-- 神经网络
javascript
蒙古语
语言学
maya
概率论
与日本相关的日记
与中国相关的日记
-- 与北京相关的日记
-- 与香港相关的日记
-- 与澳门相关的日记
与台湾相关的日记
与北欧相关的日记
与其他国家相关的日记
qiita
其他日志

按类别分日志



ติดตามอัปเดตของบล็อกได้ที่แฟนเพจ

  查看日志

  推荐日志

ตัวอักษรกรีกและเปรียบเทียบการใช้งานในภาษากรีกโบราณและกรีกสมัยใหม่
ที่มาของอักษรไทยและความเกี่ยวพันกับอักษรอื่นๆในตระกูลอักษรพราหมี
การสร้างแบบจำลองสามมิติเป็นไฟล์ .obj วิธีการอย่างง่ายที่ไม่ว่าใครก็ลองทำได้ทันที
รวมรายชื่อนักร้องเพลงกวางตุ้ง
ภาษาจีนแบ่งเป็นสำเนียงอะไรบ้าง มีความแตกต่างกันมากแค่ไหน
ทำความเข้าใจระบอบประชาธิปไตยจากประวัติศาสตร์ความเป็นมา
เรียนรู้วิธีการใช้ regular expression (regex)
การใช้ unix shell เบื้องต้น ใน linux และ mac
g ในภาษาญี่ปุ่นออกเสียง "ก" หรือ "ง" กันแน่
ทำความรู้จักกับปัญญาประดิษฐ์และการเรียนรู้ของเครื่อง
ค้นพบระบบดาวเคราะห์ ๘ ดวง เบื้องหลังความสำเร็จคือปัญญาประดิษฐ์ (AI)
หอดูดาวโบราณปักกิ่ง ตอนที่ ๑: แท่นสังเกตการณ์และสวนดอกไม้
พิพิธภัณฑ์สถาปัตยกรรมโบราณปักกิ่ง
เที่ยวเมืองตานตง ล่องเรือในน่านน้ำเกาหลีเหนือ
ตระเวนเที่ยวตามรอยฉากของอนิเมะในญี่ปุ่น
เที่ยวชมหอดูดาวที่ฐานสังเกตการณ์ซิงหลง
ทำไมจึงไม่ควรเขียนวรรณยุกต์เวลาทับศัพท์ภาษาต่างประเทศ