φυβλαςのβλογ
phyblasのブログ



[python] แยกแยะภาพตัวเลขที่เขียนด้วยลายมือด้วยการวิเคราะห์การถดถอยโลจิสติก
เขียนเมื่อ 2017/09/22 20:52
แก้ไขล่าสุด 2021/09/28 16:42
ในตอนที่แล้วได้แนะนำให้รู้จักชุดข้อมูลตัวเลขที่เขียนด้วยลายมือของ MNIST ไป https://phyblas.hinaboshi.com/20170920



สำหรับในตอนนี้จะลองนำข้อมูลนี้มาใช้ทดสอบการแยกแยะตัวเลขดู โดยใช้วิธีที่พื้นฐานที่สุด นั่นคือการถดถอยโลจิสติกแบบมัลติโนเมียล (การถดถอยซอฟต์แม็กซ์)

รายละเอียดเรื่องการถดถอยโลจิสติกได้เขียนถึงไปมากแล้วในบทความก่อนหน้านี้ เช่น https://phyblas.hinaboshi.com/20161205

โค้ดสำหรับแบบจำลองการถดถอยโลจิสติกที่ใช้ในนี้ดัดแปลงจากในบทความก่อนๆมา

ในบทความก่อนๆที่เคยเขียนถึงนั้นเรามักใช้กับข้อมูลที่มีจำนวนตัวแปรต้นแค่ ๒ แต่สำหรับข้อมูล MNIST นี้ ตัวแปรต้นมีมากถึง ๗๘๔ ตัว

แม้จะมีตัวแปรต้น (จำนวนมิติ) มากขนาดนั้น แต่ก็สามารถใช้วิธีเดียวกันในการแก้ปัญหาได้ แนวคิดก็เหมือนกัน คือสร้างโปรแกรมให้มันเรียนรู้เพื่อปรับค่ำน้ำหนักของแต่ละตัวแปร

แนวคิดในการออกแบบสร้างเป็นดังนี้

- เนื่องจากข้อมูลเดิมมีค่า 0~255 เพื่อให้เหมาะสมต่อการคำนวณจึงหารด้วย 255 ให้ค่าอยู่ระหว่าง 0~1

- จำนวนตัวแปรต้นคือ 784 ตัวแปรตาม 10 ตัว (คือผลการทายตัวเลขทั้งสิบในแบบ one-hot)

- ค่าเสียหายใช้เอนโทรปีไขว้

- ตั้งเงื่อนไขการหยุดทำซ้ำเป็นว่าถ้าค่าความแม่นไม่เพิ่มขึ้นเลยเกิน ๑๐ ครั้งก็ให้หยุดและใช้ค่าน้ำหนักที่ให้ค่าความแม่นสูงสุด

- เนื่องจากข้อมูลนำเข้ามีจำนวนมากและมีความหลากหลาย ควรใช้มินิแบตช์

โค้ด
import numpy as np

def softmax(x):
    exp_x = np.exp(x.T-x.max(1))
    return (exp_x/exp_x.sum(0)).T

