from torchvision import models as md
# ฟังก์ชันนับจำนวนพารามิเตอร์
def nap_param(khrongkhai):
param = list(khrongkhai.parameters())
return len(param),sum([p.data.numpy().size for p in param])
nap_param(md.resnet18())
function = type(lambda _:_)
# แบบจำลองทั้งหมดสร้างจากฟังก์ชัน จึงใช้วิธีไล่ค้นหาฟังก์ชันแล้วเรียกมาดูทีละตัว
for d in dir(md):
f = getattr(md,d)
if(type(f)==function):
print('%s: %s'%(d,nap_param(f())))
alexnet: (16, 61100840)
densenet121: (364, 7978856)
densenet161: (484, 28681000)
densenet169: (508, 14149480)
densenet201: (604, 20013928)
inception_v3: (292, 27161264)
resnet101: (314, 44549160)
resnet152: (467, 60192808)
resnet18: (62, 11689512)
resnet34: (110, 21797672)
resnet50: (161, 25557032)
squeezenet1_0: (52, 1248424)
squeezenet1_1: (52, 1235496)
vgg11: (22, 132863336)
vgg11_bn: (38, 132868840)
vgg13: (26, 133047848)
vgg13_bn: (46, 133053736)
vgg16: (32, 138357544)
vgg16_bn: (58, 138365992)
vgg19: (38, 143667240)
vgg19_bn: (70, 143678248)
from torchvision.models import resnet18
print(resnet18())
...
...(ด้านบนยาวมากขอละไว้)
...
(avgpool): AvgPool2d(kernel_size=7, stride=1, padding=0)
(fc): Linear(in_features=512, out_features=1000, bias=True)
)
import torch
from torchvision.models import resnet18
import time
ha_entropy = torch.nn.CrossEntropyLoss()
class PrasatResnet18(torch.nn.Sequential):
def __init__(self,n_klum,eta=0.001,aothifueklaeo=1,fuekthangmot=1,gpu=1):
self.fuekthangmot = fuekthangmot
rn = resnet18(aothifueklaeo)
if(not fuekthangmot):
for p in rn.parameters():
p.requires_grad = False
rn.eval()
rn.fc = torch.nn.Linear(512,n_klum)
super(PrasatResnet18,self).__init__(rn)
if(gpu):
self.dev = torch.device('cuda')
self.cuda()
else:
self.dev = torch.device('cpu')
self.opt = torch.optim.Adam(rn.parameters(),lr=eta)
def rianru(self,rup_fuek,rup_truat,n_thamsam=500,ro=10):
self.khanaen = []
khanaen_sungsut = 0
t_roem = time.time()
for o in range(n_thamsam):
if(self.fuekthangmot):
self.train()
for i,(Xb,zb) in enumerate(rup_fuek):
a = self(Xb.to(self.dev))
J = ha_entropy(a,zb.to(self.dev))
J.backward()
self.opt.step()
self.opt.zero_grad()
self.eval()
khanaen = []
for Xb,zb in rup_truat:
khanaen.append(self.thamnai_(Xb.to(self.dev)).cpu()==zb)
khanaen = torch.cat(khanaen).numpy().mean()
self.khanaen.append(khanaen)
print('%d ครั้งผ่านไป ใช้เวลาไป %.1f นาที ทำนายแม่น %.4f'%(o+1,(time.time()-t_roem)/60,khanaen))
if(khanaen>khanaen_sungsut):
khanaen_sungsut = khanaen
maiphoem = 0
else:
maiphoem += 1
if(ro>0 and maiphoem>=ro):
break
def thamnai_(self,X):
return self(X).argmax(1)
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader as Dalo
import torchvision.datasets as ds
import torchvision.transforms as tf
folder_cifar10 = 'pytorchdata/cifar'
tran = tf.Compose([
tf.Resize(224),
tf.RandomHorizontalFlip(),
tf.ToTensor(),
tf.Normalize((0.485,0.456,0.406),(0.229,0.224,0.225))])
rup_fuek = ds.CIFAR10(folder_cifar10,transform=tran,train=1,download=1)
rup_fuek = Dalo(rup_fuek,batch_size=32,shuffle=True)
tran = tf.Compose([
tf.Resize(224),
tf.ToTensor(),
tf.Normalize((0.485,0.456,0.406),(0.229,0.224,0.225))])
rup_truat = ds.CIFAR10(folder_cifar10,transform=tran,train=0)
rup_truat = Dalo(rup_truat,batch_size=32)
prasat = PrasatResnet18(10,eta=0.001,aothifueklaeo=1,fuekthangmot=0,gpu=1)
prasat.rianru(rup_fuek,rup_truat,ro=10)
torch.save(prasat.state_dict(),'resnet18_cifar_param_.pkl')
prasat.load_state_dict(torch.load('resnet18_cifar_param_.pkl'))
plt.plot(prasat.khanaen)
plt.savefig('resnet18_cifar_.png')
plt.show()
prasat = PrasatResnet18(10,eta=0.001,aothifueklaeo=1,fuekthangmot=1,gpu=1)
prasat = PrasatResnet18(10,eta=0.001,aothifueklaeo=0,fuekthangmot=1,gpu=1)
ติดตามอัปเดตของบล็อกได้ที่แฟนเพจ