📚 The CoCalc Library - books, templates and other resources
License: OTHER
import time12import numpy as np3import tensorflow as tf45import config67class ChatBotModel:8def __init__(self, forward_only, batch_size):9"""forward_only: if set, we do not construct the backward pass in the model.10"""11print('Initialize new model')12self.fw_only = forward_only13self.batch_size = batch_size1415def _create_placeholders(self):16# Feeds for inputs. It's a list of placeholders17print('Create placeholders')18self.encoder_inputs = [tf.placeholder(tf.int32, shape=[None], name='encoder{}'.format(i))19for i in range(config.BUCKETS[-1][0])]20self.decoder_inputs = [tf.placeholder(tf.int32, shape=[None], name='decoder{}'.format(i))21for i in range(config.BUCKETS[-1][1] + 1)]22self.decoder_masks = [tf.placeholder(tf.float32, shape=[None], name='mask{}'.format(i))23for i in range(config.BUCKETS[-1][1] + 1)]2425# Our targets are decoder inputs shifted by one (to ignore <GO> symbol)26self.targets = self.decoder_inputs[1:]2728def _inference(self):29print('Create inference')30# If we use sampled softmax, we need an output projection.31# Sampled softmax only makes sense if we sample less than vocabulary size.32if config.NUM_SAMPLES > 0 and config.NUM_SAMPLES < config.DEC_VOCAB:33w = tf.get_variable('proj_w', [config.HIDDEN_SIZE, config.DEC_VOCAB])34b = tf.get_variable('proj_b', [config.DEC_VOCAB])35self.output_projection = (w, b)3637def sampled_loss(logits, labels):38labels = tf.reshape(labels, [-1, 1])39return tf.nn.sampled_softmax_loss(weights=tf.transpose(w),40biases=b,41inputs=logits,42labels=labels,43num_sampled=config.NUM_SAMPLES,44num_classes=config.DEC_VOCAB)45self.softmax_loss_function = sampled_loss4647single_cell = tf.contrib.rnn.GRUCell(config.HIDDEN_SIZE)48self.cell = tf.contrib.rnn.MultiRNNCell([single_cell for _ in range(config.NUM_LAYERS)])4950def _create_loss(self):51print('Creating loss... \nIt might take a couple of minutes depending on how many buckets you have.')52start = time.time()53def _seq2seq_f(encoder_inputs, decoder_inputs, do_decode):54setattr(tf.contrib.rnn.GRUCell, '__deepcopy__', lambda self, _: self)55setattr(tf.contrib.rnn.MultiRNNCell, '__deepcopy__', lambda self, _: self)56return tf.contrib.legacy_seq2seq.embedding_attention_seq2seq(57encoder_inputs, decoder_inputs, self.cell,58num_encoder_symbols=config.ENC_VOCAB,59num_decoder_symbols=config.DEC_VOCAB,60embedding_size=config.HIDDEN_SIZE,61output_projection=self.output_projection,62feed_previous=do_decode)6364if self.fw_only:65self.outputs, self.losses = tf.contrib.legacy_seq2seq.model_with_buckets(66self.encoder_inputs,67self.decoder_inputs,68self.targets,69self.decoder_masks,70config.BUCKETS,71lambda x, y: _seq2seq_f(x, y, True),72softmax_loss_function=self.softmax_loss_function)73# If we use output projection, we need to project outputs for decoding.74if self.output_projection:75for bucket in range(len(config.BUCKETS)):76self.outputs[bucket] = [tf.matmul(output,77self.output_projection[0]) + self.output_projection[1]78for output in self.outputs[bucket]]79else:80self.outputs, self.losses = tf.contrib.legacy_seq2seq.model_with_buckets(81self.encoder_inputs,82self.decoder_inputs,83self.targets,84self.decoder_masks,85config.BUCKETS,86lambda x, y: _seq2seq_f(x, y, False),87softmax_loss_function=self.softmax_loss_function)88print('Time:', time.time() - start)8990def _creat_optimizer(self):91print('Create optimizer... \nIt might take a couple of minutes depending on how many buckets you have.')92with tf.variable_scope('training') as scope:93self.global_step = tf.Variable(0, dtype=tf.int32, trainable=False, name='global_step')9495if not self.fw_only:96self.optimizer = tf.train.GradientDescentOptimizer(config.LR)97trainables = tf.trainable_variables()98self.gradient_norms = []99self.train_ops = []100start = time.time()101for bucket in range(len(config.BUCKETS)):102103clipped_grads, norm = tf.clip_by_global_norm(tf.gradients(self.losses[bucket],104trainables),105config.MAX_GRAD_NORM)106self.gradient_norms.append(norm)107self.train_ops.append(self.optimizer.apply_gradients(zip(clipped_grads, trainables),108global_step=self.global_step))109print('Creating opt for bucket {} took {} seconds'.format(bucket, time.time() - start))110start = time.time()111112113def _create_summary(self):114pass115116def build_graph(self):117self._create_placeholders()118self._inference()119self._create_loss()120self._creat_optimizer()121self._create_summary()122123