181 lines
6.1 KiB
Python
181 lines
6.1 KiB
Python
import torch.nn as nn
|
|
import torch
|
|
import numpy as np
|
|
import math
|
|
|
|
|
|
def compute_flops(module, inp, out):
|
|
if isinstance(module, nn.Conv2d):
|
|
return compute_Conv2d_flops(module, inp[0], out[0])
|
|
elif type(module).__name__ == 'ConvFunction':
|
|
return compute_Conv2d_flops(module, inp[0], out[0])
|
|
elif type(module).__name__ == 'SplitKernelConvFunction':
|
|
return compute_Conv2d_flops(module, inp[0], out[0])
|
|
elif isinstance(module, nn.ConvTranspose2d):
|
|
return compute_ConvTranspose2d_flops(module, inp[0], out[0])
|
|
elif isinstance(module, nn.BatchNorm2d):
|
|
return compute_BatchNorm2d_flops(module, inp[0], out[0])
|
|
elif isinstance(module, (nn.AvgPool2d, nn.MaxPool2d)):
|
|
return compute_Pool2d_flops(module, inp[0], out[0])
|
|
elif isinstance(module, (nn.AdaptiveAvgPool2d, nn.AdaptiveMaxPool2d)):
|
|
return compute_adaptivepool_flops(module, inp[0], out[0])
|
|
elif isinstance(module,
|
|
(nn.ReLU, nn.ReLU6, nn.PReLU, nn.ELU, nn.LeakyReLU)):
|
|
return compute_ReLU_flops(module, inp[0], out[0])
|
|
elif isinstance(module, nn.Upsample):
|
|
return compute_Upsample_flops(module, inp[0], out[0])
|
|
elif isinstance(module, nn.Linear):
|
|
return compute_Linear_flops(module, inp[0], out[0])
|
|
elif type(module).__name__ == 'MatMul':
|
|
return compute_matmul_flops(module, inp, out)
|
|
else:
|
|
#print(f"[Flops]: {type(module).__name__} is not supported!")
|
|
return 0
|
|
pass
|
|
|
|
|
|
def compute_matmul_flops(moudle, inp, out):
|
|
x, y = inp
|
|
batch_size = x.size(0)
|
|
_, l, m = x.size()
|
|
_, _, n = y.size()
|
|
return batch_size * 2 * l * m * n
|
|
|
|
|
|
def compute_Conv2d_flops(module, inp, out):
|
|
# Can have multiple inputs, getting the first one
|
|
# assert isinstance(module, nn.Conv2d)
|
|
assert len(inp.size()) == 4 and len(inp.size()) == len(out.size())
|
|
|
|
batch_size = inp.size()[0]
|
|
in_c = inp.size()[1]
|
|
k_h, k_w = module.kernel_size
|
|
out_c, out_h, out_w = out.size()[1:]
|
|
groups = module.groups
|
|
|
|
filters_per_channel = out_c // groups
|
|
conv_per_position_flops = k_h * k_w * in_c * filters_per_channel
|
|
active_elements_count = batch_size * out_h * out_w
|
|
|
|
total_conv_flops = conv_per_position_flops * active_elements_count
|
|
|
|
bias_flops = 0
|
|
if module.bias is not None:
|
|
bias_flops = out_c * active_elements_count
|
|
|
|
total_flops = total_conv_flops + bias_flops
|
|
return total_flops
|
|
|
|
|
|
def compute_ConvTranspose2d_flops(module, inp, out):
|
|
# Can have multiple inputs, getting the first one
|
|
assert isinstance(module, nn.ConvTranspose2d)
|
|
assert len(inp.size()) == 4 and len(inp.size()) == len(out.size())
|
|
|
|
batch_size = inp.size()[0]
|
|
in_h, in_w = inp.size()[2:]
|
|
|
|
k_h, k_w = module.kernel_size
|
|
in_c = module.in_channels
|
|
out_c = module.out_channels
|
|
groups = module.groups
|
|
|
|
filters_per_channel = out_c // groups
|
|
conv_per_position_flops = k_h * k_w * in_c * filters_per_channel
|
|
active_elements_count = batch_size * in_h * in_w
|
|
|
|
total_conv_flops = conv_per_position_flops * active_elements_count
|
|
|
|
bias_flops = 0
|
|
if module.bias is not None:
|
|
out_h, out_w = out.size()[2:]
|
|
bias_flops = out_c * batch_size * out_h * out_w
|
|
|
|
total_flops = total_conv_flops + bias_flops
|
|
|
|
return total_flops
|
|
|
|
|
|
def compute_adaptivepool_flops(module, input, output):
|
|
# credits: https://github.com/xternalz/SDPoint/blob/master/utils/flops.py
|
|
batch_size = input.size(0)
|
|
input_planes = input.size(1)
|
|
input_height = input.size(2)
|
|
input_width = input.size(3)
|
|
|
|
flops = 0
|
|
for i in range(output.size(2)):
|
|
y_start = int(math.floor(float(i * input_height) / output.size(2)))
|
|
y_end = int(math.ceil(float((i + 1) * input_height) / output.size(2)))
|
|
for j in range(output.size(3)):
|
|
x_start = int(math.floor(float(j * input_width) / output.size(3)))
|
|
x_end = int(
|
|
math.ceil(float((j + 1) * input_width) / output.size(3)))
|
|
|
|
flops += batch_size * input_planes * (y_end - y_start + 1) * (
|
|
x_end - x_start + 1)
|
|
return flops
|
|
|
|
|
|
def compute_BatchNorm2d_flops(module, inp, out):
|
|
assert isinstance(module, nn.BatchNorm2d)
|
|
assert len(inp.size()) == 4 and len(inp.size()) == len(out.size())
|
|
in_c, in_h, in_w = inp.size()[1:]
|
|
batch_flops = np.prod(inp.shape)
|
|
if module.affine:
|
|
batch_flops *= 2
|
|
return batch_flops
|
|
|
|
|
|
def compute_ReLU_flops(module, inp, out):
|
|
assert isinstance(module,
|
|
(nn.ReLU, nn.ReLU6, nn.PReLU, nn.ELU, nn.LeakyReLU))
|
|
batch_size = inp.size()[0]
|
|
active_elements_count = batch_size
|
|
|
|
for s in inp.size()[1:]:
|
|
active_elements_count *= s
|
|
|
|
return active_elements_count
|
|
|
|
|
|
def compute_Pool2d_flops(module, input, out):
|
|
batch_size = input.size(0)
|
|
input_planes = input.size(1)
|
|
input_height = input.size(2)
|
|
input_width = input.size(3)
|
|
kernel_size = ('int' in str(type(module.kernel_size))) and [
|
|
module.kernel_size, module.kernel_size
|
|
] or module.kernel_size
|
|
kernel_ops = kernel_size[0] * kernel_size[1]
|
|
stride = ('int' in str(type(
|
|
module.stride))) and [module.stride, module.stride] or module.stride
|
|
padding = ('int' in str(type(module.padding))) and [
|
|
module.padding, module.padding
|
|
] or module.padding
|
|
|
|
output_width = math.floor((input_width + 2 * padding[0] - kernel_size[0]) /
|
|
float(stride[0]) + 1)
|
|
output_height = math.floor(
|
|
(input_height + 2 * padding[1] - kernel_size[1]) / float(stride[0]) +
|
|
1)
|
|
return batch_size * input_planes * output_width * output_height * kernel_ops
|
|
|
|
|
|
def compute_Linear_flops(module, inp, out):
|
|
assert isinstance(module, nn.Linear)
|
|
assert len(inp.size()) == 2 and len(out.size()) == 2
|
|
batch_size = inp.size()[0]
|
|
return batch_size * inp.size()[1] * out.size()[1]
|
|
|
|
|
|
def compute_Upsample_flops(module, inp, out):
|
|
assert isinstance(module, nn.Upsample)
|
|
output_size = out[0]
|
|
batch_size = inp.size()[0]
|
|
output_elements_count = batch_size
|
|
for s in output_size.shape[1:]:
|
|
output_elements_count *= s
|
|
|
|
return output_elements_count
|