φυβλαςのβλογ
phyblas的博客



สร้างภาพอธิบายการคำนวณที่เกิดขึ้นในชั้นคอนโวลูชันสองมิติของโครงข่ายประสาทเทียมโดยใช้ manim
เขียนเมื่อ 2021/03/26 13:22
แก้ไขล่าสุด 2021/09/28 16:42
ช่วงนี้กำลังพยายามเขียนเนื้อหาเรื่องโครงข่ายประสาทเทียมแบบคอนโวลูชันอยู่ จึงได้พยายามสร้างภาพต่างๆขึ้นมาอธิบาย ซึ่งก็โชคดีที่ได้มารู้จักมอดูลที่ใช้งานได้ดีตัวหนึ่ง นั่นคือ manim (คลิกลิงก์เพื่อดูเนื้อหาสอนการใช้งาน manim ได้)

การคำนวณที่เกิดขึ้นในชั้นคอนโวลูชันนั้นเป็นอะไรที่เข้าใจยาก มองเห็นภาพยากอยู่ การวาดภาพดีๆจะช่วยให้เข้าใจได้ง่ายขึ้น โดยเฉพาะถ้าทำเป็นภาพเคลื่อนไหวได้น่าจะช่วยอธิบายได้ดี

ก่อนหน้านี้ได้เริ่มทำการเขียนเนื้อหาโครงข่ายประสาทเทียมเบื้องต้น ลงไว้ในบล็อก แต่ก็ค้างอยู่ตรงบทที่ ๒๐ เรื่องของโครงข่ายประสาทเทียมแบบคอนโวลูชัน เนื่องจากคิดหาวิธีในการเขียนอธิบายให้เข้าใจง่ายๆไม่ได้

งานนี้เลยค้างมาตั้งแต่ปี 2018 เคยคิดจะเขียนต่อ แต่ก็ไม่ได้เขียนสักที จนเวลาก็ผ่านมาถึง ๒ ปีครึ่งแล้ว

พอได้รู้จัก manim เลยคิดว่าเป็นโอกาสที่ดีที่จะนำมาใช้ในการอธิบายโครงค่ายประสาทเทียมแบบคอนโวลูชัน จึงสามารถเขียนบทที่ ๒๐ ออกมาได้ในที่สุด



จากนั้นจึงได้ฝึก manim จนเริ่มใช้งานได้คล่อง แล้วก็ได้สร้างวีดีโอตัวนี้ขึ้นมา ลงไว้ใน facebook >> https://www.facebook.com/watch/?v=232382701949345

แต่วีดีโอนี้ไม่ได้มีอะไรเพิ่มเติมไปจากเดิมที่เคยทำโดยเขียนด้วย matplotlib ซึ่งเคยใช้ลงไว้ในในบทความที่อธิบายเรื่องคอนโวลูชัน

ที่จริงแล้วการคำนวณที่เกิดขึ้นจริงภายในโครงข่ายประสาทเทียมนั้นยิ่งซับซ้อนกว่าในรูปนั้นไปอีก เพราะประกอบด้วยตัวกรองหลายตัวซึ่งต้องมาจับคู่คูณกับค่าป้อนเข้าหลายตัว และยังมีการบวกด้วยไบแอสด้วย ถึงจะคำนวณออกมาเป็นผลลัพธ์ในแต่ละชั้น

ดังนั้นต่อมาจึงได้สร้างภาพนี้ขึ้นมาเพื่ออธิบายคือภาพนี้ >> https://www.facebook.com/watch/?v=904503580382726

ในภาพนี้มี
- จำนวนช่องขาเข้า 3
- จำนวนช่องขาออก 3
- ขนาดตัวกรองเป็น 3×3
- ขนาดข้อมูลขาเข้าเป็น 4×4
- ดังนั้นผลลัพธ์ที่ได้ก็จะเป็นขนาด 2×2
- พารามิเตอร์น้ำหนักบนตัวกรองมีทั้งหมด 3×3×3×3=27 ตัว
- บวกกับพารามิเตอร์ไบแอสอีก 3 ตัว (เท่ากับจำนวนช่องขาออก)

สำหรับกรอบต่างๆที่เห็นในรูปนี้
- สีเขียวคือค่าป้อนเข้า (input)
- สีน้ำเงินคือตัวกรอง (kernel)
- สีม่วงคือไบแอส (bias)
- สีแดงคือผลลัพธ์ (output)



ต่อไปนี้เป็นโค้ดที่ใช้สร้างวีดีโอนั้นขึ้น
import numpy as np
import manimlib as mnm

