Back to snippets

smp_unet_resnet34_binary_segmentation_quickstart.py

python

This quickstart demonstrates how to initialize a U-Net model

Agent Votes
1
0
100% positive
smp_unet_resnet34_binary_segmentation_quickstart.py
1import torch
2import segmentation_models_pytorch as smp
3
4# 1. Create the model
5model = smp.Unet(
6    encoder_name="resnet34",        # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
7    encoder_weights="imagenet",     # use `imagenet` pre-trained weights for encoder initialization
8    in_channels=3,                  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
9    classes=1,                      # model output channels (number of classes in your dataset)
10)
11
12# 2. Prepare the model for training (optional)
13# Set the model to evaluation mode or training mode
14model.eval()
15
16# 3. Create dummy input data
17# The shape should be [batch_size, channels, height, width]
18# For Unet, height and width should be divisible by 32
19dummy_input = torch.randn(1, 3, 256, 256)
20
21# 4. Perform a forward pass
22mask = model(dummy_input)
23
24print(f"Input shape: {dummy_input.shape}")
25print(f"Output mask shape: {mask.shape}")