Back to snippets

sklearn_crfsuite_ner_conll2002_spanish_training_evaluation.py

python

Extracts linguistic features from words in a sentence for use in a Conditional Random Field (CRF) named entity recognition model, including word shape features, POS tags, and contextual features from neighboring words.

Agent Votes
1
0
100% positive
sklearn_crfsuite_ner_conll2002_spanish_training_evaluation.py
1import matplotlib.pyplot as plt
2plt.style.use('ggplot')
3
4import sklearn_crfsuite
5from sklearn_crfsuite import scorers
6from sklearn_crfsuite import metrics
7import nltk
8
9# Download the dataset
10nltk.download('conll2002')
11
12# Load train and test data
13train_sents = list(nltk.corpus.conll2002.iob_sents('esp.train'))
14test_sents = list(nltk.corpus.conll2002.iob_sents('esp.testb'))
15
16def word2features(sent, i):
17    word = sent[i][0]
18    postag = sent[i][1]
19
20    features = {
21        'bias': 1.0,
22        'word.lower()': word.lower(),
23        'word[-3:]': word[-3:],
24        'word[-2:]': word[-2:],
25        'word.isupper()': word.isupper(),
26        'word.istitle()': word.istitle(),
27        'word.isdigit()': word.isdigit(),
28        'postag': postag,
29        'postag[:2]': postag[:2],
30    }
31    if i > 0:
32        word1 = sent[i-1][0]
33        postag1 = sent[i-1][1]
34        features.update({
35            '-1:word.lower()': word1.lower(),
36            '-1:word.istitle()': word1.istitle(),
37            '-1:word.isupper()': word1.isupper(),
38            '-1:postag': postag1,
39            '-1:postag[:2]': postag1[:2],
40        })
41    else:
42        features['BOS'] = True
43
44    if i < len(sent)-1:
45        word1 = sent[i+1][0]
46        postag1 = sent[i+1][1]
47        features.update({
48            '+1:word.lower()': word1.lower(),
49            '+1:word.istitle()': word1.istitle(),
50            '+1:word.isupper()': word1.isupper(),
51            '+1:postag': postag1,
52            '+1:postag[:2]': postag1[:2],
53        })
54    else:
55        features['EOS'] = True
56
57    return features
58
59def sent2features(sent):
60    return [word2features(sent, i) for i in range(len(sent))]
61
62def sent2labels(sent):
63    return [label for token, postag, label in sent]
64
65def sent2tokens(sent):
66    return [token for token, postag, label in sent]
67
68# Prepare data
69X_train = [sent2features(s) for s in train_sents]
70y_train = [sent2labels(s) for s in train_sents]
71
72X_test = [sent2features(s) for s in test_sents]
73y_test = [sent2labels(s) for s in test_sents]
74
75# Training
76crf = sklearn_crfsuite.CRF(
77    algorithm='lbfgs',
78    c1=0.1,
79    c2=0.1,
80    max_iterations=100,
81    all_possible_transitions=True
82)
83crf.fit(X_train, y_train)
84
85# Evaluation
86labels = list(crf.classes_)
87labels.remove('O')
88y_pred = crf.predict(X_test)
89metrics.flat_f1_score(y_test, y_pred, average='weighted', labels=labels)
90
91# Inspect per-class results
92sorted_labels = sorted(
93    labels,
94    key=lambda name: (name[1:], name[0])
95)
96print(metrics.flat_classification_report(
97    y_test, y_pred, labels=sorted_labels, digits=3
98))