class Manimala(mnm.Scene):
    def construct(self):
        nx = [3,4,4]
        nk = [3,3,3,3]
        nz = [nk[0],nk[1],nx[1]-nk[2]+1,nx[2]-nk[3]+1]
        
        x = np.random.randint(0,10,nx)
        k = np.random.randint(0,10,nk)
        b = np.empty(nk[0])
        zh = np.empty(nz)
        zh_ruam = np.empty([nz[0],nz[1]+1,nz[2],nz[3]])
        zk = np.empty([nz[0],nz[2],nz[3],nz[1],nk[2],nk[3]])
        for g in range(nz[0]):
            for h in range(nz[1]):
                for j in range(nz[2]):
                    for i in range(nz[3]):
                        z_ghji = (x[h,j:j+nk[2],i:i+nk[3]]*k[g,h])
                        zk[g,j,i,h] = z_ghji
                        zh[g,h,j,i] = z_ghji.sum()
            b[g] = np.random.randint(10,100)
            zh_ruam[g,0] = b[g]
            zh_ruam[g,1:] = zh[g].cumsum(0)
        
        
        lek_x = []
        krop_x = []
        for h in range(nx[0]):
            lek_x_h = []
            krop_x_h = []
            for j in range(nx[1]):
                for i in range(nx[2]):
                    lek_x_hji = mnm.Tex('%d'%x[h,j,i],font_size=60)
                    tamnaeng = [-1.5+i+h*0.5,-j+h*1.5,0]
                    lek_x_hji.move_to(tamnaeng)
                    lek_x_h.append(lek_x_hji)
                    
                    krop_x_hji = mnm.Square(1,fill_opacity=1,fill_color='#466422',color='#2c3711')
                    krop_x_hji.move_to(tamnaeng)
                    krop_x_h.append(krop_x_hji)
                    
            lek_x.append(mnm.VGroup(*lek_x_h))
            krop_x.append(mnm.VGroup(*krop_x_h))
        lek_x = mnm.VGroup(*lek_x)
        krop_x = mnm.VGroup(*krop_x)
        
        
        lek_k = []
        krop_k = []
        for g in range(nk[0]):
            lek_k_g = []
            krop_k_g = []
            for h in range(nk[1]):
                lek_k_gh = []
                krop_k_gh = []
                for j in range(nk[2]):
                    for i in range(nk[3]):
                        lek_k_ghji = mnm.Tex('%d\\times'%k[g,h,j,i],color='#d6e9a8',font_size=36)
                        tamnaeng = [-1.9+i+h*0.5-4.5+g*0.5/3,0.1-j+h*1.5-1+g*0.5,0]
                        lek_k_ghji.move_to(tamnaeng)
                        lek_k_gh.append(lek_k_ghji)
                        
                        krop_k_ghji = mnm.Square(1,fill_opacity=1,fill_color='#277ad7',color='#101240',stroke_opacity=0.3)
                        krop_k_ghji.move_to(tamnaeng)
                        krop_k_gh.append(krop_k_ghji)
                        
                lek_k_g.append(mnm.VGroup(*lek_k_gh))
                krop_k_g.append(mnm.VGroup(*krop_k_gh))
            lek_k.append(mnm.VGroup(*lek_k_g))
            krop_k.append(mnm.VGroup(*krop_k_g))
        lek_k = mnm.VGroup(*lek_k)
        krop_k = mnm.VGroup(*krop_k)
        
        
        lek_zk = [[[],[],[]],[[],[],[]],[[],[],[]]]
        krop_zk = [[[],[],[]],[[],[],[]],[[],[],[]]]
        for g in range(nz[0]):
            for jz in range(nz[2]):
                for iz in range(nz[3]):
                    for h in range(nz[1]):
                        lek_zk_gh = []
                        krop_zk_gh = []
                        for jk in range(nk[2]):
                            for ik in range(nk[3]):
                                tex = '%d'%zk[g,jz,iz,h,jk,ik]
                                lek_zk_ghji = mnm.Tex(tex,color='#d59ee7',stroke_opacity=1,stroke_width=2,font_size=80)
                                tamnaeng = [-1.9+ik+iz+h*0.5,0.1-jk-jz+h*1.5,0]
                                lek_zk_ghji.move_to(tamnaeng)
                                lek_zk_gh.append(lek_zk_ghji)
                                
                                krop_zk_ghji = mnm.Square(1,fill_opacity=0.8,fill_color='#277ad7',color='#101240',stroke_opacity=0.3)
                                krop_zk_ghji.move_to(tamnaeng)
                                krop_zk_gh.append(krop_zk_ghji)
                        
                        lek_zk[g][h].append(mnm.VGroup(*lek_zk_gh))
                        krop_zk[g][h].append(mnm.VGroup(*krop_zk_gh))
        
        
        lek_z = []
        krop_z = []
        lek_b = []
        krop_b = []
        for g in range(nz[0]):
            lek_z_g = []
            krop_z_g = []
            for h in range(nz[1]+1):
                lek_z_gh = []
                for j in range(nz[2]):
                    for i in range(nz[3]):
                        color = ['#1b8680','#48a8a2','#95d5d1','#d7f2f0'][h]
                        lek_z_ghji = mnm.Tex('%d'%zh_ruam[g,h,j,i],font_size=54,color=color)
                        tamnaeng = [i+4.5+g*0.5,-j+(g-1)*1.9,0]
                        lek_z_ghji.move_to(tamnaeng)
                        lek_z_gh.append(lek_z_ghji)
                        
                        if(h==0):
                            tamnaeng[0] += 3.5
                            krop_z_ghji = mnm.Square(1,fill_opacity=1,fill_color='#9a2c36',color='#541117')
                            krop_z_ghji.move_to(tamnaeng)
                            krop_z_g.append(krop_z_ghji)
                        
                lek_z_g.append(mnm.VGroup(*lek_z_gh))
            lek_z.append(mnm.VGroup(*lek_z_g))
            krop_z.append(mnm.VGroup(*krop_z_g))
            
            lek_b_g = mnm.Tex('%+d'%b[g],font_size=42,color='#d2feb1')
            tamnaeng = [4.5+g,3+1./3,0]
            lek_b_g.move_to(tamnaeng)
            lek_b.append(lek_b_g)
            
            krop_b_g = mnm.Square(1,fill_opacity=1,fill_color='#aa81c9',color='#50296f')
            krop_b_g.move_to(tamnaeng)
            krop_b.append(krop_b_g)
            
        lek_z = mnm.VGroup(*lek_z)
        krop_z = mnm.VGroup(*krop_z)
        lek_b = mnm.VGroup(*lek_b)
        krop_b = mnm.VGroup(*krop_b)
        
        
        
        for h in [2,1,0]:
            self.add(krop_x[h],lek_x[h])
            for g in [2,1,0]:
                self.add(krop_k[g][h],lek_k[g][h])
        for g in [2,1,0]:
            self.add(krop_z[g],krop_b[g],lek_b[g])
        
        for g in range(3):
            lis_play1 = [
                krop_z[g].animate.shift([-3.5,0,0]),
                mnm.ReplacementTransform(lek_b[g].copy(),lek_z[g][0]),
            ]
            for h in [2,1,0]:
                lis_play1.extend([
                    krop_k[g][h].animate.shift([4.5-g*0.5/3,-g*0.5+1,0]),
                    lek_k[g][h].animate.shift([4.5-g*0.5/3,-g*0.5+1,0]),
                ])
            
            self.play(*lis_play1,run_time=0.5)
            
            lis_play2 = [
                krop_b[g].animate.shift([0,2./3,0]),
                lek_b[g].animate.shift([0,2./3,0]),
            ]
            for h in [2,1,0]:
                lis_play2.append(krop_k[g][h].animate.set_fill(opacity=0.2))
                
            self.play(*lis_play2,run_time=0.1)
            
            for h in range(3):
                for j in range(nz[2]):
                    for i in range(nz[3]):
                        ji = j*nz[3]+i
                        self.play(
                            mnm.FadeIn(krop_zk[g][h][ji]),
                            mnm.ReplacementTransform(lek_k[g][h].copy(),lek_zk[g][h][ji]),
                            run_time=0.5
                        )
                        
                        self.play(
                            mnm.ReplacementTransform(mnm.VGroup(lek_zk[g][h][ji],lek_z[g][h][ji]),lek_z[g][h+1][ji]),
                            mnm.FadeOut(krop_zk[g][h][ji]),
                            run_time=0.5
                        )
                        
                        if(i!=nz[3]-1):
                            luean = [1,0,0]
                        else:
                            luean = [1-nz[3],-1,0]
                        if(ji!=nz[2]*nz[3]-1):
                            self.play(
                                lek_k[g][h].animate.shift(luean),
                                krop_k[g][h].animate.shift(luean),
                                run_time=0.25
                            )
                if(h!=2):
                    self.play(
                        krop_x[h].animate.shift([0,-3,0]),
                        lek_x[h].animate.shift([0,-3,0]),
                        run_time=0.25
                    )
                else:
                    self.play(
                        krop_x[1].animate.shift([0,3,0]),
                        lek_x[1].animate.shift([0,3,0]),
                        krop_x[0].animate.shift([0,3,0]),
                        lek_x[0].animate.shift([0,3,0]),
                        run_time=0.25
                    )
                self.play(
                    krop_k[g][h].animate.set_fill(opacity=1),
                    run_time=0.05)
                self.play(
                    krop_k[g][h].animate.shift([-8+g*0.5/3,g*0.5,0]),
                    lek_k[g][h].animate.shift([-8+g*0.5/3,g*0.5,0]),
                    run_time=0.2
                )
        
        lis_play3 = [
            krop_b.animate.shift([0,-2./3,0]),
            lek_b.animate.shift([0,-2./3,0])
        ]
        for h in [2,1,0]:
            for g in [2,1,0]:
                lis_play3.extend([
                    krop_k[g][h].animate.shift([2.5,0,0]),
                    lek_k[g][h].animate.shift([2.5,0,0]),
                ])
                
        self.play(*lis_play3,run_time=0.2)
        self.wait(0.1)


