From 218a4b56c6f2ca4587fa2dbd58ee2339f7da1241 Mon Sep 17 00:00:00 2001 From: xinyang Date: Fri, 3 May 2019 11:13:17 +0800 Subject: [PATCH] =?UTF-8?q?=E5=90=8C=E6=AD=A5master=E4=B8=8A=E7=9A=84bug?= =?UTF-8?q?=20fix=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- CMakeLists.txt | 2 +- main.cpp | 35 ++++++++++++++------------- tools/TrainCNN/backward.py | 16 +++++-------- tools/TrainCNN/forward.py | 6 ++--- tools/TrainCNN/generate.py | 48 +++++++++++++++++--------------------- 5 files changed, 50 insertions(+), 57 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 6dcd4f8..3395bc8 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -3,7 +3,7 @@ CMAKE_MINIMUM_REQUIRED(VERSION 3.5) PROJECT(AutoAim) SET(CMAKE_CXX_STANDARD 11) SET(CMAKE_BUILD_TYPE RELEASE) -SET(CMAKE_CXX_FLAGS "-DPROJECT_DIR=\"\\\"${PROJECT_SOURCE_DIR}\\\"\"") +SET(CMAKE_CXX_FLAGS "-DPATH=\"\\\"${PROJECT_SOURCE_DIR}\\\"\"") FIND_PROGRAM(CCACHE_FOUND ccache) IF(CCACHE_FOUND) diff --git a/main.cpp b/main.cpp index 67bf00e..378370e 100644 --- a/main.cpp +++ b/main.cpp @@ -15,12 +15,12 @@ #include #include -#define DO_NOT_CNT_TIME +//#define DO_NOT_CNT_TIME #include -#define PATH PROJECT_DIR +#define PROJECT_DIR PATH #define ENERGY_STATE 1 -#define ARMOR_STATE 0 +#define ARMOR_STATE 0 using namespace cv; using namespace std; @@ -38,6 +38,7 @@ int main(int argc, char *argv[]){ Uart uart; thread receive(uartReceive, &uart); bool flag = true; + while (flag){ int ally_color = ALLY_RED; int energy_part_rotation = CLOCKWISE; @@ -54,8 +55,8 @@ int main(int argc, char *argv[]){ video_armor = new CameraWrapper(0); // video_energy = new CameraWrapper(1); }else { - video_armor = new VideoWrapper("/home/xinyang/Desktop/Video0.mp4"); - video_energy = new VideoWrapper("/home/xinyang/Desktop/Video0.mp4"); + video_armor = new VideoWrapper("/home/xinyang/Desktop/Video.mp4"); + video_energy = new VideoWrapper("/home/xinyang/Desktop/Video.mp4"); } if (video_armor->init()) { cout << "Video source initialization successfully." << endl; @@ -63,7 +64,7 @@ int main(int argc, char *argv[]){ Mat energy_src, armor_src; - ArmorFinder armorFinder(ENEMY_BLUE, uart, PATH"/tools/para/"); + ArmorFinder armorFinder(ENEMY_BLUE, uart, PROJECT_DIR"/tools/para/"); Energy energy(uart); energy.setAllyColor(ally_color); @@ -73,10 +74,11 @@ int main(int argc, char *argv[]){ while (ok){ CNT_TIME(WORD_LIGHT_CYAN, "Total", { - CNT_TIME(WORD_LIGHT_PURPLE, "Read", { - ok = video_armor->read(armor_src); -// ok &&= video_energy->read(energy_src); - }); + ok = video_armor->read(energy_src) && video_armor->read(armor_src); + if (show_origin) { + imshow("enery src", energy_src); + imshow("armor src", armor_src); + } if (state == ENERGY_STATE) { if (from_camera == 0) { energy.extract(energy_src); @@ -100,6 +102,7 @@ int main(int argc, char *argv[]){ return 0; } +#define RECEIVE_LOG_LEVEL LOG_NOTHING void uartReceive(Uart* uart){ char buffer[100]; @@ -109,28 +112,28 @@ void uartReceive(Uart* uart){ while((data=uart->receive()) != '\n'){ buffer[cnt++] = data; if(cnt >= 100){ -// LOGE("data receive over flow!"); - cnt = 0; + LOG(RECEIVE_LOG_LEVEL, "data receive over flow!"); + cnt = 0; } } if(cnt == 10){ if(buffer[8] == 'e'){ state = ENERGY_STATE; -// LOGM("Energy state"); + LOG(RECEIVE_LOG_LEVEL, "Energy state"); }else if(buffer[8] == 'a'){ state = ARMOR_STATE; -// LOGM("Armor state"); + LOG(RECEIVE_LOG_LEVEL, "Armor state"); } memcpy(&curr_yaw, buffer, 4); memcpy(&curr_pitch, buffer+4, 4); -// LOGM("Get yaw:%f pitch:%f", curr_yaw, curr_pitch); + LOG(RECEIVE_LOG_LEVEL, "Get yaw:%f pitch:%f", curr_yaw, curr_pitch); if(buffer[9] == 1){ if(mark == 0){ mark = 1; mark_yaw = curr_yaw; mark_pitch = curr_pitch; } -// LOGM("Marked"); + LOG(RECEIVE_LOG_LEVEL, "Marked"); } } cnt = 0; diff --git a/tools/TrainCNN/backward.py b/tools/TrainCNN/backward.py index 5ad6ec0..9d21e0a 100755 --- a/tools/TrainCNN/backward.py +++ b/tools/TrainCNN/backward.py @@ -5,6 +5,7 @@ from tqdm import tqdm import generate import forward import cv2 +import sys import numpy as np print("Finish!") @@ -53,7 +54,7 @@ def save_para(folder, paras): save_bias(fp, paras[7]) -STEPS = 10000 +STEPS = 20000 BATCH = 30 LEARNING_RATE_BASE = 0.01 LEARNING_RATE_DECAY = 0.99 @@ -102,16 +103,10 @@ def train(dataset, show_bar=False): if i % 100 == 0: if i % 1000 == 0: - test_samples, test_labels = dataset.sample_train_sets(5000) + test_samples, test_labels = dataset.sample_test_sets(5000) acc = sess.run(accuracy, feed_dict={x: test_samples, y_: test_labels}) bar.set_postfix({"loss": loss_value, "acc": acc}) - - # if show_bar: - # bar.title = "step: %d, loss: %f, acc: %f" % (step, loss_value, acc) - # bar.cursor.restore() - # bar.draw(value=i+1) - # video = cv2.VideoCapture("/home/xinyang/Desktop/Video.mp4") # _ = True # while _: @@ -147,7 +142,8 @@ def train(dataset, show_bar=False): # res = sess.run(y, feed_dict={x: im}) # res = res.reshape([forward.OUTPUT_NODES]) # print(np.argmax(res)) - + + test_samples, test_labels = dataset.sample_test_sets(100) vars_val = sess.run(vars) save_para("/home/xinyang/Desktop/AutoAim/tools/para", vars_val) nodes_val = sess.run(nodes, feed_dict={x:test_samples}) @@ -159,4 +155,4 @@ if __name__ == "__main__": dataset = generate.DataSet("/home/xinyang/Desktop/DataSets/box") print("Finish!") train(dataset, show_bar=True) - input("Press any key to continue...") + input("Press any key to end...") diff --git a/tools/TrainCNN/forward.py b/tools/TrainCNN/forward.py index bc91e6e..1ec3084 100644 --- a/tools/TrainCNN/forward.py +++ b/tools/TrainCNN/forward.py @@ -41,7 +41,7 @@ CONV2_OUTPUT_CHANNELS = 10 FC1_OUTPUT_NODES = 16 # 第二层全连接宽度(输出标签类型数) -FC2_OUTPUT_NODES = 4 +FC2_OUTPUT_NODES = 8 # 输出标签类型数 OUTPUT_NODES = FC2_OUTPUT_NODES @@ -64,8 +64,8 @@ def forward(x, regularizer=None): [CONV2_KERNAL_SIZE, CONV2_KERNAL_SIZE, CONV1_OUTPUT_CHANNELS, CONV2_OUTPUT_CHANNELS] ) conv2_b = get_bias([CONV2_OUTPUT_CHANNELS]) - conv2 = tf.nn.relu(tf.nn.bias_add(conv2d(pool1, conv2_w), conv2_b)) - pool2 = avg_pool_2x2(conv2) + conv2 = tf.nn.relu(tf.nn.bias_add(conv2d(pool1, conv2_w), conv2_b)) + pool2 = avg_pool_2x2(conv2) vars.extend([conv2_w, conv2_b]) nodes.extend([conv2, pool2]) diff --git a/tools/TrainCNN/generate.py b/tools/TrainCNN/generate.py index 24245f2..20c2a63 100644 --- a/tools/TrainCNN/generate.py +++ b/tools/TrainCNN/generate.py @@ -24,40 +24,34 @@ class DataSet: self.test_labels = [] self.generate_data_sets(folder) + def file2nparray(self, name): + try: + image = cv2.imread(name) + image = cv2.resize(image, (SRC_COLS, SRC_ROWS)) + image = image.astype(np.float32) + return image / 255.0 + except: + print(name) + sys.exit(-1) + + def id2label(self, id): + a = np.zeros([OUTPUT_NODES]) + a[id] = 1 + return a[:] + def generate_data_sets(self, folder): - def id2label(id): - a = np.zeros([OUTPUT_NODES]) - a[id] = 1 - return a[:] - - def file2nparray(name): - try: - image = cv2.imread(name) - if image.shape[0] < 15: - return None - elif image.shape[1] < 10: - return None - image = cv2.resize(image, (SRC_COLS, SRC_ROWS)) - image = image.astype(np.float32) - return image / 255.0, id2label(int(name.split("/")[-2])) - except TypeError: - print(name) - sys.exit(-1) - sets = [] for i in range(OUTPUT_NODES): dir = "%s/%d" % (folder, i) files = os.listdir(dir) for file in tqdm(files, postfix={"loading id": i}, dynamic_ncols=True): if file[-3:] == "jpg": - x = file2nparray("%s/%s" % (dir, file)) - if x is not None: - if random.random() > 0.2: - self.train_samples.append(x[0]) - self.train_labels.append(x[1]) - else: - self.test_samples.append(x[0]) - self.test_labels.append(x[1]) + if random.random() > 0.2: + self.train_samples.append(self.file2nparray("%s/%s" % (dir, file))) + self.train_labels.append(self.id2label(i)) + else: + self.test_samples.append(self.file2nparray("%s/%s" % (dir, file))) + self.test_labels.append(self.id2label(i)) self.train_samples = np.array(self.train_samples) self.train_labels = np.array(self.train_labels) self.test_samples = np.array(self.test_samples)