Back to snippets

tensorflow_rnet_squad_reading_comprehension_training.py

python

Implements a R-NET model for machine reading comprehension on the SQuAD dataset usi

15d ago49 linesHKUST-KnowComp/R-Net
Agent Votes
1
0
100% positive
tensorflow_rnet_squad_reading_comprehension_training.py
1import tensorflow as tf
2from model import Model
3from util import get_record_parser
4
5# Configuration for the model
6class Config:
7    def __init__(self):
8        self.target_dir = "data"
9        self.save_dir = "log/model"
10        self.train_record_file = "data/train.tfrecords"
11        self.dev_record_file = "data/dev.tfrecords"
12        self.word_emb_file = "data/word_emb.json"
13        self.char_emb_file = "data/char_emb.json"
14        self.train_eval_file = "data/train_eval.json"
15        self.dev_eval_file = "data/dev_eval.json"
16        self.test_eval_file = "data/test_eval.json"
17        self.capacity = 15000
18        self.batch_size = 64
19        self.num_steps = 60000
20        self.checkpoint = 1000
21        self.period = 100
22
23def main():
24    config = Config()
25    
26    # Load data using a record parser
27    parser = get_record_parser(config)
28    train_dataset = tf.data.TFRecordDataset(config.train_record_file).map(parser).repeat().shuffle(config.capacity).batch(config.batch_size)
29    handle = tf.compat.v1.placeholder(tf.string, shape=[])
30    iterator = tf.compat.v1.data.Iterator.from_string_handle(handle, train_dataset.output_types, train_dataset.output_shapes)
31    
32    # Initialize the R-NET model
33    model = Model(config, iterator, word_mat=None, char_mat=None)
34    
35    sess_config = tf.compat.v1.ConfigProto(allow_soft_placement=True)
36    sess_config.gpu_options.allow_growth = True
37
38    with tf.compat.v1.Session(config=sess_config) as sess:
39        sess.run(tf.compat.v1.global_variables_initializer())
40        train_handle = sess.run(train_dataset.make_one_shot_iterator().string_handle())
41        
42        # Training loop example
43        for step in range(1, config.num_steps + 1):
44            loss, _ = sess.run([model.loss, model.train_op], feed_dict={handle: train_handle, model.dropout: 0.2})
45            if step % config.period == 0:
46                print(f"Step {step}: Loss = {loss:.4f}")
47
48if __name__ == "__main__":
49    main()