class ThotthoiLogistic:
    def __init__(self,eta):
        self.eta = eta

    def rianru(self,X,z,n_thamsam,n_batch=0,romaiphoem=10):
        n = len(z)
        if(n_batch==0 or n<n_batch):
            n_batch = n
        self.kiklum = int(z.max()+1)
        z_1h = z[:,None]==range(self.kiklum)
        self.w = np.zeros([X.shape[1]+1,self.kiklum])
        self.entropy = []
        self.thuktong = []
        thukmaksut = 0 # ค่าจำนวนที่ถูกมากสุด
        thukmaiphoem = 0 # นับว่าจำนวนที่ถูกไม่เพิ่มมาแล้วกี่ครั้ง
        for j in range(n_thamsam):
            lueak = np.random.permutation(n)
            for i in range(0,n,n_batch):
                Xn = X[lueak[i:i+n_batch]]
                zn = z_1h[lueak[i:i+n_batch]]
                phi = self.ha_softmax(Xn)
                eee = (zn-phi)/len(zn)*self.eta
                self.w[1:] += np.dot(eee.T,Xn).T
                self.w[0] += eee.sum(0)

            thukmai = self.thamnai(X)==z
            thukmak = thukmai.mean()*100

            if(thukmak > thukmaksut):
                # ถ้าจำนวนที่ถูกมากขึ้นกว่าเดิมก็บันทึกค่าจำนวนนั้น และน้ำหนักในตอนนั้นไว้
                thukmaksut = thukmak
                thukmaiphoem = 0
                w = self.w.copy()
            else:
                thukmaiphoem += 1 # ถ้าไม่ถูกมากขึ้นก็นับไว้ว่าไม่เพิ่มไปอีกครั้งแล้ว

            self.thuktong += [thukmak]
            self.entropy += [self.ha_entropy(X,z_1h)]
            print(u'ครั้งที่ %d ถูก %.3f%% สูงสุด %.3f%% ไม่เพิ่มมาแล้ว %d ครั้ง'%(j+1,self.thuktong[-1],thukmaksut,thukmaiphoem))

            if(romaiphoem!=0 and thukmaiphoem>=romaiphoem):
                break # ถ้าจำนวนที่ถูกไม่เพิ่มเลย 10 ครั้งก็เลิกทำ

        self.w = w # ค่าน้ำหนักที่ได้ในท้ายสุด เอาตามค่าที่ทำให้ทายถูกมากที่สุด

    def thamnai(self,X):
        return (np.dot(X,self.w[1:])+self.w[0]).argmax(1)

    def ha_softmax(self,X):
        return softmax(np.dot(X,self.w[1:])+self.w[0])

    def ha_entropy(self,X,z_1h):
        return -(z_1h*np.log(self.ha_softmax(X)+1e-7)).mean()



# ดึงข้อมูล MNIST
mnist = datasets.fetch_openml('mnist_784')
mnist.data = mnist.data/255. # ทำให้ค่าเป็น 0~1
np.random.seed(0)
sumriang = np.random.permutation(len(mnist.target)) # สุ่มเรียงลำดับข้อมูลใหม่
X = mnist.data[sumriang]
z = mnist.target.astype(int)[sumriang]

# เริ่มการเรียนรู้
eta = 0.24 # อัตราการเรียนรู้
n_thamsam = 1000 # จำนวนทำซ้ำสูงสุดถ้าไม่มีการหยุดเสียก่อน
n_batch = 100 # จำนวนมินิแบตช์
romaiphoem = 10 # จะให้หยุดเมื่อความแม่นยำไม่เพิ่มเกินกี่ครั้ง
tl = ThotthoiLogistic(eta)
tl.rianru(X,z,n_thamsam,n_batch,romaiphoem)

# กราฟแสดงความคืบหน้าในการเรียนรู้
ax = plt.subplot(211)
ax.set_title(u'เอนโทรปี',fontname='Tahoma')
plt.plot(tl.entropy)
plt.tick_params(labelbottom='off')
ax = plt.subplot(212)
ax.set_title(u'% ถูก',fontname='Tahoma')
plt.plot(tl.thuktong)
plt.show()

ผลที่ได้พบว่าความแม่นยำขึ้นไปได้ถึงที่ประมาณ 93% และไม่อาจสูงขึ้นไปกว่านี้แล้ว ซึ่งลองดูตัวอย่างที่คนอื่นๆทำก็พบว่าจะได้สูงสุดแค่ประมาณนั้นกันเหมือนกัน



ที่จริงแล้วการแยกภาพตัวเลขเป็นเรื่องที่ต้องพิจารณาอย่างซับซ้อนเกินกว่าที่แค่ใช้การถดถอยโลจิสติกธรรมดาแล้วจะทำได้ดีเพียงพอ

