反小陀螺前置代码已完成。

This commit is contained in:
xinyang
2019-07-16 16:31:58 +08:00
parent 298c4c9114
commit 73230d91f1
24 changed files with 18873 additions and 18675 deletions

View File

@@ -54,16 +54,14 @@ def save_para(folder, paras):
save_bias(fp, paras[7])
STEPS = 10000
BATCH = 30
STEPS = 20000
BATCH = 50
LEARNING_RATE_BASE = 0.005
LEARNING_RATE_DECAY = 0.99
MOVING_AVERAGE_DECAY = 0.99
def train(dataset, show_bar=False):
test_images, test_labels = dataset.all_test_sets()
x = tf.placeholder(tf.float32, [None, generate.SRC_ROWS, generate.SRC_COLS, generate.SRC_CHANNELS])
y_= tf.placeholder(tf.float32, [None, forward.OUTPUT_NODES])
keep_rate = tf.placeholder(tf.float32)
@@ -107,7 +105,8 @@ def train(dataset, show_bar=False):
)
if i % 100 == 0:
if i % 1000 == 0:
if i % 500 == 0:
test_images, test_labels = dataset.sample_test_sets(10000)
acc = sess.run(accuracy, feed_dict={x: test_images, y_: test_labels, keep_rate:1.0})
bar.set_postfix({"loss": loss_value, "acc": acc})
@@ -116,6 +115,9 @@ def train(dataset, show_bar=False):
vars_val = sess.run(vars)
save_para("/home/xinyang/Workspace/RM_auto-aim/tools/para", vars_val)
print("save done!")
# pred = sess.run(y, feed_dict={x: test_images, keep_rate:1.0})
# nodes_val = sess.run(nodes, feed_dict={x:test_images})
# return vars_val, nodes_val
DevList = mvsdk.CameraEnumerateDevice()
@@ -204,9 +206,9 @@ def train(dataset, show_bar=False):
if __name__ == "__main__":
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
# import os
# os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
# os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
dataset = generate.DataSet("/home/xinyang/Workspace/box_cut")
train(dataset, show_bar=True)
input("press enter to continue...")