φυβλαςのβλογ
บล็อกของ phyblas



[python] วิเคราะห์การถดถอยโดยใช้ฟังก์ชันฐาน
เขียนเมื่อ 2018/07/20 18:50
แก้ไขล่าสุด 2024/10/03 19:33
เมื่อต้องการหาความสัมพันธ์อะไรบางอย่างระหว่างตัวแปรต้นและตัวแปรตามจากข้อมูลชุดหนึ่งที่มี เรียกปัญหานี้ว่าการวิเคราะห์การถดถอย (回归, regression)

เช่นสมมุติว่ามีตัวแปร ๒ ​ตัว ตัวแปรตั้งต้น x และตัวแปรตาม z มีความสัมพันธ์แบบนี้



เนื่องจากข้อมูลนี้ตัวแปรทั้ง ๒ ดูแล้วสามารถลากเส้นตรงเพื่ออธิบายได้ ดังนั้นจึงสามารถใช้การวิเคราะห์การถดถอยเชิงเส้น (线性回归, linear regression) นั่นคือเขียน z ในรูปของ
..(1)

โดย w และ b เป็นค่าคงที่บางอย่างที่ต้องหา โดยอาศัยข้อมูล เช่นตัวอย่างนี้จะได้ว่า w=0.3 และ b=0.4 และจะสามารถคำนวณลากเส้นได้ผ่านกลางจุดข้อมูลแบบนี้



เรื่องของการวิเคราะห์การถดถอยเชิงเส้นได้เขียนถึงไปแล้วใน https://phyblas.hinaboshi.com/20161210

แต่ในปัญหาโดยทั่วไปแล้วยากที่จะอธิบายด้วยเส้นตรงง่ายๆแบบนั้น เช่นหากข้อมูลมีหน้าตาแบบนี้



แบบนี้แทนที่จะแค่อธิบายด้วยเส้นตรง wx+b ง่ายๆ ควรใช้ฟังก์ชันอะไรบางอย่างที่ซับซ้อนกว่านั้น

ทางหนึ่งที่สามารถทำได้ก็คือ ใช้ฟังก์ชันฐาน (基函数, basis function)

คือ z อาจเขียนอยู่ในรูปของฟังก์ชันอะไรบางอย่างของ x มาบวกกัน อาจเขียนได้ในลักษณะแบบนี้
..(2)

โดย ϕi(x) คือฟังก์ชันฐาน คือเป็นกลุ่มฟังก์ชันอะไรสักอย่างของ x ที่ต้องการนำมาใช้อธิบายค่า โดยที่ ϕ แต่ละค่านั้นจะมีคุณสมบัติตั้งฉากซึ่งกันและกัน (正交, orthogonal) คือไม่สามารถเขียนตัวนึงในรูปของตัวอื่นๆบวกกันได้

ส่วน wi เป็นค่าน้ำหนักที่บ่งบอกว่าฟังก์ชันฐานแต่ละตัวมีความสำคัญแค่ไหน

และ m คือจำนวนฟังก์ชันฐานที่ใช้

ตัวอย่างฟังก์ชันฐานที่นิยมใช้ที่สุดก็คือ ฟังก์ชันพหุนาม นั่นคือ
..(3)

ก็จะได้ว่า
..(4)

ทีนี้เพื่อความสะดวกในการคำนวณ เราอาจเขียนใหม่ในรูปเวกเตอร์ได้เป็นแบบนี้
..(5)

โดย
..(6)
..(7)

สำหรับฟังก์ชันฐานพหุนาม
..(8)

ทีนี้จะเห็นว่าค่า z ที่คำนวณได้นั้นจะขึ้นอยู่กับ w ทั้งหลาย ดังนั้นปัญหาต่อไปก่อนอื่นสิ่งที่จะต้องทำก็คือหาค่า w ที่จะทำให้คำนวณค่า z ออกมาได้ใกล้เคียงกับข้อมูลที่มีอยู่มากที่สุด

การแก้หาค่า w นั้นอาจมีอยู่หลายวิธี แต่โดยทั่วไปที่เป็นพื้นฐานง่ายที่สุดก็คือคำนวณค่าผลรวมความคลาดเคลื่อนกำลังสอง (和方差, sum of squared error, SSE) ระหว่างค่าจริงกับค่าที่คำนวณได้
..(9)

