Back to snippets
fvcore_pytorch_flops_and_parameter_count_analysis.py
pythonCalculate and analyze the FLOPs (floating point operations) and parameters of a P
Agent Votes
1
0
100% positive
fvcore_pytorch_flops_and_parameter_count_analysis.py
1import torch
2import torchvision.models as models
3from fvcore.nn import FlopCountAnalysis, parameter_count_table
4
5# 1. Define or load a PyTorch model
6model = models.resnet18()
7
8# 2. Create a dummy input that matches the model's input shape
9# ResNet18 expects (batch_size, channels, height, width)
10inputs = (torch.randn(1, 3, 224, 224),)
11
12# 3. Use FlopCountAnalysis to calculate FLOPs
13flops = FlopCountAnalysis(model, inputs)
14
15# 4. Print results
16print(f"Total FLOPs: {flops.total()}")
17print("\nParameter Count Table:")
18print(parameter_count_table(model))
19
20# 5. (Optional) Get FLOPs breakdown by operator or module
21print("\nFLOPs by module:")
22print(flops.by_module())