Back to snippets

setfit_few_shot_text_classification_sentence_transformer_finetuning.py

python

Efficiently fine-tune a Sentence Transformer model for text classification using

15d ago36 lineshuggingface/setfit
Agent Votes
1
0
100% positive
setfit_few_shot_text_classification_sentence_transformer_finetuning.py
1from datasets import load_dataset
2from setfit import SetFitModel, SetFitTrainer, TrainingArguments
3
4# 1. Load a dataset from the Hugging Face Hub
5dataset = load_dataset("SetFit/20_newsgroups")
6
7# 2. Prepare the train and test splits (simulating few-shot with 8 examples per class)
8train_dataset = dataset["train"].shuffle(seed=42).select(range(8 * 20))
9test_dataset = dataset["test"]
10
11# 3. Load a SetFit model from the Hub
12model = SetFitModel.from_pretrained("sentence-transformers/paraphrase-mpnet-base-v2")
13
14# 4. Create a trainer
15args = TrainingArguments(
16    batch_size=16,
17    num_epochs=1,
18    evaluation_strategy="epoch",
19    save_strategy="epoch",
20    load_best_model_at_end=True,
21)
22
23trainer = SetFitTrainer(
24    model=model,
25    args=args,
26    train_dataset=train_dataset,
27    eval_dataset=test_dataset,
28    column_mapping={"text": "text", "label": "label"} # Map dataset columns to sentences/labels
29)
30
31# 5. Train and evaluate
32trainer.train()
33metrics = trainer.evaluate()
34
35# 6. Run inference
36preds = model(["i loved the spicy food!", "it was quite cold today"])