ทีนี้ เมื่อเราจำเป็นต้องคำนวณค่า z สำหรับค่า x ทุกตัวที่มี โดยทั่วไปแล้วทางที่สะดวกคือเขียนอยู่ในรูปของการคูณเมทริกซ์ แบบนี้
..(10)

โดย
..(11)

โดย n คือจำนวนจุดข้อมูล

และ
..(12)

เมทริกซ์ของ Φ นี้มีชื่อเรียกเฉพาะว่าเมทริกซ์ออกแบบ (设计矩阵, design matrix)

แทนลงในสมการ (9) จะได้ว่า
..(13)

เป้าหมายคือต้องการหาค่า w ที่ทำให้ J ต่ำสุด โดยทั่วไปที่จุดต่ำสุดจะมีค่าความชันเป็น 0 ดังนั้นจึงหาอนุพันธ์ย่อยของ J เทียบ w และให้เท่ากับ 0 จะได้
..(14)

แล้วก็จะได้ว่า
..(15)

ซึ่งในรูปแบบนี้สามารถหาค่า w ได้ด้วยการแก้ระบบสมการ

ในไพธอนมีฟังก์ชัน np.linalg.solve ซึ่งใช้สำหรับแก้สมการแบบนี้ได้ทันทีอยู่แล้ว จึงดึงมาใช้ได้เลย เสร็จแล้วก็จะได้ค่า w ที่ต้องการออกมา

เสร็จแล้วพอได้ w มาก็นำมาหาค่า z ของข้อมูลใหม่ได้โดยสมการ (10) ต่อไป

สรุปขั้นตอนสำหรับการหาค่า z ใหม่ที่ไม่รู้คือ
1. คำนวณ Φ ของข้อมูล x ที่มี
2. ใช้ np.linalg.solve แก้หาค่า w
3. คำนวณ Φ ของข้อมูล x ใหม่ที่ต้องการหา
4. คำนวณ z จาก Φ ใหม่และ w

ลองเขียนโค้ดเพื่อใช้งานดูเลย ข้อมูลตัวอย่างที่ใช้นี้เป็นอันเดียวกับที่แสดงในภาพข้างบน ในที่นี้ใช้ x และ z เป็นตัวแปรของข้อมูลที่มี ส่วน x_ เป็นข้อมูลจุดใหม่ที่ต้องการหา z_ เป็นค่าของจุดใหม่ที่คำนวณได้ ส่วน Φ เขียนแทนด้วย phi
import numpy as np
import matplotlib.pyplot as plt

np.random.seed(0)
n = 70 # จำนวนข้อมูล
x = np.random.uniform(0,10,n) # จุดข้อมูลที่รู้ค่า
# ค่าของแต่ละจุด ให้เป็นฟังก์ชันพหุนามที่บวกคลื่นรบกวนเข้าไป
z = ((2*x-1)**2*(x-5)**2*(x-9)**2)*np.random.normal(1,0.1,n)/1000

# ฟังก์ชันสำหรับสร้าง Φ โดยใช้ฐานเป็นพหุนาม
def than_phahunam(x,m):
    return np.array([x**i for i in range(m)]).T

m = 8 # จำนวนฐานที่จะใช้
phi = than_phahunam(x,m)
# แก้สมการหาค่า w
w = np.linalg.solve(phi.T.dot(phi),phi.T.dot(z))

x_ = np.linspace(x.min(),x.max(),201) # ข้อมูลใหม่ที่ต้องการหาค่า
phi_ = than_phahunam(x_,m)
z_ = phi_.dot(w)

plt.axes(xlabel='x',ylabel='z')
plt.scatter(x,z,c='m',edgecolor='b')
plt.plot(x_,z_,'r')
plt.show()




กรณีที่ใช้ฐานเป็นพหุนามแบบนี้อาจเรียกว่าการวิเคราะห์การถดถอยพหุนาม (多项式回归, polynomial regression)

เพื่อความสะดวกในการใช้ ลองเขียนในรูปแบบคลาสได้ดังนี้
class ThotthoiPhahunam:
    def __init__(self,m=4):
        self.m = m

    def rianru(self,x,z):
        phi = self.than(x)
        self.w = np.linalg.solve(phi.T.dot(phi),phi.T.dot(z))

    def thamnai(self,x_):
        phi_ = self.than(x_)
        return phi_.dot(self.w)

    def than(self,x):
        return x[:,None]**np.arange(self.m)

