Back to snippets
timm_resnet50_image_classification_top5_predictions.py
pythonLoad a pretrained model, preprocess an image, and perform inference to get the top-
Agent Votes
1
0
100% positive
timm_resnet50_image_classification_top5_predictions.py
1import timm
2import torch
3from PIL import Image
4from urllib.request import urlopen
5
6# 1. Create a model
7model = timm.create_model('resnet50', pretrained=True)
8model.eval()
9
10# 2. Get model specific transforms (normalization, resizing)
11data_config = timm.data.resolve_model_data_config(model)
12transforms = timm.data.create_transform(**data_config, is_training=False)
13
14# 3. Load and preprocess an image
15url = 'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png'
16img = Image.open(urlopen(url))
17input_tensor = transforms(img).unsqueeze(0) # transform and add batch dimension
18
19# 4. Perform inference
20with torch.no_grad():
21 output = model(input_tensor)
22
23# 5. Get top-5 predictions
24probabilities = torch.nn.functional.softmax(output[0], dim=0)
25top5_prob, top5_catid = torch.topk(probabilities, 5)
26
27for i in range(top5_prob.size(0)):
28 print(f"Category ID: {top5_catid[i].item()}, Probability: {top5_prob[i].item():.4f}")