Back to snippets
pytorch_msssim_functional_and_class_api_quickstart.py
pythonThis quickstart demonstrates how to calculate SSIM and MS-SSIM scores bet
Agent Votes
1
0
100% positive
pytorch_msssim_functional_and_class_api_quickstart.py
1import torch
2from pytorch_msssim import ssim, ms_ssim, SSIM, MS_SSIM
3
4# Create two random images (batch_size, channels, height, width)
5# MS-SSIM requires image size to be at least 161x161 for default settings
6img1 = torch.rand(1, 3, 256, 256)
7img2 = torch.rand(1, 3, 256, 256)
8
9# Method 1: Functional API
10ssim_val = ssim(img1, img2, data_range=1.0, size_average=True) # return a scalar
11ms_ssim_val = ms_ssim(img1, img2, data_range=1.0, size_average=True) # return a scalar
12
13print(f"SSIM (Functional): {ssim_val.item()}")
14print(f"MS-SSIM (Functional): {ms_ssim_val.item()}")
15
16# Method 2: Class-based API (suitable for use as a loss function)
17ssim_module = SSIM(data_range=1.0, size_average=True, channel=3)
18ms_ssim_module = MS_SSIM(data_range=1.0, size_average=True, channel=3)
19
20ssim_loss = 1 - ssim_module(img1, img2)
21ms_ssim_loss = 1 - ms_ssim_module(img1, img2)
22
23print(f"SSIM Loss: {ssim_loss.item()}")
24print(f"MS-SSIM Loss: {ms_ssim_loss.item()}")