np.random.seed(0)
n = 50
x = np.random.uniform(0,20,n)
z = ((x-1)**2*(x-10)*(x-18)**2)*np.random.normal(1,0.1,n)/1000

tt = ThotthoiPhahunam(m=8)
tt.rianru(x,z)
x_ = np.linspace(x.min(),x.max(),201)
z_ = tt.thamnai(x_)

plt.axes(xlabel='x',ylabel='z')
plt.scatter(x,z,c='m',edgecolor='b')
plt.plot(x_,z_,'r')
plt.show()



ที่จริงเรื่องการวิเคราะห์การถดถอยพหุนามเคยเขียนถึงไปแล้วใน https://phyblas.hinaboshi.com/20161219

เพียงแต่ว่ารูปแบบการเขียน และวิธีการหาค่า w ก็ต่างกัน ในนั้นใช้วิธีการเคลื่อนลงตามความชัน (梯度下降法, gradient descent) แต่ในนี้ใช้วิธีการแก้สมการหาโดยตรง



ใช้ฟังก์ชันเกาส์เป็นฐาน

นอกจากพหุนามแล้ว ฐานอีกชนิดหนึ่งที่นิยมคือใช้ฟังก์ชันเกาส์ (Gauss)

ฟังก์ชันเกาส์ เป็นฟังก์ชันที่มีค่าเป็นลักษณะระฆังคว่ำ โดยมีค่าสูงสุดอยู่ที่จุดนึง เรียกว่าเป็นจุดศูนย์กลาง และยิ่งห่างจากจุดนี้ไปก็จะมีค่าต่ำลงเรื่อยๆจนเข้าใกล้ 0 นิยามดังนี้
..(16)

โดย μ คือจุดศูนย์กลาง ส่วน γ เป็นตัวกำหนดว่าค่าจะลงลงเร็วแค่ไหนเมื่อห่างจากใจกลาง ยิ่งค่าน้อยยิ่งลดลงช้า

เมื่อนำมาใช้เป็นฟังก์ชันฐาน ฐานเกาส์อาจนิยามตามนี้
..(17)

โดย xj เป็นจุดข้อมูล นั่นคือใช้จุดข้อมูลแต่ละจุดเป็นศูนย์กลางของฐานแต่ละอัน

กรณีนี้เราจะมีจำนวนฐานเท่ากับจำนวนข้อมูลทั้งหมดที่ใช้

เขียนโปรแกรมได้ดังนี้
def gauss(x1,x2,gamma):
    return np.exp(-gamma*(x1-x2)**2)

n = 20
np.random.seed(4)
x = np.random.uniform(0,10,n)
z = ((x-1)**2*(x-9)**2)*np.random.normal(1,0.1,n)/1000

gamma = 1
phi = gauss(x,x[:,None],gamma)
w = np.linalg.solve(phi.T.dot(phi),phi.T.dot(z))

x_ = np.linspace(x.min(),x.max(),201)
phi_ = gauss(x,x_[:,None],gamma)
z_ = phi_.dot(w)

plt.axes(xlabel='x',ylabel='z')
plt.scatter(x,z,c='m',edgecolor='b')
plt.plot(x_,z_,'r')
plt.show()



ผลที่ได้จะพบว่าออกมาแปลกๆ ที่เป็นแบบนี้เพราะว่าเมื่อจำนวนฐานเท่ากับจำนวนข้อมูลจะทำให้ฟังก์ชันที่ได้สามารถลากผ่านจุดข้อมูลทั้งหมด

ซึ่งที่จริงแล้วข้อมูลโดยทั่วไปจะมีความไม่แน่นอนปนอยู่ ดังนั้นการฝืนให้เส้นลากผ่านจุดทั้งหมดนั้นจึงไม่ใช่เรื่องที่ดี มีแต่จะทำให้เกิดการเรียนรู้เกิน (过学习, overlearning) คือการที่ผลการคำนวณปรับตัวเข้ากับข้อมูลที่มีอยู่มากเกินไปแต่ไม่สามารถใช้อธิบายทุกอย่างตามธรรมชาติจริงๆได้

เมื่อเป็นแบบนี้จึงจำเป็นต้องมีการเรกูลาไรซ์ (正规化, regularize) คือการทำให้การคำนวณไม่ยึดติดกับข้อมูลที่มีมากเกินไป

เกี่ยวกับแนวคิดของการเรกูลาไรซ์เคยเขียนถึงไปแล้วใน https://phyblas.hinaboshi.com/20170928

