# Lab 2 Linear Regression
import tensorflow as tf
tf.set_random_seed(777) # for reproducibility
# Try to find values for W and b to compute y_data = W * x_data + b
# We know that W should be 1 and b should be 0
# But let's use TensorFlow to figure it out
W = tf.Variable(tf.random_normal([1]), name='weight')
b = tf.Variable(tf.random_normal([1]), name='bias')
# Now we can use X and Y in place of x_data and y_data
# # placeholders for a tensor that will be always fed using feed_dict
# See http://stackoverflow.com/questions/36693740/
X = tf.placeholder(tf.float32, shape=[None])
Y = tf.placeholder(tf.float32, shape=[None])
# Our hypothesis XW+b
hypothesis = X * W + b
# cost/loss function
cost = tf.reduce_mean(tf.square(hypothesis - Y))
# Minimize
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01)
train = optimizer.minimize(cost)
# Launch the graph in a session.
sess = tf.Session()
# Initializes global variables in the graph.
sess.run(tf.global_variables_initializer())
# Fit the line
for step in range(2001):
cost_val, W_val, b_val, _ = \
sess.run([cost, W, b, train],
feed_dict={X: [1, 2, 3], Y: [1, 2, 3]})
if step % 20 == 0:
print(step, cost_val, W_val, b_val)
# Learns best fit W:[ 1.], b:[ 0]
'''
...
1980 1.32962e-05 [ 1.00423515] [-0.00962736]
2000 1.20761e-05 [ 1.00403607] [-0.00917497]
'''
# Testing our model
print(sess.run(hypothesis, feed_dict={X: [5]}))
print(sess.run(hypothesis, feed_dict={X: [2.5]}))
print(sess.run(hypothesis, feed_dict={X: [1.5, 3.5]}))
'''
[ 5.0110054]
[ 2.50091505]
[ 1.49687922 3.50495124]
'''
# Fit the line with new training data
for step in range(2001):
cost_val, W_val, b_val, _ = \
sess.run([cost, W, b, train],
feed_dict={X: [1, 2, 3, 4, 5],
Y: [2.1, 3.1, 4.1, 5.1, 6.1]})
if step % 20 == 0:
print(step, cost_val, W_val, b_val)
# Testing our model
print(sess.run(hypothesis, feed_dict={X: [5]}))
print(sess.run(hypothesis, feed_dict={X: [2.5]}))
print(sess.run(hypothesis, feed_dict={X: [1.5, 3.5]}))
'''
1960 3.32396e-07 [ 1.00037301] [ 1.09865296]
1980 2.90429e-07 [ 1.00034881] [ 1.09874094]
2000 2.5373e-07 [ 1.00032604] [ 1.09882331]
[ 6.10045338]
[ 3.59963846]
[ 2.59931231 4.59996414]
'''
Linear Regression_feed
最新推荐文章于 2024-09-14 14:47:44 发布