ระบบการคิดของการถดถอยโลจิสติกที่ใช้การคำนวณแค่ชั้นเดียวแบบนี้ค่อนข้างเข้าใจได้ไม่ยาก คือโดยทั่วไปแล้วจะพิจารณาจากว่าช่องไหนมีการเขียนแล้วจะมีโอกาสที่จะเป็นตัวเลขไหนมากกว่า เช่น "เลข 0 ต้องมีรอบๆแต่ต้องไม่มีตรงกลาง"

ลองดูค่าน้ำหนักของเลข 0 ที่ได้เป็นผลลัพธ์ออกมา
plt.imshow(tl.w[1:,0].reshape(28,28),cmap='gray_r')
plt.show()



ในที่นี้สีดำคือช่องที่ถ้าถูกเขียนแล้วจะมีโอกาสเป็น 0 มาก ส่วนสีขาวคือถ้าถูกเขียนจะมีโอกาสเป็น 0 น้อย

จากนั้นลองดูตัวเลขอื่นๆที่เหลือ
for i in range(1,10):
    plt.subplot(330+i)
    plt.imshow(tl.w[1:,i].reshape(28,28),cmap='gray_r')
plt.show()



แต่ละภาพก็บอกแนวโน้มได้คร่าวๆว่าเวลาคนเราเขียนตัวเลขต่างๆนั้น ดินสอมักจะถูกขีดที่ตำแหน่งไหนมากหรือน้อย

แต่ในความเป็นจริงแล้วระบบการคิดของมนุษย์เราไม่ควรจะง่ายๆแค่นั้น เช่นว่าเลข 0 บางครั้งอาจถูกเขียนเบี้ยวมาอยู่ทางขวามากหน่อย จนเส้นไปพาดผ่านตรงกลาง พอมีกรณีแบบนี้ขึ้นโอกาสทายผิดก็มีสูง

ปกติควรพิจารณาว่าเลข 0 เป็นสิ่งที่มีขอบและรูตรงกลาง ตรงไหนอาจจะเป็นขอบหรือรูก็ไม่รู้ แม้รูจะมักอยู่ตรงกลางสุด แต่ก็ไม่เสมอไป

ดังนั้นจำเป็นต้องสามารถคิดในลักษณะที่ว่า "ถ้าตรงนี้มีการเขียนแล้วตรงนั้นต้องไม่มีถึงจะควรเป็นเลขนี้ แต่ถ้าตรงนี้ไม่มีก็ควรต้องมีตรงนี้"

ซึ่งระบบที่เขียนขึ้นข้างต้นนี้จะยังไม่สามารถพิจารณาอะไรซับซ้อนหลายตลบแบบนั้นได้เพราะมีแค่ชั้นเดียว

อีกทั้งข้อมูลของทั้ง 784 ช่องถูกป้อนในลักษณะที่เป็นมิติเดียว แถมไม่มีอะไรเป็นตัวบอกว่าช่องไหนอยู่ติดกัน ช่องที่ 1,2,3 ก็เป็นตัวแปรต้นต้นตัวนึงเหมือนๆกัน ไม่ได้ถูกพิจารณาว่าช่อง 2 กับช่อง 1 มีความใกล้ชิดกันมากกว่าช่อง 3 กับช่อง 1

เพื่อที่จะให้เครื่องเรียนรู้อะไรซับซ้อนขึ้น จึงจำเป็นต้องสร้างระบบการเรียนรู้ที่ซับซ้อนขึ้น

วิธีการหนึ่งที่นิยมใช้กันมากก็คือ การเอาส่วนคำนวณการถดถอยโลจิสติกมาซ้อนต่อกันเป็น ๒ ชั้นขึ้นไป ซึ่งก็คือการสร้างเครือข่ายประสาทเทียม (神经网路, neural network)

นี่ก็เป็นเรื่องที่ตั้งใจจะเขียนถึงในบทความต่อๆไป



อ้างอิง


-----------------------------------------

囧囧囧囧囧囧囧囧囧囧囧囧囧囧囧囧囧囧囧囧囧囧囧囧囧

ดูสถิติของหน้านี้

หมวดหมู่

