使用旧版候选区寻找方式。

This commit is contained in:
xinyang
2019-05-02 20:19:17 +08:00
parent ee83a373d3
commit 2f20367677
20 changed files with 12502 additions and 12395 deletions

View File

@@ -53,7 +53,7 @@ def save_para(folder, paras):
save_bias(fp, paras[7])
STEPS = 20000
STEPS = 10000
BATCH = 30
LEARNING_RATE_BASE = 0.01
LEARNING_RATE_DECAY = 0.99
@@ -61,8 +61,6 @@ MOVING_AVERAGE_DECAY = 0.99
def train(dataset, show_bar=False):
test_images, test_labels = dataset.all_test_sets()
x = tf.placeholder(tf.float32, [None, generate.SRC_ROWS, generate.SRC_COLS, generate.SRC_CHANNELS])
y_= tf.placeholder(tf.float32, [None, forward.OUTPUT_NODES])
nodes, vars = forward.forward(x, 0.001)
@@ -104,7 +102,8 @@ def train(dataset, show_bar=False):
if i % 100 == 0:
if i % 1000 == 0:
acc = sess.run(accuracy, feed_dict={x: test_images, y_: test_labels})
test_samples, test_labels = dataset.sample_train_sets(5000)
acc = sess.run(accuracy, feed_dict={x: test_samples, y_: test_labels})
bar.set_postfix({"loss": loss_value, "acc": acc})
@@ -151,8 +150,8 @@ def train(dataset, show_bar=False):
vars_val = sess.run(vars)
save_para("/home/xinyang/Desktop/AutoAim/tools/para", vars_val)
nodes_val = sess.run(nodes, feed_dict={x:test_images})
return vars_val, nodes_val
nodes_val = sess.run(nodes, feed_dict={x:test_samples})
return vars_val, nodes_val, test_samples
if __name__ == "__main__":
@@ -160,3 +159,4 @@ if __name__ == "__main__":
dataset = generate.DataSet("/home/xinyang/Desktop/DataSets/box")
print("Finish!")
train(dataset, show_bar=True)
input("Press any key to continue...")