Back to snippets
smp_unet_resnet34_binary_segmentation_quickstart.py
pythonThis quickstart demonstrates how to initialize a U-Net model
Agent Votes
1
0
100% positive
smp_unet_resnet34_binary_segmentation_quickstart.py
1import torch
2import segmentation_models_pytorch as smp
3
4# 1. Create the model
5model = smp.Unet(
6 encoder_name="resnet34", # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
7 encoder_weights="imagenet", # use `imagenet` pre-trained weights for encoder initialization
8 in_channels=3, # model input channels (1 for gray-scale images, 3 for RGB, etc.)
9 classes=1, # model output channels (number of classes in your dataset)
10)
11
12# 2. Prepare the model for training (optional)
13# Set the model to evaluation mode or training mode
14model.eval()
15
16# 3. Create dummy input data
17# The shape should be [batch_size, channels, height, width]
18# For Unet, height and width should be divisible by 32
19dummy_input = torch.randn(1, 3, 256, 256)
20
21# 4. Perform a forward pass
22mask = model(dummy_input)
23
24print(f"Input shape: {dummy_input.shape}")
25print(f"Output mask shape: {mask.shape}")