更加简单的使用方深度学习

tensorflow_权重更新

Posted on By duimu

流程:

import tensorflow as tf
import numpy as np

#learning rate
eta=0.4
#minimum error accepted
episilon=0.001
#max epochs
max_epochs=100

#权重更新公式
"""
dW=eta*tf.matmul(tf.transpose(X),(y-y1))
"""
T=1.
F=0.

x_in=[
    [T,T,T,T],
    [T,T,F,T],
    [T,F,T,T],
    [T,F,F,T],
    [F,T,T,T],
    [F,T,F,T],
    [F,F,T,T],
    [F,F,F,T]
]

y=[ [T],
    [T],
    [F],
    [F],
    [T],
    [F],
    [F],
    [F]
]
def threshold(x):
    cond=tf.less(x,tf.zeros(tf.shape(x),dtype=x.dtype))
    out=tf.where(cond,tf.zeros(tf.shape(x)),tf.ones(tf.shape(x)))
    return out


W=tf.Variable(tf.random_normal([4,1],stddev=2,seed=0))
h=tf.matmul(x_in,W)
y1=threshold(h)
error=y-y1
mean_error=tf.reduce_mean(tf.square(error))
dW=eta*tf.matmul(x_in,error,transpose_a=True)

train=tf.assign(W,W+dW)
init=tf.global_variables_initializer()
err=1.
epoch=1
with tf.Session() as sess:
    sess.run(init)
    while err>episilon and epoch<max_epochs:
        epoch+=1
        err,_=sess.run([mean_error,train])
        print('epoch:{0} mean error:{1}'.format(epoch,err))
        print(sess.run(W))


----------------------------
epoch:2 mean error:0.625
[[-1.5983152 ]
 [ 3.8088784 ]
 [-0.45785552]
 [-0.9069637 ]]
epoch:3 mean error:0.125
[[-1.5983152 ]
 [ 3.4088783 ]
 [-0.45785552]
 [-1.3069637 ]]
epoch:4 mean error:0.125
[[-1.5983152 ]
 [ 3.0088782 ]
 [-0.45785552]
 [-1.7069637 ]]
epoch:5 mean error:0.375
[[-0.7983152 ]
 [ 3.4088783 ]
 [-0.05785552]
 [-1.3069637 ]]
epoch:6 mean error:0.125
[[-0.7983152 ]
 [ 3.0088782 ]
 [-0.05785552]
 [-1.7069637 ]]
epoch:7 mean error:0.125
[[-0.7983152 ]
 [ 2.6088781 ]
 [-0.05785552]
 [-2.1069636 ]]
epoch:8 mean error:0.375
[[ 1.6847849e-03]
 [ 3.0088782e+00]
 [ 3.4214449e-01]
 [-1.7069637e+00]]
epoch:9 mean error:0.125
[[ 1.6847849e-03]
 [ 2.6088781e+00]
 [ 3.4214449e-01]
 [-2.1069636e+00]]
epoch:10 mean error:0.125
[[ 1.6847849e-03]
 [ 2.2088780e+00]
 [ 3.4214449e-01]
 [-2.5069637e+00]]
epoch:11 mean error:0.125
[[ 0.4016848]
 [ 2.6088781]
 [ 0.3421445]
 [-2.1069636]]
epoch:12 mean error:0.125
[[ 0.4016848]
 [ 2.208878 ]
 [ 0.3421445]
 [-2.5069637]]
epoch:13 mean error:0.0
[[ 0.4016848]
 [ 2.208878 ]
 [ 0.3421445]
 [-2.5069637]]