最近想要计算模型的参数量和运算量(FLOPs),浏览了一些帖子,发现可以通过Python中的thop工具包来实现,但查到的资料中仅介绍了对一些简单模型(例如Resnet50)的计算,而没有考虑复杂模型的情况。研究了一番后找到了一个通过thop对复杂模型进行计算的方法,虽然不知道是否准确,但先进行一下记录。
在正式介绍thop之前,先讲一下torch的内置方法。如果仅仅想要知道模型中参数的数量,那么无需安装thop,通过调用torch内置的一些方法就可以实现,代码如下:
total = sum([param.nelement() for param in model.parameters()])
print("parameter:%fM" % (total/1e6)) 其中model是待计算模型的对象。
(1)安装:通过pip安装即可,命令如下
pip install thop (2)简单模型的参数量与FLOPs计算
from torchvision.models import resnet50
from thop import profile
model = resnet50()
input = torch.randn(1, 3, 224, 224)
flops, params = profile(model, inputs=(input, ))
print("FLOPs=", str(flops/1e9) + '{}'.format("G"))
print("params=", str(params/1e6) + '{}'.format("M")) 上面的代码摘自官方文档。在调用thop工具包时,需要将模型的输入作为参数传入。在上面的代码中,第4行就是随机生成了一个模型的输入,(1,3,224,224)是一个四维矩阵,其含义是batch size = 1,通道数为3,分辨率为224*224。
从上面的介绍可以发现,thop只能计算输入为一个矩阵的模型的参数量和FLOPs,而我们实际想要计算的模型可能要复杂的多。例如我想要计算的一个模型需要接收五个矩阵作为输入,调用模型的代码为:
out = model(detections,boxs,grids,masks,captions) 其中作为模型传入参数的是五个矩阵,其尺寸为:
detections.shape = (bs,50,2048)
boxs.shape = (bs,50,4)
grids.shape = (bs,49,2048)
masks.shape = (bs,50,49)
captions.shape = (bs,19) 想要计算这样一个模型的参数量和计算量,需要考虑一种新的方式。我在这里给出的方案是:构建一个新的,接收单矩阵输入的类。代码如下:
import torch.nn as nn
class func(nn.Module):
def __init__(self,object):
super(func,self).__init__()
self.model = object
def forward(self,a):
detections = torch.randn(1,50,2048).float()
boxs = torch.randn(1,50,4).float()
grids = torch.randn(1,49,2048).float()
masks = torch.randn(1,50,49).float()
captions = torch.tensor([0,100,105,2,1,1,1,1,
1,1,1,1,1,1,1,1,1,1,1]).unsqueeze(0)
print(captions.shape)
out = self.model(detections,boxs,grids,masks,captions)
return out 可以看到,该类继承自nn.Module,以确保该类可以被声明为一个模型。接着在类的初始化函数中,将待计算模型的对象作为参数传入,并赋值给self.model;在forward函数中,规定其必须接收一个参数(实际上我们并不会使用这个接收到的参数),并通过torch内置的随机函数产生需要形状的张量(之所以特殊对待captions是因为model要求该参数的元素为整型)。
最后,就可以进行参数量和FLOPs的计算,代码如下:
model = Transformer(text_field.vocab.stoi['<bos>'],
encoder, decoder, args=args)
use_model = func(model)
input = torch.randn(1, 3, 224, 224)
flops, params = profile(use_model, inputs=(input))
print("FLOPs=", str(flops/1e9) + '{}'.format("G"))
print("params=", str(params/1e6) + '{}'.format("M")) 其中,第一行声明Transformer类,并将其对象命名为model;
第二行声明上面定义的func类,将model作为声明类时的传入参数;
第三行随机生成一个矩阵
第四行调用thop中的profile方法对func类的参数量和运算量进行计算,实际上等价于对Transformer类的参数量和运算量计算。
(完)