[Tensorflow]构建神经网络完成线性回归分析

本文用tensorflow构建神经网络完成一个简单的例子:线性回归。

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

# 对于线性关系y = x1 + x0,x1和x0都是一维向量
input_num = 2
output_num = 1
# 用16个神经元
neurons_num = 16
batch_size = 5
epochs = 200

x_data = np.random.normal(size=(1000, 2))
w = np.array([[1], [1]])
y_data = np.dot(x_data, w).reshape((-1, 1)) + 1

X = tf.placeholder(tf.float32, [None, input_num])
y = tf.placeholder(tf.float32, [None, output_num])

weights = tf.Variable(tf.truncated_normal([input_num, 1]), dtype=tf.float32)
bias = tf.Variable(tf.constant(0.1, shape=[1]), dtype=tf.float32)

with tf.name_scope('convd'):
    y_ = tf.add(tf.matmul(X, weights), bias)
with tf.name_scope('mse'):
    mse = tf.reduce_mean(tf.square(y_ - y))
with tf.name_scope('optimizer'):
    opt = tf.train.GradientDescentOptimizer(learning_rate=0.01).minimize(mse)
init = tf.global_variables_initializer()

with tf.Session() as sess:
    # Run the initializer
    sess.run(init)

    # Fit all training data
    mse_list = []
    for epoch in range(epochs):
        x_train = x_data[epoch * batch_size: (epoch + 1) * batch_size, :]
        y_train = y_data[epoch * batch_size: (epoch + 1) * batch_size, :]
        mse_, w_, b_, o_ = sess.run([mse, weights, bias, opt], feed_dict={X: x_train, y: y_train})

        if epoch % 5 == 0:
            mse_list.append(mse_)
            print('epoch: %d' % epoch, 'mse=', mse_, 'weights=', w_, 'bias=', b_)
    plt.plot(mse_list)
    plt.show()

 

损失曲线:

 

每一步模拟结果:

epoch: 0 mse= 0.95470846 weights= [[0.37130916]
 [0.5389818 ]] bias= [0.11342223]
epoch: 5 mse= 1.6448042 weights= [[0.42115292]
 [0.5673728 ]] bias= [0.1854951]
epoch: 10 mse= 0.9690776 weights= [[0.47701502]
 [0.63785   ]] bias= [0.27959865]
epoch: 15 mse= 0.5702162 weights= [[0.55095905]
 [0.6883684 ]] bias= [0.36654383]
epoch: 20 mse= 0.38775384 weights= [[0.56151056]
 [0.71580756]] bias= [0.41816548]
epoch: 25 mse= 0.71004486 weights= [[0.60053504]
 [0.71535575]] bias= [0.48022684]
epoch: 30 mse= 0.21135831 weights= [[0.64272714]
 [0.7356484 ]] bias= [0.52237076]
epoch: 35 mse= 0.6288473 weights= [[0.70176554]
 [0.7579718 ]] bias= [0.5681559]
epoch: 40 mse= 0.14831199 weights= [[0.7259164]
 [0.7772914]] bias= [0.6122901]
epoch: 45 mse= 0.17581545 weights= [[0.75020236]
 [0.7950606 ]] bias= [0.6455283]
epoch: 50 mse= 0.31268176 weights= [[0.777244 ]
 [0.8116195]] bias= [0.67804706]
epoch: 55 mse= 0.26954082 weights= [[0.799988 ]
 [0.8274919]] bias= [0.71080756]
epoch: 60 mse= 0.110923626 weights= [[0.81554675]
 [0.84473526]] bias= [0.73851544]
epoch: 65 mse= 0.3364119 weights= [[0.8351821 ]
 [0.86834836]] bias= [0.76178485]
epoch: 70 mse= 0.11775869 weights= [[0.8496617 ]
 [0.88230723]] bias= [0.78296524]
epoch: 75 mse= 0.15410557 weights= [[0.8716818]
 [0.8911777]] bias= [0.80560607]
epoch: 80 mse= 0.06427675 weights= [[0.8916166 ]
 [0.89697725]] bias= [0.82531697]
epoch: 85 mse= 0.10565728 weights= [[0.90494376]
 [0.9080882 ]] bias= [0.84353626]
epoch: 90 mse= 0.010155492 weights= [[0.9199746 ]
 [0.91889423]] bias= [0.8615363]
epoch: 95 mse= 0.01568078 weights= [[0.9244    ]
 [0.92285496]] bias= [0.8725688]
epoch: 100 mse= 0.037834167 weights= [[0.9321953 ]
 [0.93196994]] bias= [0.88768554]
epoch: 105 mse= 0.006278408 weights= [[0.9405269]
 [0.9397645]] bias= [0.89913684]
epoch: 110 mse= 0.016463825 weights= [[0.9462303]
 [0.9466858]] bias= [0.9079533]
epoch: 115 mse= 0.016840978 weights= [[0.9486531]
 [0.9503261]] bias= [0.91491836]
epoch: 120 mse= 0.02027573 weights= [[0.95348746]
 [0.95571154]] bias= [0.92315996]
epoch: 125 mse= 0.005527997 weights= [[0.9594867]
 [0.9587776]] bias= [0.93095905]
epoch: 130 mse= 0.008891637 weights= [[0.96296716]
 [0.9639074 ]] bias= [0.9394498]
epoch: 135 mse= 0.0022668275 weights= [[0.96632165]
 [0.96931326]] bias= [0.94564474]
epoch: 140 mse= 0.0031881095 weights= [[0.9703764 ]
 [0.97241443]] bias= [0.95126355]
epoch: 145 mse= 0.001595168 weights= [[0.9707678]
 [0.9744987]] bias= [0.95516884]
epoch: 150 mse= 0.0019761706 weights= [[0.97468543]
 [0.9750065 ]] bias= [0.9587112]
epoch: 155 mse= 0.0023416674 weights= [[0.97562784]
 [0.9777103 ]] bias= [0.9612137]
epoch: 160 mse= 0.0012770665 weights= [[0.97862244]
 [0.98051095]] bias= [0.96506375]
epoch: 165 mse= 0.0005977473 weights= [[0.9810191]
 [0.9808582]] bias= [0.9682847]
epoch: 170 mse= 0.004160084 weights= [[0.98460144]
 [0.9831914 ]] bias= [0.9717908]
epoch: 175 mse= 0.0006283476 weights= [[0.9864475 ]
 [0.98375654]] bias= [0.9745427]
epoch: 180 mse= 0.0011961826 weights= [[0.9877456 ]
 [0.98501045]] bias= [0.9769379]
epoch: 185 mse= 0.0008161774 weights= [[0.9886735]
 [0.9869667]] bias= [0.9794328]
epoch: 190 mse= 0.00035172925 weights= [[0.9895241]
 [0.9880932]] bias= [0.98158604]
epoch: 195 mse= 0.00052211457 weights= [[0.9904907]
 [0.9897504]] bias= [0.98362017]

weights和bias都收敛于1,这与我们给出的模拟函数y = 1 * x0 + 1 * x1是一致的。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值