ในที่นี้ใช้เรกูลาไรซ์แบบ l2 สมการ (13) ก็จะแก้เป็น
..(18)

ที่เปลี่ยนไปคือมีเพิ่มพจน์ทางขวาบวกเข้ามา โดย λ คือขนาดของการเรกูลาไรซ์

จากนั้นก็ให้ความชันเท่ากับ 0 เช่นเคย
..(19)

แล้วก็จะได้
..(20)

โดย I คือเมทริกซ์เอกลักษณ์ ค่า λ ที่เพิ่มเข้าไปจะมีผลทำให้ค่าแนวทแยงของเมทริกซ์มากขึ้น

สามารถเขียนโค้ดแก้ตรงที่คำนวณ w ใหม่ เป็นแบบนี้
l = 0.1
w = np.linalg.solve(phi.T.dot(phi) + np.eye(len(x))*l,phi.T.dot(z))
แล้วผลที่ได้จะกลายเป็นแบบนี้



จะเห็นว่าได้เส้นกราฟออกมาเรียบและดูสมเหตุสมผลขึ้น



สุดท้ายนี้ เพื่อความสะดวกในการใช้งานลองมาสร้างเป็นคลาสขึ้นมาดู
class ThotthoiThanGauss:
    def __init__(self,gamma=1,l=0.1):
        self.gamma = gamma
        self.l = l

    def rianru(self,x,z):
        phi = self.gauss(x,x[:,None])
        self.w = np.linalg.solve(phi.T.dot(phi)+self.l*np.eye(len(x)),phi.T.dot(z))
        self.x = x

    def thamnai(self,x_):
        phi_ = self.gauss(self.x,x_[:,None])
        return phi_.dot(self.w)

    def gauss(self,x1,x2):
        return np.exp(-self.gamma*(x1-x2)**2)

ข้อควรสังเกตอย่างหนึ่งคือ กรณีนี้จำเป็นต้องบันทึกตำแหน่งของ x ที่ป้อนเข้ามาตอนแรกไว้ด้วย เพราะต้องใช้ตำแหน่งใจกลางของฐานเกาส์ในการคำนวณ

ขอยกตัวอย่างการใช้ด้วยการเปรียบเทียบความแตกต่างระหว่างกรณีที่ λ มีค่าต่างกัน
n = 30
np.random.seed(5)
x = np.random.uniform(0,30,n)
z = ((x-5)**2*(x-15)*(x-28))*np.random.normal(1,0.2,n)/1000
x_ = np.linspace(x.min(),x.max(),301)
ll = [0,0.1,1,10]
plt.figure(figsize=[6.5,5])
for i in range(4):
    l = ll[i]
    tt = ThotthoiThanGauss(gamma=0.1,l=l)
    tt.rianru(x,z)
    z_ = tt.thamnai(x_)
    plt.subplot(2,2,i+1)
    plt.scatter(x,z,c='m',edgecolor='b')
    plt.title('$\\lambda$=%.1f'%l)
    plt.plot(x_,z_,'r')
plt.tight_layout()
plt.show()



จะเห็นได้ว่าเมื่อ λ น้อยไปจะทำให้เกิดการเรียนรู้เกินได้ง่าย แต่พอมากไปก็จะทำให้ไม่สามารถปรับเข้ากับข้อมูลได้ดีพอ ดังนั้นการเลือก λ ที่เหมาะสมจึงสำคัญ



ทั้งหมดนี้เป็นพื้นฐานของการประยุกต์ใช้ฟังก์ชันฐานในปัญหาการวิเคราะห์การถดถอย

การใช้ฟังก์ชันฐานยังเป็นพื้นฐานสำคัญที่จะต่อยอดไปยังเรื่องอื่นในเทคทิคการเรียนรู้ของเครื่อง เช่นการถดถอยเชิงเส้นแบบเบส์ (贝叶斯线性回归, Bayesian Linear Regression) รวมทั้งการใช้ลูกเล่นเคอร์เนล (核技巧, kernel trick)

เกี่ยวกับลูกเล่นเคอร์เนลอ่านได้ใน https://phyblas.hinaboshi.com/20180724



อ้างอิง


-----------------------------------------

囧囧囧囧囧囧囧囧囧囧囧囧囧囧囧囧囧囧囧囧囧囧囧囧囧

ดูสถิติของหน้านี้

หมวดหมู่