-----------------------------------------

囧囧囧囧囧囧囧囧囧囧囧囧囧囧囧囧囧囧囧囧囧囧囧囧囧

ดูสถิติของหน้านี้

หมวดหมู่

-- คอมพิวเตอร์ >> ปัญญาประดิษฐ์ >> โครงข่ายประสาทเทียม
-- คอมพิวเตอร์ >> เขียนโปรแกรม >> python >> numpy
-- คอมพิวเตอร์ >> เขียนโปรแกรม >> python >> manim

ไม่อนุญาตให้นำเนื้อหาของบทความไปลงที่อื่นโดยไม่ได้ขออนุญาตโดยเด็ดขาด หากต้องการนำบางส่วนไปลงสามารถทำได้โดยต้องไม่ใช่การก๊อปแปะแต่ให้เปลี่ยนคำพูดเป็นของตัวเอง หรือไม่ก็เขียนในลักษณะการยกข้อความอ้างอิง และไม่ว่ากรณีไหนก็ตาม ต้องให้เครดิตพร้อมใส่ลิงก์ของทุกบทความที่มีการใช้เนื้อหาเสมอ

目录

从日本来的名言
模块
-- numpy
-- matplotlib

-- pandas
-- manim
-- opencv
-- pyqt
-- pytorch
机器学习
-- 神经网络
javascript
蒙古语
语言学
maya
概率论
与日本相关的日记
与中国相关的日记
-- 与北京相关的日记
-- 与香港相关的日记
-- 与澳门相关的日记
与台湾相关的日记
与北欧相关的日记
与其他国家相关的日记
qiita
其他日志

