-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathflops_counter.py
88 lines (73 loc) · 2.37 KB
/
flops_counter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import torch
from ptflops import get_model_complexity_info
from calflops import calculate_flops
from deepspeed.profiling.flops_profiler import get_model_profile
from src.har_project.models.models_factory import get_model
def calculate_flops_cal(model, input_shape):
flops, macs, params = calculate_flops(
model=model,
input_shape=input_shape,
output_as_string=True,
print_detailed=False,
output_precision=4,
)
return flops, macs, params
def get_model_complexity_ptflops(model, input_shape) -> tuple:
with torch.cuda.device(0):
macs, params = get_model_complexity_info(
model,
input_shape,
as_strings=True,
print_per_layer_stat=False,
verbose=False,
)
return macs, params
def get_model_complexity_deepspeed(model, input_shape) -> tuple[str, str, str]:
"""Get model complexity using deepspeed profiler
Args:
model (torch.nn.Module): Model to evaluate
input_shape (tuple): Input shape
Returns:
tuple[str, str, str]: flops, macs, params
"""
return get_model_profile(
model=model,
input_shape=input_shape,
print_profile=True,
detailed=False,
module_depth=-1,
top_modules=1,
warm_up=10,
as_string=True,
)
if __name__ == "__main__":
# Settings
model_name = "cnn-rnn" # "MoViNetA0", "x3d_xs", "cnn-rnn"
nr_classes = 400
nr_videos = 1
frames = 50
res = 172
# Set model and input shape
model = get_model(model_name, nr_classes)
model.eval()
input_shape = (nr_videos, 3, frames, res, res) # (B, C, T, H, W)
# Calculate model complexity using ptflops
pt_macs, pt_params = get_model_complexity_ptflops(model, input_shape)
# Calculate FLOPs using calflops
cal_flops, cal_fmacs, cal_fparams = calculate_flops_cal(model, input_shape)
# Calculate model complexity using deepspeed
(
ds_flops,
ds_macs,
ds_params,
) = get_model_complexity_deepspeed(model, input_shape)
# print
print("PTFlops | MACs:%s Params:%s \n" % (cal_fmacs, cal_fparams))
print(
"calflops | FLOPS:%s MACs:%s Params:%s \n"
% (cal_flops, cal_fmacs, cal_fparams)
)
print(
"DeepSpeed| FLOPS:%s MACs:%s Params:%s \n"
% (ds_flops, ds_macs, ds_params)
)