📚 The CoCalc Library - books, templates and other resources
License: OTHER
""" Solution for simple linear regression example using placeholders1Created by Chip Huyen ([email protected])2CS20: "TensorFlow for Deep Learning Research"3cs20.stanford.edu4Lecture 035"""6import os7os.environ['TF_CPP_MIN_LOG_LEVEL']='2'8import time910import numpy as np11import matplotlib.pyplot as plt12import tensorflow as tf1314import utils1516DATA_FILE = 'data/birth_life_2010.txt'1718# Step 1: read in data from the .txt file19data, n_samples = utils.read_birth_life_data(DATA_FILE)2021# Step 2: create placeholders for X (birth rate) and Y (life expectancy)22X = tf.placeholder(tf.float32, name='X')23Y = tf.placeholder(tf.float32, name='Y')2425# Step 3: create weight and bias, initialized to 026w = tf.get_variable('weights', initializer=tf.constant(0.0))27b = tf.get_variable('bias', initializer=tf.constant(0.0))2829# Step 4: build model to predict Y30Y_predicted = w * X + b3132# Step 5: use the squared error as the loss function33# you can use either mean squared error or Huber loss34loss = tf.square(Y - Y_predicted, name='loss')35# loss = utils.huber_loss(Y, Y_predicted)3637# Step 6: using gradient descent with learning rate of 0.001 to minimize loss38optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.001).minimize(loss)394041start = time.time()42writer = tf.summary.FileWriter('./graphs/linear_reg', tf.get_default_graph())43with tf.Session() as sess:44# Step 7: initialize the necessary variables, in this case, w and b45sess.run(tf.global_variables_initializer())4647# Step 8: train the model for 100 epochs48for i in range(100):49total_loss = 050for x, y in data:51# Session execute optimizer and fetch values of loss52_, l = sess.run([optimizer, loss], feed_dict={X: x, Y:y})53total_loss += l54print('Epoch {0}: {1}'.format(i, total_loss/n_samples))5556# close the writer when you're done using it57writer.close()5859# Step 9: output the values of w and b60w_out, b_out = sess.run([w, b])6162print('Took: %f seconds' %(time.time() - start))6364# plot the results65plt.plot(data[:,0], data[:,1], 'bo', label='Real data')66plt.plot(data[:,0], data[:,0] * w_out + b_out, 'r', label='Predicted data')67plt.legend()68plt.show()6970