diff --git a/tools/TrainCNN/backward.py b/tools/TrainCNN/backward.py index 8503c5a..66f421c 100755 --- a/tools/TrainCNN/backward.py +++ b/tools/TrainCNN/backward.py @@ -5,6 +5,7 @@ from tqdm import tqdm import generate import forward import cv2 +import sys import numpy as np print("Finish!") @@ -61,8 +62,6 @@ MOVING_AVERAGE_DECAY = 0.99 def train(dataset, show_bar=False): - test_images, test_labels = dataset.all_test_sets() - x = tf.placeholder(tf.float32, [None, generate.SRC_ROWS, generate.SRC_COLS, generate.SRC_CHANNELS]) y_= tf.placeholder(tf.float32, [None, forward.OUTPUT_NODES]) nodes, vars = forward.forward(x, 0.001) @@ -104,15 +103,10 @@ def train(dataset, show_bar=False): if i % 100 == 0: if i % 1000 == 0: - acc = sess.run(accuracy, feed_dict={x: test_images, y_: test_labels}) + test_samples, test_labels = dataset.sample_test_sets(5000) + acc = sess.run(accuracy, feed_dict={x: test_samples, y_: test_labels}) bar.set_postfix({"loss": loss_value, "acc": acc}) - - # if show_bar: - # bar.title = "step: %d, loss: %f, acc: %f" % (step, loss_value, acc) - # bar.cursor.restore() - # bar.draw(value=i+1) - # video = cv2.VideoCapture("/home/xinyang/Desktop/Video.mp4") # _ = True # while _: @@ -148,11 +142,12 @@ def train(dataset, show_bar=False): # res = sess.run(y, feed_dict={x: im}) # res = res.reshape([forward.OUTPUT_NODES]) # print(np.argmax(res)) - + + test_samples, test_labels = dataset.sample_test_sets(100) vars_val = sess.run(vars) save_para("/home/xinyang/Desktop/AutoAim/tools/para", vars_val) - nodes_val = sess.run(nodes, feed_dict={x:test_images}) - return vars_val, nodes_val + nodes_val = sess.run(nodes, feed_dict={x:test_samples}) + return vars_val, nodes_val, test_samples if __name__ == "__main__": @@ -160,3 +155,4 @@ if __name__ == "__main__": dataset = generate.DataSet("/home/xinyang/Desktop/DataSets/box") print("Finish!") train(dataset, show_bar=True) + input("Press any key to end...") diff --git a/tools/TrainCNN/forward.py b/tools/TrainCNN/forward.py index 07ed70f..1ec3084 100644 --- a/tools/TrainCNN/forward.py +++ b/tools/TrainCNN/forward.py @@ -41,7 +41,7 @@ CONV2_OUTPUT_CHANNELS = 10 FC1_OUTPUT_NODES = 16 # 第二层全连接宽度(输出标签类型数) -FC2_OUTPUT_NODES = 4 +FC2_OUTPUT_NODES = 8 # 输出标签类型数 OUTPUT_NODES = FC2_OUTPUT_NODES diff --git a/tools/TrainCNN/generate.py b/tools/TrainCNN/generate.py index 5b2f3f4..20c2a63 100644 --- a/tools/TrainCNN/generate.py +++ b/tools/TrainCNN/generate.py @@ -67,6 +67,15 @@ class DataSet: labels.append(self.train_labels[id]) return np.array(samples), np.array(labels) + def sample_test_sets(self, length): + samples = [] + labels = [] + for i in range(length): + id = random.randint(0, len(self.test_samples)-1) + samples.append(self.test_samples[id]) + labels.append(self.test_labels[id]) + return np.array(samples), np.array(labels) + def all_train_sets(self): return self.train_samples[:], self.train_labels[:]