同步master上的bug fix。

This commit is contained in:
xinyang
2019-05-03 11:13:17 +08:00
parent 64d2554069
commit 218a4b56c6
5 changed files with 50 additions and 57 deletions

View File

@@ -3,7 +3,7 @@ CMAKE_MINIMUM_REQUIRED(VERSION 3.5)
PROJECT(AutoAim) PROJECT(AutoAim)
SET(CMAKE_CXX_STANDARD 11) SET(CMAKE_CXX_STANDARD 11)
SET(CMAKE_BUILD_TYPE RELEASE) SET(CMAKE_BUILD_TYPE RELEASE)
SET(CMAKE_CXX_FLAGS "-DPROJECT_DIR=\"\\\"${PROJECT_SOURCE_DIR}\\\"\"") SET(CMAKE_CXX_FLAGS "-DPATH=\"\\\"${PROJECT_SOURCE_DIR}\\\"\"")
FIND_PROGRAM(CCACHE_FOUND ccache) FIND_PROGRAM(CCACHE_FOUND ccache)
IF(CCACHE_FOUND) IF(CCACHE_FOUND)

View File

@@ -15,10 +15,10 @@
#include <options/options.h> #include <options/options.h>
#include <thread> #include <thread>
#define DO_NOT_CNT_TIME //#define DO_NOT_CNT_TIME
#include <log.h> #include <log.h>
#define PATH PROJECT_DIR #define PROJECT_DIR PATH
#define ENERGY_STATE 1 #define ENERGY_STATE 1
#define ARMOR_STATE 0 #define ARMOR_STATE 0
@@ -38,6 +38,7 @@ int main(int argc, char *argv[]){
Uart uart; Uart uart;
thread receive(uartReceive, &uart); thread receive(uartReceive, &uart);
bool flag = true; bool flag = true;
while (flag){ while (flag){
int ally_color = ALLY_RED; int ally_color = ALLY_RED;
int energy_part_rotation = CLOCKWISE; int energy_part_rotation = CLOCKWISE;
@@ -54,8 +55,8 @@ int main(int argc, char *argv[]){
video_armor = new CameraWrapper(0); video_armor = new CameraWrapper(0);
// video_energy = new CameraWrapper(1); // video_energy = new CameraWrapper(1);
}else { }else {
video_armor = new VideoWrapper("/home/xinyang/Desktop/Video0.mp4"); video_armor = new VideoWrapper("/home/xinyang/Desktop/Video.mp4");
video_energy = new VideoWrapper("/home/xinyang/Desktop/Video0.mp4"); video_energy = new VideoWrapper("/home/xinyang/Desktop/Video.mp4");
} }
if (video_armor->init()) { if (video_armor->init()) {
cout << "Video source initialization successfully." << endl; cout << "Video source initialization successfully." << endl;
@@ -63,7 +64,7 @@ int main(int argc, char *argv[]){
Mat energy_src, armor_src; Mat energy_src, armor_src;
ArmorFinder armorFinder(ENEMY_BLUE, uart, PATH"/tools/para/"); ArmorFinder armorFinder(ENEMY_BLUE, uart, PROJECT_DIR"/tools/para/");
Energy energy(uart); Energy energy(uart);
energy.setAllyColor(ally_color); energy.setAllyColor(ally_color);
@@ -73,10 +74,11 @@ int main(int argc, char *argv[]){
while (ok){ while (ok){
CNT_TIME(WORD_LIGHT_CYAN, "Total", { CNT_TIME(WORD_LIGHT_CYAN, "Total", {
CNT_TIME(WORD_LIGHT_PURPLE, "Read", { ok = video_armor->read(energy_src) && video_armor->read(armor_src);
ok = video_armor->read(armor_src); if (show_origin) {
// ok &&= video_energy->read(energy_src); imshow("enery src", energy_src);
}); imshow("armor src", armor_src);
}
if (state == ENERGY_STATE) { if (state == ENERGY_STATE) {
if (from_camera == 0) { if (from_camera == 0) {
energy.extract(energy_src); energy.extract(energy_src);
@@ -100,6 +102,7 @@ int main(int argc, char *argv[]){
return 0; return 0;
} }
#define RECEIVE_LOG_LEVEL LOG_NOTHING
void uartReceive(Uart* uart){ void uartReceive(Uart* uart){
char buffer[100]; char buffer[100];
@@ -109,28 +112,28 @@ void uartReceive(Uart* uart){
while((data=uart->receive()) != '\n'){ while((data=uart->receive()) != '\n'){
buffer[cnt++] = data; buffer[cnt++] = data;
if(cnt >= 100){ if(cnt >= 100){
// LOGE("data receive over flow!"); LOG(RECEIVE_LOG_LEVEL, "data receive over flow!");
cnt = 0; cnt = 0;
} }
} }
if(cnt == 10){ if(cnt == 10){
if(buffer[8] == 'e'){ if(buffer[8] == 'e'){
state = ENERGY_STATE; state = ENERGY_STATE;
// LOGM("Energy state"); LOG(RECEIVE_LOG_LEVEL, "Energy state");
}else if(buffer[8] == 'a'){ }else if(buffer[8] == 'a'){
state = ARMOR_STATE; state = ARMOR_STATE;
// LOGM("Armor state"); LOG(RECEIVE_LOG_LEVEL, "Armor state");
} }
memcpy(&curr_yaw, buffer, 4); memcpy(&curr_yaw, buffer, 4);
memcpy(&curr_pitch, buffer+4, 4); memcpy(&curr_pitch, buffer+4, 4);
// LOGM("Get yaw:%f pitch:%f", curr_yaw, curr_pitch); LOG(RECEIVE_LOG_LEVEL, "Get yaw:%f pitch:%f", curr_yaw, curr_pitch);
if(buffer[9] == 1){ if(buffer[9] == 1){
if(mark == 0){ if(mark == 0){
mark = 1; mark = 1;
mark_yaw = curr_yaw; mark_yaw = curr_yaw;
mark_pitch = curr_pitch; mark_pitch = curr_pitch;
} }
// LOGM("Marked"); LOG(RECEIVE_LOG_LEVEL, "Marked");
} }
} }
cnt = 0; cnt = 0;

View File

@@ -5,6 +5,7 @@ from tqdm import tqdm
import generate import generate
import forward import forward
import cv2 import cv2
import sys
import numpy as np import numpy as np
print("Finish!") print("Finish!")
@@ -53,7 +54,7 @@ def save_para(folder, paras):
save_bias(fp, paras[7]) save_bias(fp, paras[7])
STEPS = 10000 STEPS = 20000
BATCH = 30 BATCH = 30
LEARNING_RATE_BASE = 0.01 LEARNING_RATE_BASE = 0.01
LEARNING_RATE_DECAY = 0.99 LEARNING_RATE_DECAY = 0.99
@@ -102,16 +103,10 @@ def train(dataset, show_bar=False):
if i % 100 == 0: if i % 100 == 0:
if i % 1000 == 0: if i % 1000 == 0:
test_samples, test_labels = dataset.sample_train_sets(5000) test_samples, test_labels = dataset.sample_test_sets(5000)
acc = sess.run(accuracy, feed_dict={x: test_samples, y_: test_labels}) acc = sess.run(accuracy, feed_dict={x: test_samples, y_: test_labels})
bar.set_postfix({"loss": loss_value, "acc": acc}) bar.set_postfix({"loss": loss_value, "acc": acc})
# if show_bar:
# bar.title = "step: %d, loss: %f, acc: %f" % (step, loss_value, acc)
# bar.cursor.restore()
# bar.draw(value=i+1)
# video = cv2.VideoCapture("/home/xinyang/Desktop/Video.mp4") # video = cv2.VideoCapture("/home/xinyang/Desktop/Video.mp4")
# _ = True # _ = True
# while _: # while _:
@@ -148,6 +143,7 @@ def train(dataset, show_bar=False):
# res = res.reshape([forward.OUTPUT_NODES]) # res = res.reshape([forward.OUTPUT_NODES])
# print(np.argmax(res)) # print(np.argmax(res))
test_samples, test_labels = dataset.sample_test_sets(100)
vars_val = sess.run(vars) vars_val = sess.run(vars)
save_para("/home/xinyang/Desktop/AutoAim/tools/para", vars_val) save_para("/home/xinyang/Desktop/AutoAim/tools/para", vars_val)
nodes_val = sess.run(nodes, feed_dict={x:test_samples}) nodes_val = sess.run(nodes, feed_dict={x:test_samples})
@@ -159,4 +155,4 @@ if __name__ == "__main__":
dataset = generate.DataSet("/home/xinyang/Desktop/DataSets/box") dataset = generate.DataSet("/home/xinyang/Desktop/DataSets/box")
print("Finish!") print("Finish!")
train(dataset, show_bar=True) train(dataset, show_bar=True)
input("Press any key to continue...") input("Press any key to end...")

View File

@@ -41,7 +41,7 @@ CONV2_OUTPUT_CHANNELS = 10
FC1_OUTPUT_NODES = 16 FC1_OUTPUT_NODES = 16
# 第二层全连接宽度(输出标签类型数) # 第二层全连接宽度(输出标签类型数)
FC2_OUTPUT_NODES = 4 FC2_OUTPUT_NODES = 8
# 输出标签类型数 # 输出标签类型数
OUTPUT_NODES = FC2_OUTPUT_NODES OUTPUT_NODES = FC2_OUTPUT_NODES

View File

@@ -24,40 +24,34 @@ class DataSet:
self.test_labels = [] self.test_labels = []
self.generate_data_sets(folder) self.generate_data_sets(folder)
def generate_data_sets(self, folder): def file2nparray(self, name):
def id2label(id): try:
image = cv2.imread(name)
image = cv2.resize(image, (SRC_COLS, SRC_ROWS))
image = image.astype(np.float32)
return image / 255.0
except:
print(name)
sys.exit(-1)
def id2label(self, id):
a = np.zeros([OUTPUT_NODES]) a = np.zeros([OUTPUT_NODES])
a[id] = 1 a[id] = 1
return a[:] return a[:]
def file2nparray(name): def generate_data_sets(self, folder):
try:
image = cv2.imread(name)
if image.shape[0] < 15:
return None
elif image.shape[1] < 10:
return None
image = cv2.resize(image, (SRC_COLS, SRC_ROWS))
image = image.astype(np.float32)
return image / 255.0, id2label(int(name.split("/")[-2]))
except TypeError:
print(name)
sys.exit(-1)
sets = [] sets = []
for i in range(OUTPUT_NODES): for i in range(OUTPUT_NODES):
dir = "%s/%d" % (folder, i) dir = "%s/%d" % (folder, i)
files = os.listdir(dir) files = os.listdir(dir)
for file in tqdm(files, postfix={"loading id": i}, dynamic_ncols=True): for file in tqdm(files, postfix={"loading id": i}, dynamic_ncols=True):
if file[-3:] == "jpg": if file[-3:] == "jpg":
x = file2nparray("%s/%s" % (dir, file))
if x is not None:
if random.random() > 0.2: if random.random() > 0.2:
self.train_samples.append(x[0]) self.train_samples.append(self.file2nparray("%s/%s" % (dir, file)))
self.train_labels.append(x[1]) self.train_labels.append(self.id2label(i))
else: else:
self.test_samples.append(x[0]) self.test_samples.append(self.file2nparray("%s/%s" % (dir, file)))
self.test_labels.append(x[1]) self.test_labels.append(self.id2label(i))
self.train_samples = np.array(self.train_samples) self.train_samples = np.array(self.train_samples)
self.train_labels = np.array(self.train_labels) self.train_labels = np.array(self.train_labels)
self.test_samples = np.array(self.test_samples) self.test_samples = np.array(self.test_samples)