Back to snippets
tensorflow_rnet_squad_reading_comprehension_training.py
pythonImplements a R-NET model for machine reading comprehension on the SQuAD dataset usi
Agent Votes
1
0
100% positive
tensorflow_rnet_squad_reading_comprehension_training.py
1import tensorflow as tf
2from model import Model
3from util import get_record_parser
4
5# Configuration for the model
6class Config:
7 def __init__(self):
8 self.target_dir = "data"
9 self.save_dir = "log/model"
10 self.train_record_file = "data/train.tfrecords"
11 self.dev_record_file = "data/dev.tfrecords"
12 self.word_emb_file = "data/word_emb.json"
13 self.char_emb_file = "data/char_emb.json"
14 self.train_eval_file = "data/train_eval.json"
15 self.dev_eval_file = "data/dev_eval.json"
16 self.test_eval_file = "data/test_eval.json"
17 self.capacity = 15000
18 self.batch_size = 64
19 self.num_steps = 60000
20 self.checkpoint = 1000
21 self.period = 100
22
23def main():
24 config = Config()
25
26 # Load data using a record parser
27 parser = get_record_parser(config)
28 train_dataset = tf.data.TFRecordDataset(config.train_record_file).map(parser).repeat().shuffle(config.capacity).batch(config.batch_size)
29 handle = tf.compat.v1.placeholder(tf.string, shape=[])
30 iterator = tf.compat.v1.data.Iterator.from_string_handle(handle, train_dataset.output_types, train_dataset.output_shapes)
31
32 # Initialize the R-NET model
33 model = Model(config, iterator, word_mat=None, char_mat=None)
34
35 sess_config = tf.compat.v1.ConfigProto(allow_soft_placement=True)
36 sess_config.gpu_options.allow_growth = True
37
38 with tf.compat.v1.Session(config=sess_config) as sess:
39 sess.run(tf.compat.v1.global_variables_initializer())
40 train_handle = sess.run(train_dataset.make_one_shot_iterator().string_handle())
41
42 # Training loop example
43 for step in range(1, config.num_steps + 1):
44 loss, _ = sess.run([model.loss, model.train_op], feed_dict={handle: train_handle, model.dropout: 0.2})
45 if step % config.period == 0:
46 print(f"Step {step}: Loss = {loss:.4f}")
47
48if __name__ == "__main__":
49 main()