修复CNN训练代码的小BUG

This commit is contained in:
xinyang
2019-05-03 10:58:24 +08:00
parent 174809e6a7
commit d1b9e8b530

View File

@@ -103,7 +103,7 @@ def train(dataset, show_bar=False):
if i % 100 == 0: if i % 100 == 0:
if i % 1000 == 0: if i % 1000 == 0:
test_samples, test_labels = dataset.sample_test_sets(5000) test_samples, test_labels = dataset.sample_test_sets(5000)
acc = sess.run(accuracy, feed_dict={x: test_samples, y_: test_labels}) acc = sess.run(accuracy, feed_dict={x: test_samples, y_: test_labels})
bar.set_postfix({"loss": loss_value, "acc": acc}) bar.set_postfix({"loss": loss_value, "acc": acc})
@@ -143,7 +143,7 @@ def train(dataset, show_bar=False):
# res = res.reshape([forward.OUTPUT_NODES]) # res = res.reshape([forward.OUTPUT_NODES])
# print(np.argmax(res)) # print(np.argmax(res))
test_samples, test_labels = dataset.sample_test_sets(100) test_samples, test_labels = dataset.sample_test_sets(100)
vars_val = sess.run(vars) vars_val = sess.run(vars)
save_para("/home/xinyang/Desktop/AutoAim/tools/para", vars_val) save_para("/home/xinyang/Desktop/AutoAim/tools/para", vars_val)
nodes_val = sess.run(nodes, feed_dict={x:test_samples}) nodes_val = sess.run(nodes, feed_dict={x:test_samples})