-- คอมพิวเตอร์ >> ปัญญาประดิษฐ์
-- คอมพิวเตอร์ >> เขียนโปรแกรม >> python >> numpy
-- คอมพิวเตอร์ >> เขียนโปรแกรม >> python >> matplotlib

ไม่อนุญาตให้นำเนื้อหาของบทความไปลงที่อื่นโดยไม่ได้ขออนุญาตโดยเด็ดขาด หากต้องการนำบางส่วนไปลงสามารถทำได้โดยต้องไม่ใช่การก๊อปแปะแต่ให้เปลี่ยนคำพูดเป็นของตัวเอง หรือไม่ก็เขียนในลักษณะการยกข้อความอ้างอิง และไม่ว่ากรณีไหนก็ตาม ต้องให้เครดิตพร้อมใส่ลิงก์ของทุกบทความที่มีการใช้เนื้อหาเสมอ

目次

日本による名言集
モジュール
-- numpy
-- matplotlib

-- pandas
-- manim
-- opencv
-- pyqt
-- pytorch
機械学習
-- ニューラル
     ネットワーク
javascript
モンゴル語
言語学
maya
確率論
日本での日記
中国での日記
-- 北京での日記
-- 香港での日記
-- 澳門での日記
台灣での日記
北欧での日記
他の国での日記
qiita
その他の記事

記事の類別



ติดตามอัปเดตของบล็อกได้ที่แฟนเพจ

  記事を検索

  おすすめの記事

รวมร้านราเมงในเมืองฟุกุโอกะ
ตัวอักษรกรีกและเปรียบเทียบการใช้งานในภาษากรีกโบราณและกรีกสมัยใหม่
ที่มาของอักษรไทยและความเกี่ยวพันกับอักษรอื่นๆในตระกูลอักษรพราหมี
การสร้างแบบจำลองสามมิติเป็นไฟล์ .obj วิธีการอย่างง่ายที่ไม่ว่าใครก็ลองทำได้ทันที
รวมรายชื่อนักร้องเพลงกวางตุ้ง
ภาษาจีนแบ่งเป็นสำเนียงอะไรบ้าง มีความแตกต่างกันมากแค่ไหน
ทำความเข้าใจระบอบประชาธิปไตยจากประวัติศาสตร์ความเป็นมา
เรียนรู้วิธีการใช้ regular expression (regex)
การใช้ unix shell เบื้องต้น ใน linux และ mac
g ในภาษาญี่ปุ่นออกเสียง "ก" หรือ "ง" กันแน่
ทำความรู้จักกับปัญญาประดิษฐ์และการเรียนรู้ของเครื่อง
ค้นพบระบบดาวเคราะห์ ๘ ดวง เบื้องหลังความสำเร็จคือปัญญาประดิษฐ์ (AI)
หอดูดาวโบราณปักกิ่ง ตอนที่ ๑: แท่นสังเกตการณ์และสวนดอกไม้
พิพิธภัณฑ์สถาปัตยกรรมโบราณปักกิ่ง
เที่ยวเมืองตานตง ล่องเรือในน่านน้ำเกาหลีเหนือ
ตระเวนเที่ยวตามรอยฉากของอนิเมะในญี่ปุ่น
เที่ยวชมหอดูดาวที่ฐานสังเกตการณ์ซิงหลง
ทำไมจึงไม่ควรเขียนวรรณยุกต์เวลาทับศัพท์ภาษาต่างประเทศ

月別記事

2025年

1月 2月 3月 4月
5月 6月 7月 8月
9月 10月 11月 12月

2024年

1月 2月 3月 4月
5月 6月 7月 8月
9月 10月 11月 12月

2023年

1月 2月 3月 4月
5月 6月 7月 8月
9月 10月 11月 12月

2022年

1月 2月 3月 4月
5月 6月 7月 8月
9月 10月 11月 12月

2021年

1月 2月 3月 4月
5月 6月 7月 8月
9月 10月 11月 12月

もっと前の記事

ไทย

日本語

中文