参数更新。

This commit is contained in:
xinyang
2019-08-07 02:42:54 +08:00
parent 19c07c441d
commit 93d578085b
12 changed files with 31255 additions and 41148 deletions

View File

@@ -52,8 +52,8 @@ def save_para(folder, paras, names, info):
STEPS = 100000
BATCH = 40
LEARNING_RATE_BASE = 0.0002
BATCH = 50
LEARNING_RATE_BASE = 0.0005
LEARNING_RATE_DECAY = 0.99
MOVING_AVERAGE_DECAY = 0.99
@@ -98,7 +98,7 @@ def train(dataset, show_bar=False):
_, loss_value, step = sess.run(
[train_op, loss, global_step],
feed_dict={x: images_samples, y_: labels_samples, keep_rate: 0.3}
feed_dict={x: images_samples, y_: labels_samples, keep_rate: 0.4}
)
if step % 500 == 0:

View File

@@ -41,7 +41,7 @@ CONV2_OUTPUT_CHANNELS = 8
CONV3_KERNAL_SIZE = 3
# 第三层卷积输出通道数
CONV3_OUTPUT_CHANNELS = 16
CONV3_OUTPUT_CHANNELS = 12
# 第一层全连接宽度
FC1_OUTPUT_NODES = 60
@@ -77,7 +77,6 @@ def forward(x, regularizer=None, keep_rate=tf.constant(1.0)):
conv2 = tf.nn.relu(tf.nn.bias_add(conv2d(pool1, conv2_w), conv2_b))
pool2 = avg_pool_2x2(conv2)
print("conv2: ", conv2.shape)
print("pool2: ", pool2.shape)
vars.extend([conv2_w, conv2_b])
vars_name.extend(["conv2_w", "conv2_b"])
nodes.extend([conv2, pool2])