📚 The CoCalc Library - books, templates and other resources
License: OTHER
""" Solution for simple linear regression example using tf.data1Created by Chip Huyen ([email protected])2CS20: "TensorFlow for Deep Learning Research"3cs20.stanford.edu4Lecture 035"""6import os7os.environ['TF_CPP_MIN_LOG_LEVEL']='2'8import time910import numpy as np11import matplotlib.pyplot as plt12import tensorflow as tf1314import utils1516DATA_FILE = 'data/birth_life_2010.txt'1718# Step 1: read in the data19data, n_samples = utils.read_birth_life_data(DATA_FILE)2021# Step 2: create Dataset and iterator22dataset = tf.data.Dataset.from_tensor_slices((data[:,0], data[:,1]))2324iterator = dataset.make_initializable_iterator()25X, Y = iterator.get_next()2627# Step 3: create weight and bias, initialized to 028w = tf.get_variable('weights', initializer=tf.constant(0.0))29b = tf.get_variable('bias', initializer=tf.constant(0.0))3031# Step 4: build model to predict Y32Y_predicted = X * w + b3334# Step 5: use the square error as the loss function35loss = tf.square(Y - Y_predicted, name='loss')36# loss = utils.huber_loss(Y, Y_predicted)3738# Step 6: using gradient descent with learning rate of 0.001 to minimize loss39optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.001).minimize(loss)4041start = time.time()42with tf.Session() as sess:43# Step 7: initialize the necessary variables, in this case, w and b44sess.run(tf.global_variables_initializer())45writer = tf.summary.FileWriter('./graphs/linear_reg', sess.graph)4647# Step 8: train the model for 100 epochs48for i in range(100):49sess.run(iterator.initializer) # initialize the iterator50total_loss = 051try:52while True:53_, l = sess.run([optimizer, loss])54total_loss += l55except tf.errors.OutOfRangeError:56pass5758print('Epoch {0}: {1}'.format(i, total_loss/n_samples))5960# close the writer when you're done using it61writer.close()6263# Step 9: output the values of w and b64w_out, b_out = sess.run([w, b])65print('w: %f, b: %f' %(w_out, b_out))66print('Took: %f seconds' %(time.time() - start))6768# plot the results69plt.plot(data[:,0], data[:,1], 'bo', label='Real data')70plt.plot(data[:,0], data[:,0] * w_out + b_out, 'r', label='Predicted data with squared error')71# plt.plot(data[:,0], data[:,0] * (-5.883589) + 85.124306, 'g', label='Predicted data with Huber loss')72plt.legend()73plt.show()7475