-- คอมพิวเตอร์ >> ปัญญาประดิษฐ์
-- คอมพิวเตอร์ >> เขียนโปรแกรม >> python >> numpy
-- คอมพิวเตอร์ >> เขียนโปรแกรม >> python >> matplotlib

ไม่อนุญาตให้นำเนื้อหาของบทความไปลงที่อื่นโดยไม่ได้ขออนุญาตโดยเด็ดขาด หากต้องการนำบางส่วนไปลงสามารถทำได้โดยต้องไม่ใช่การก๊อปแปะแต่ให้เปลี่ยนคำพูดเป็นของตัวเอง หรือไม่ก็เขียนในลักษณะการยกข้อความอ้างอิง และไม่ว่ากรณีไหนก็ตาม ต้องให้เครดิตพร้อมใส่ลิงก์ของทุกบทความที่มีการใช้เนื้อหาเสมอ

สารบัญ

รวมคำแปลวลีเด็ดจากญี่ปุ่น
มอดูลต่างๆ
-- numpy
-- matplotlib

-- pandas
-- manim
-- opencv
-- pyqt
-- pytorch
การเรียนรู้ของเครื่อง
-- โครงข่าย
     ประสาทเทียม
ภาษา javascript
ภาษา mongol
ภาษาศาสตร์
maya
ความน่าจะเป็น
บันทึกในญี่ปุ่น
บันทึกในจีน
-- บันทึกในปักกิ่ง
-- บันทึกในฮ่องกง
-- บันทึกในมาเก๊า
บันทึกในไต้หวัน
บันทึกในยุโรปเหนือ
บันทึกในประเทศอื่นๆ
qiita
บทความอื่นๆ

บทความแบ่งตามหมวด



ติดตามอัปเดตของบล็อกได้ที่แฟนเพจ

  ค้นหาบทความ

  บทความแนะนำ

รวมร้านราเมงในเมืองฟุกุโอกะ
ตัวอักษรกรีกและเปรียบเทียบการใช้งานในภาษากรีกโบราณและกรีกสมัยใหม่
ที่มาของอักษรไทยและความเกี่ยวพันกับอักษรอื่นๆในตระกูลอักษรพราหมี
การสร้างแบบจำลองสามมิติเป็นไฟล์ .obj วิธีการอย่างง่ายที่ไม่ว่าใครก็ลองทำได้ทันที
รวมรายชื่อนักร้องเพลงกวางตุ้ง
ภาษาจีนแบ่งเป็นสำเนียงอะไรบ้าง มีความแตกต่างกันมากแค่ไหน
ทำความเข้าใจระบอบประชาธิปไตยจากประวัติศาสตร์ความเป็นมา
เรียนรู้วิธีการใช้ regular expression (regex)
การใช้ unix shell เบื้องต้น ใน linux และ mac
g ในภาษาญี่ปุ่นออกเสียง "ก" หรือ "ง" กันแน่
ทำความรู้จักกับปัญญาประดิษฐ์และการเรียนรู้ของเครื่อง
ค้นพบระบบดาวเคราะห์ ๘ ดวง เบื้องหลังความสำเร็จคือปัญญาประดิษฐ์ (AI)
หอดูดาวโบราณปักกิ่ง ตอนที่ ๑: แท่นสังเกตการณ์และสวนดอกไม้
พิพิธภัณฑ์สถาปัตยกรรมโบราณปักกิ่ง
เที่ยวเมืองตานตง ล่องเรือในน่านน้ำเกาหลีเหนือ
ตระเวนเที่ยวตามรอยฉากของอนิเมะในญี่ปุ่น
เที่ยวชมหอดูดาวที่ฐานสังเกตการณ์ซิงหลง
ทำไมจึงไม่ควรเขียนวรรณยุกต์เวลาทับศัพท์ภาษาต่างประเทศ

บทความแต่ละเดือน

2025年

1月 2月 3月 4月
5月 6月 7月 8月
9月 10月 11月 12月

2024年

1月 2月 3月 4月
5月 6月 7月 8月
9月 10月 11月 12月

2023年

1月 2月 3月 4月
5月 6月 7月 8月
9月 10月 11月 12月

2022年

1月 2月 3月 4月
5月 6月 7月 8月
9月 10月 11月 12月

2021年

1月 2月 3月 4月
5月 6月 7月 8月
9月 10月 11月 12月

ค้นบทความเก่ากว่านั้น

ไทย

日本語

中文