按类别分日志



ติดตามอัปเดตของบล็อกได้ที่แฟนเพจ

  查看日志

  推荐日志

ตัวอักษรกรีกและเปรียบเทียบการใช้งานในภาษากรีกโบราณและกรีกสมัยใหม่
ที่มาของอักษรไทยและความเกี่ยวพันกับอักษรอื่นๆในตระกูลอักษรพราหมี
การสร้างแบบจำลองสามมิติเป็นไฟล์ .obj วิธีการอย่างง่ายที่ไม่ว่าใครก็ลองทำได้ทันที
รวมรายชื่อนักร้องเพลงกวางตุ้ง
ภาษาจีนแบ่งเป็นสำเนียงอะไรบ้าง มีความแตกต่างกันมากแค่ไหน
ทำความเข้าใจระบอบประชาธิปไตยจากประวัติศาสตร์ความเป็นมา
เรียนรู้วิธีการใช้ regular expression (regex)
การใช้ unix shell เบื้องต้น ใน linux และ mac
g ในภาษาญี่ปุ่นออกเสียง "ก" หรือ "ง" กันแน่
ทำความรู้จักกับปัญญาประดิษฐ์และการเรียนรู้ของเครื่อง
ค้นพบระบบดาวเคราะห์ ๘ ดวง เบื้องหลังความสำเร็จคือปัญญาประดิษฐ์ (AI)
หอดูดาวโบราณปักกิ่ง ตอนที่ ๑: แท่นสังเกตการณ์และสวนดอกไม้
พิพิธภัณฑ์สถาปัตยกรรมโบราณปักกิ่ง
เที่ยวเมืองตานตง ล่องเรือในน่านน้ำเกาหลีเหนือ
ตระเวนเที่ยวตามรอยฉากของอนิเมะในญี่ปุ่น
เที่ยวชมหอดูดาวที่ฐานสังเกตการณ์ซิงหลง
ทำไมจึงไม่ควรเขียนวรรณยุกต์เวลาทับศัพท์ภาษาต่างประเทศ