from torchvision import models as md
# ฟังก์ชันนับจำนวนพารามิเตอร์
def nap_param(khrongkhai):
param = list(khrongkhai.parameters())
return len(param),sum([p.data.numpy().size for p in param])
nap_param(md.resnet18())
function = type(lambda _:_)
# แบบจำลองทั้งหมดสร้างจากฟังก์ชัน จึงใช้วิธีไล่ค้นหาฟังก์ชันแล้วเรียกมาดูทีละตัว
for d in dir(md):
f = getattr(md,d)
if(type(f)==function):
try:
print('%s: %s'%(d,nap_param(f())))
except:
0
alexnet: (16, 61100840)
convnext_base: (344, 88591464)
convnext_large: (344, 197767336)
convnext_small: (344, 50223688)
convnext_tiny: (182, 28589128)
densenet121: (364, 7978856)
densenet161: (484, 28681000)
densenet169: (508, 14149480)
densenet201: (604, 20013928)
efficientnet_b0: (213, 5288548)
efficientnet_b1: (301, 7794184)
efficientnet_b2: (301, 9109994)
efficientnet_b3: (340, 12233232)
efficientnet_b4: (418, 19341616)
efficientnet_b5: (506, 30389784)
efficientnet_b6: (584, 43040704)
efficientnet_b7: (711, 66347960)
efficientnet_v2_l: (897, 118515272)
efficientnet_v2_m: (649, 54139356)
efficientnet_v2_s: (452, 21458488)
googlenet: (187, 13004888)
inception_v3: (292, 27161264)
mnasnet0_5: (158, 2218512)
mnasnet0_75: (158, 3170208)
mnasnet1_0: (158, 4383312)
mnasnet1_3: (158, 6282256)
mobilenet_v2: (158, 3504872)
mobilenet_v3_large: (174, 5483032)
mobilenet_v3_small: (142, 2542856)
regnet_x_16gf: (215, 54278536)
regnet_x_1_6gf: (179, 9190136)
regnet_x_32gf: (224, 107811560)
regnet_x_3_2gf: (242, 15296552)
regnet_x_400mf: (215, 5495976)
regnet_x_800mf: (161, 7259656)
regnet_x_8gf: (224, 39572648)
regnet_y_128gf: (368, 644812894)
regnet_y_16gf: (251, 83590140)
regnet_y_1_6gf: (368, 11202430)
regnet_y_32gf: (277, 145046770)
regnet_y_3_2gf: (290, 19436338)
regnet_y_400mf: (225, 4344144)
regnet_y_800mf: (199, 6432512)
regnet_y_8gf: (238, 39381472)
resnet101: (314, 44549160)
resnet152: (467, 60192808)
resnet18: (62, 11689512)
resnet34: (110, 21797672)
resnet50: (161, 25557032)
resnext101_32x8d: (314, 88791336)
resnext101_64x4d: (314, 83455272)
resnext50_32x4d: (161, 25028904)
shufflenet_v2_x0_5: (170, 1366792)
shufflenet_v2_x1_0: (170, 2278604)
shufflenet_v2_x1_5: (170, 3503624)
shufflenet_v2_x2_0: (170, 7393996)
squeezenet1_0: (52, 1248424)
squeezenet1_1: (52, 1235496)
swin_b: (329, 87768224)
swin_s: (329, 49606258)
swin_t: (173, 28288354)
vgg11: (22, 132863336)
vgg11_bn: (38, 132868840)
vgg13: (26, 133047848)
vgg13_bn: (46, 133053736)
vgg16: (32, 138357544)
vgg16_bn: (58, 138365992)
vgg19: (38, 143667240)
vgg19_bn: (70, 143678248)
vit_b_16: (152, 86567656)
vit_b_32: (152, 88224232)
vit_h_14: (392, 632045800)
vit_l_16: (296, 304326632)
vit_l_32: (296, 306535400)
wide_resnet101_2: (314, 126886696)
wide_resnet50_2: (161, 68883240)
from torchvision.models import resnet18
print(resnet18())
...
...(ด้านบนยาวมากขอละไว้)
...
(avgpool): AvgPool2d(kernel_size=7, stride=1, padding=0)
(fc): Linear(in_features=512, out_features=1000, bias=True)
)
import torch
from torchvision.models import resnet18
import time
ha_entropy = torch.nn.CrossEntropyLoss()
class PrasatResnet18(torch.nn.Sequential):
def __init__(self,n_klum,eta=0.001,aothifueklaeo=1,fuekthangmot=1,gpu=1):
self.fuekthangmot = fuekthangmot
rn = resnet18(aothifueklaeo)
if(not fuekthangmot):
for p in rn.parameters():
p.requires_grad = False
rn.eval()
rn.fc = torch.nn.Linear(512,n_klum)
super(PrasatResnet18,self).__init__(rn)
if(gpu):
self.dev = torch.device('cuda')
self.cuda()
else:
self.dev = torch.device('cpu')
self.opt = torch.optim.Adam(rn.parameters(),lr=eta)
def rianru(self,rup_fuek,rup_truat,n_thamsam=500,ro=10):
self.khanaen = []
khanaen_sungsut = 0
t_roem = time.time()
for o in range(n_thamsam):
if(self.fuekthangmot):
self.train()
for i,(Xb,zb) in enumerate(rup_fuek):
a = self(Xb.to(self.dev))
J = ha_entropy(a,zb.to(self.dev))
J.backward()
self.opt.step()
self.opt.zero_grad()
self.eval()
khanaen = []
for Xb,zb in rup_truat:
khanaen.append(self.thamnai_(Xb.to(self.dev)).cpu()==zb)
khanaen = torch.cat(khanaen).numpy().mean()
self.khanaen.append(khanaen)
print('%d ครั้งผ่านไป ใช้เวลาไป %.1f นาที ทำนายแม่น %.4f'%(o+1,(time.time()-t_roem)/60,khanaen))
if(khanaen>khanaen_sungsut):
khanaen_sungsut = khanaen
maiphoem = 0
else:
maiphoem += 1
if(ro>0 and maiphoem>=ro):
break
def thamnai_(self,X):
return self(X).argmax(1)
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader as Dalo
import torchvision.datasets as ds
import torchvision.transforms as tf
folder_cifar10 = 'pytorchdata/cifar'
tran = tf.Compose([
tf.Resize(224),
tf.RandomHorizontalFlip(),
tf.ToTensor(),
tf.Normalize((0.485,0.456,0.406),(0.229,0.224,0.225))])
rup_fuek = ds.CIFAR10(folder_cifar10,transform=tran,train=1,download=1)
rup_fuek = Dalo(rup_fuek,batch_size=32,shuffle=True)
tran = tf.Compose([
tf.Resize(224),
tf.ToTensor(),
tf.Normalize((0.485,0.456,0.406),(0.229,0.224,0.225))])
rup_truat = ds.CIFAR10(folder_cifar10,transform=tran,train=0)
rup_truat = Dalo(rup_truat,batch_size=32)
prasat = PrasatResnet18(10,eta=0.001,aothifueklaeo=1,fuekthangmot=0,gpu=1)
prasat.rianru(rup_fuek,rup_truat,ro=10)
torch.save(prasat.state_dict(),'resnet18_cifar_param_.pkl')
prasat.load_state_dict(torch.load('resnet18_cifar_param_.pkl'))
plt.plot(prasat.khanaen)
plt.savefig('resnet18_cifar_.png')
plt.show()
prasat = PrasatResnet18(10,eta=0.001,aothifueklaeo=1,fuekthangmot=1,gpu=1)
prasat = PrasatResnet18(10,eta=0.001,aothifueklaeo=0,fuekthangmot=1,gpu=1)
ติดตามอัปเดตของบล็อกได้ที่แฟนเพจ