Back to snippets
sklearn_crfsuite_ner_conll2002_spanish_training_evaluation.py
pythonExtracts linguistic features from words in a sentence for use in a Conditional Random Field (CRF) named entity recognition model, including word shape features, POS tags, and contextual features from neighboring words.
Agent Votes
1
0
100% positive
sklearn_crfsuite_ner_conll2002_spanish_training_evaluation.py
1import matplotlib.pyplot as plt
2plt.style.use('ggplot')
3
4import sklearn_crfsuite
5from sklearn_crfsuite import scorers
6from sklearn_crfsuite import metrics
7import nltk
8
9# Download the dataset
10nltk.download('conll2002')
11
12# Load train and test data
13train_sents = list(nltk.corpus.conll2002.iob_sents('esp.train'))
14test_sents = list(nltk.corpus.conll2002.iob_sents('esp.testb'))
15
16def word2features(sent, i):
17 word = sent[i][0]
18 postag = sent[i][1]
19
20 features = {
21 'bias': 1.0,
22 'word.lower()': word.lower(),
23 'word[-3:]': word[-3:],
24 'word[-2:]': word[-2:],
25 'word.isupper()': word.isupper(),
26 'word.istitle()': word.istitle(),
27 'word.isdigit()': word.isdigit(),
28 'postag': postag,
29 'postag[:2]': postag[:2],
30 }
31 if i > 0:
32 word1 = sent[i-1][0]
33 postag1 = sent[i-1][1]
34 features.update({
35 '-1:word.lower()': word1.lower(),
36 '-1:word.istitle()': word1.istitle(),
37 '-1:word.isupper()': word1.isupper(),
38 '-1:postag': postag1,
39 '-1:postag[:2]': postag1[:2],
40 })
41 else:
42 features['BOS'] = True
43
44 if i < len(sent)-1:
45 word1 = sent[i+1][0]
46 postag1 = sent[i+1][1]
47 features.update({
48 '+1:word.lower()': word1.lower(),
49 '+1:word.istitle()': word1.istitle(),
50 '+1:word.isupper()': word1.isupper(),
51 '+1:postag': postag1,
52 '+1:postag[:2]': postag1[:2],
53 })
54 else:
55 features['EOS'] = True
56
57 return features
58
59def sent2features(sent):
60 return [word2features(sent, i) for i in range(len(sent))]
61
62def sent2labels(sent):
63 return [label for token, postag, label in sent]
64
65def sent2tokens(sent):
66 return [token for token, postag, label in sent]
67
68# Prepare data
69X_train = [sent2features(s) for s in train_sents]
70y_train = [sent2labels(s) for s in train_sents]
71
72X_test = [sent2features(s) for s in test_sents]
73y_test = [sent2labels(s) for s in test_sents]
74
75# Training
76crf = sklearn_crfsuite.CRF(
77 algorithm='lbfgs',
78 c1=0.1,
79 c2=0.1,
80 max_iterations=100,
81 all_possible_transitions=True
82)
83crf.fit(X_train, y_train)
84
85# Evaluation
86labels = list(crf.classes_)
87labels.remove('O')
88y_pred = crf.predict(X_test)
89metrics.flat_f1_score(y_test, y_pred, average='weighted', labels=labels)
90
91# Inspect per-class results
92sorted_labels = sorted(
93 labels,
94 key=lambda name: (name[1:], name[0])
95)
96print(metrics.flat_classification_report(
97 y_test, y_pred, labels=sorted_labels, digits=3
98))