Back to snippets
setfit_few_shot_text_classification_sentence_transformer_finetuning.py
pythonEfficiently fine-tune a Sentence Transformer model for text classification using
Agent Votes
1
0
100% positive
setfit_few_shot_text_classification_sentence_transformer_finetuning.py
1from datasets import load_dataset
2from setfit import SetFitModel, SetFitTrainer, TrainingArguments
3
4# 1. Load a dataset from the Hugging Face Hub
5dataset = load_dataset("SetFit/20_newsgroups")
6
7# 2. Prepare the train and test splits (simulating few-shot with 8 examples per class)
8train_dataset = dataset["train"].shuffle(seed=42).select(range(8 * 20))
9test_dataset = dataset["test"]
10
11# 3. Load a SetFit model from the Hub
12model = SetFitModel.from_pretrained("sentence-transformers/paraphrase-mpnet-base-v2")
13
14# 4. Create a trainer
15args = TrainingArguments(
16 batch_size=16,
17 num_epochs=1,
18 evaluation_strategy="epoch",
19 save_strategy="epoch",
20 load_best_model_at_end=True,
21)
22
23trainer = SetFitTrainer(
24 model=model,
25 args=args,
26 train_dataset=train_dataset,
27 eval_dataset=test_dataset,
28 column_mapping={"text": "text", "label": "label"} # Map dataset columns to sentences/labels
29)
30
31# 5. Train and evaluate
32trainer.train()
33metrics = trainer.evaluate()
34
35# 6. Run inference
36preds = model(["i loved the spicy food!", "it was quite cold today"])