163 lines
4.9 KiB
Python
163 lines
4.9 KiB
Python
"""
|
|
compute Multiply-Adds(MAdd) of each leaf module
|
|
"""
|
|
|
|
import torch.nn as nn
|
|
|
|
|
|
def compute_Conv2d_madd(module, inp, out):
|
|
assert isinstance(module, nn.Conv2d)
|
|
assert len(inp.size()) == 4 and len(inp.size()) == len(out.size())
|
|
|
|
in_c = inp.size()[1]
|
|
k_h, k_w = module.kernel_size
|
|
out_c, out_h, out_w = out.size()[1:]
|
|
groups = module.groups
|
|
|
|
# ops per output element
|
|
kernel_mul = k_h * k_w * (in_c // groups)
|
|
kernel_add = kernel_mul - 1 + (0 if module.bias is None else 1)
|
|
|
|
kernel_mul_group = kernel_mul * out_h * out_w * (out_c // groups)
|
|
kernel_add_group = kernel_add * out_h * out_w * (out_c // groups)
|
|
|
|
total_mul = kernel_mul_group * groups
|
|
total_add = kernel_add_group * groups
|
|
|
|
return total_mul + total_add
|
|
|
|
|
|
def compute_ConvTranspose2d_madd(module, inp, out):
|
|
assert isinstance(module, nn.ConvTranspose2d)
|
|
assert len(inp.size()) == 4 and len(inp.size()) == len(out.size())
|
|
|
|
in_c, in_h, in_w = inp.size()[1:]
|
|
k_h, k_w = module.kernel_size
|
|
out_c, out_h, out_w = out.size()[1:]
|
|
groups = module.groups
|
|
|
|
kernel_mul = k_h * k_w * (in_c // groups)
|
|
kernel_add = kernel_mul - 1 + (0 if module.bias is None else 1)
|
|
|
|
kernel_mul_group = kernel_mul * in_h * in_w * (out_c // groups)
|
|
kernel_add_group = kernel_add * in_h * in_w * (out_c // groups)
|
|
|
|
total_mul = kernel_mul_group * groups
|
|
total_add = kernel_add_group * groups
|
|
|
|
return total_mul + total_add
|
|
|
|
|
|
def compute_BatchNorm2d_madd(module, inp, out):
|
|
assert isinstance(module, nn.BatchNorm2d)
|
|
assert len(inp.size()) == 4 and len(inp.size()) == len(out.size())
|
|
|
|
in_c, in_h, in_w = inp.size()[1:]
|
|
|
|
# 1. sub mean
|
|
# 2. div standard deviation
|
|
# 3. mul alpha
|
|
# 4. add beta
|
|
return 4 * in_c * in_h * in_w
|
|
|
|
|
|
def compute_MaxPool2d_madd(module, inp, out):
|
|
assert isinstance(module, nn.MaxPool2d)
|
|
assert len(inp.size()) == 4 and len(inp.size()) == len(out.size())
|
|
|
|
if isinstance(module.kernel_size, (tuple, list)):
|
|
k_h, k_w = module.kernel_size
|
|
else:
|
|
k_h, k_w = module.kernel_size, module.kernel_size
|
|
out_c, out_h, out_w = out.size()[1:]
|
|
|
|
return (k_h * k_w - 1) * out_h * out_w * out_c
|
|
|
|
|
|
def compute_AvgPool2d_madd(module, inp, out):
|
|
assert isinstance(module, nn.AvgPool2d)
|
|
assert len(inp.size()) == 4 and len(inp.size()) == len(out.size())
|
|
|
|
if isinstance(module.kernel_size, (tuple, list)):
|
|
k_h, k_w = module.kernel_size
|
|
else:
|
|
k_h, k_w = module.kernel_size, module.kernel_size
|
|
out_c, out_h, out_w = out.size()[1:]
|
|
|
|
kernel_add = k_h * k_w - 1
|
|
kernel_avg = 1
|
|
|
|
return (kernel_add + kernel_avg) * (out_h * out_w) * out_c
|
|
|
|
|
|
def compute_ReLU_madd(module, inp, out):
|
|
assert isinstance(module, (nn.ReLU, nn.ReLU6))
|
|
|
|
count = 1
|
|
for i in inp.size()[1:]:
|
|
count *= i
|
|
return count
|
|
|
|
|
|
def compute_Softmax_madd(module, inp, out):
|
|
assert isinstance(module, nn.Softmax)
|
|
assert len(inp.size()) > 1
|
|
|
|
count = 1
|
|
for s in inp.size()[1:]:
|
|
count *= s
|
|
exp = count
|
|
add = count - 1
|
|
div = count
|
|
return exp + add + div
|
|
|
|
|
|
def compute_Linear_madd(module, inp, out):
|
|
assert isinstance(module, nn.Linear)
|
|
assert len(inp.size()) == 2 and len(out.size()) == 2
|
|
|
|
num_in_features = inp.size()[1]
|
|
num_out_features = out.size()[1]
|
|
|
|
mul = num_in_features
|
|
add = num_in_features - 1
|
|
return num_out_features * (mul + add)
|
|
|
|
|
|
def compute_Bilinear_madd(module, inp1, inp2, out):
|
|
assert isinstance(module, nn.Bilinear)
|
|
assert len(inp1.size()) == 2 and len(inp2.size()) == 2 and len(
|
|
out.size()) == 2
|
|
|
|
num_in_features_1 = inp1.size()[1]
|
|
num_in_features_2 = inp2.size()[1]
|
|
num_out_features = out.size()[1]
|
|
|
|
mul = num_in_features_1 * num_in_features_2 + num_in_features_2
|
|
add = num_in_features_1 * num_in_features_2 + num_in_features_2 - 1
|
|
return num_out_features * (mul + add)
|
|
|
|
|
|
def compute_madd(module, inp, out):
|
|
if isinstance(module, nn.Conv2d):
|
|
return compute_Conv2d_madd(module, inp[0], out[0])
|
|
elif isinstance(module, nn.ConvTranspose2d):
|
|
return compute_ConvTranspose2d_madd(module, inp[0], out[0])
|
|
elif isinstance(module, nn.BatchNorm2d):
|
|
return compute_BatchNorm2d_madd(module, inp[0], out[0])
|
|
elif isinstance(module, nn.MaxPool2d):
|
|
return compute_MaxPool2d_madd(module, inp[0], out[0])
|
|
elif isinstance(module, nn.AvgPool2d):
|
|
return compute_AvgPool2d_madd(module, inp[0], out[0])
|
|
elif isinstance(module, (nn.ReLU, nn.ReLU6)):
|
|
return compute_ReLU_madd(module, inp[0], out[0])
|
|
elif isinstance(module, nn.Softmax):
|
|
return compute_Softmax_madd(module, inp[0], out[0])
|
|
elif isinstance(module, nn.Linear):
|
|
return compute_Linear_madd(module, inp[0], out[0])
|
|
elif isinstance(module, nn.Bilinear):
|
|
return compute_Bilinear_madd(module, inp[0], inp[1], out)
|
|
else:
|
|
#print(f"[MAdd]: {type(module).__name__} is not supported!")
|
|
return 0
|