整理代码
This commit is contained in:
@@ -1,6 +1,9 @@
|
|||||||
//
|
//
|
||||||
// Created by xinyang on 19-4-19.
|
// Created by xinyang on 19-4-19.
|
||||||
//
|
//
|
||||||
|
// 为了一时方便,使用循环和Eigen自行编写的CNN前向传播类。
|
||||||
|
// 没有显著的性能损失。
|
||||||
|
// 但类定义了网络结构,同时实现的操作较少,可扩展性较差
|
||||||
|
|
||||||
#ifndef _CLASSIFIER_H_
|
#ifndef _CLASSIFIER_H_
|
||||||
#define _CLASSIFIER_H_
|
#define _CLASSIFIER_H_
|
||||||
@@ -16,18 +19,19 @@ using namespace Eigen;
|
|||||||
|
|
||||||
class Classifier {
|
class Classifier {
|
||||||
private:
|
private:
|
||||||
bool state;
|
bool state; // 标志分类器是否正确初始化
|
||||||
|
|
||||||
|
// 所有网络参数
|
||||||
vector<vector<MatrixXd>> conv1_w, conv2_w, conv3_w;
|
vector<vector<MatrixXd>> conv1_w, conv2_w, conv3_w;
|
||||||
vector<double> conv1_b, conv2_b, conv3_b;
|
vector<double> conv1_b, conv2_b, conv3_b;
|
||||||
MatrixXd fc1_w, fc2_w;
|
MatrixXd fc1_w, fc2_w;
|
||||||
VectorXd fc1_b, fc2_b;
|
VectorXd fc1_b, fc2_b;
|
||||||
|
// 读取网络参数的函数
|
||||||
vector<vector<MatrixXd>> load_conv_w(const string &file);
|
vector<vector<MatrixXd>> load_conv_w(const string &file);
|
||||||
vector<double> load_conv_b(const string &file);
|
vector<double> load_conv_b(const string &file);
|
||||||
MatrixXd load_fc_w(const string &file);
|
MatrixXd load_fc_w(const string &file);
|
||||||
VectorXd load_fc_b(const string &file);
|
VectorXd load_fc_b(const string &file);
|
||||||
|
// 目前支持的所有操作
|
||||||
MatrixXd softmax(const MatrixXd &input);
|
MatrixXd softmax(const MatrixXd &input);
|
||||||
MatrixXd relu(const MatrixXd &input);
|
MatrixXd relu(const MatrixXd &input);
|
||||||
MatrixXd leaky_relu(const MatrixXd &input, float alpha);
|
MatrixXd leaky_relu(const MatrixXd &input, float alpha);
|
||||||
@@ -51,4 +55,4 @@ public:
|
|||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
#endif //RUNCNN_CLASSIFIER_H
|
#endif /* _CLASSIFIER_H */
|
||||||
|
|||||||
@@ -32,6 +32,11 @@ static systime getFrontTime(const vector<systime> time_seq, const vector<float>
|
|||||||
|
|
||||||
void ArmorFinder::antiTop() {
|
void ArmorFinder::antiTop() {
|
||||||
if (target_box.rect == cv::Rect2d()) return;
|
if (target_box.rect == cv::Rect2d()) return;
|
||||||
|
// 判断是否发生装甲目标切换。
|
||||||
|
// 记录切换前一段时间目标装甲的角度和时间
|
||||||
|
// 通过线性拟合计算出角度为0时对应的时间点
|
||||||
|
// 通过两次装甲角度为零的时间差计算陀螺旋转周期
|
||||||
|
// 根据旋转周期计算下一次装甲出现在角度为零的时间点
|
||||||
if (getPointLength(last_box.getCenter() - target_box.getCenter()) > last_box.rect.height * 1.5) {
|
if (getPointLength(last_box.getCenter() - target_box.getCenter()) > last_box.rect.height * 1.5) {
|
||||||
auto front_time = getFrontTime(time_seq, angle_seq);
|
auto front_time = getFrontTime(time_seq, angle_seq);
|
||||||
auto once_periodms = getTimeIntervalms(front_time, last_front_time);
|
auto once_periodms = getTimeIntervalms(front_time, last_front_time);
|
||||||
@@ -39,7 +44,6 @@ void ArmorFinder::antiTop() {
|
|||||||
// sendBoxPosition(0);
|
// sendBoxPosition(0);
|
||||||
// return;
|
// return;
|
||||||
// }
|
// }
|
||||||
|
|
||||||
LOGM(STR_CTR(WORD_GREEN, "Top period: %.1lf"), once_periodms);
|
LOGM(STR_CTR(WORD_GREEN, "Top period: %.1lf"), once_periodms);
|
||||||
top_periodms.push(once_periodms);
|
top_periodms.push(once_periodms);
|
||||||
auto periodms = mean(top_periodms);
|
auto periodms = mean(top_periodms);
|
||||||
|
|||||||
@@ -58,7 +58,7 @@ ArmorFinder::ArmorFinder(uint8_t &color, Serial &u, const string ¶s_folder,
|
|||||||
}
|
}
|
||||||
|
|
||||||
void ArmorFinder::run(cv::Mat &src) {
|
void ArmorFinder::run(cv::Mat &src) {
|
||||||
getsystime(frame_time);
|
getsystime(frame_time); // 获取当前帧时间(不是足够精确)
|
||||||
// stateSearchingTarget(src); // for debug
|
// stateSearchingTarget(src); // for debug
|
||||||
// goto end;
|
// goto end;
|
||||||
switch (state) {
|
switch (state) {
|
||||||
@@ -87,10 +87,10 @@ void ArmorFinder::run(cv::Mat &src) {
|
|||||||
break;
|
break;
|
||||||
case STANDBY_STATE:
|
case STANDBY_STATE:
|
||||||
default:
|
default:
|
||||||
stateStandBy();
|
stateStandBy(); // currently meaningless
|
||||||
}
|
}
|
||||||
end:
|
end:
|
||||||
if(is_anti_top) {
|
if(is_anti_top) { // 判断当前是否为反陀螺模式
|
||||||
antiTop();
|
antiTop();
|
||||||
}else if(target_box.rect != cv::Rect2d()) {
|
}else if(target_box.rect != cv::Rect2d()) {
|
||||||
anti_top_cnt = 0;
|
anti_top_cnt = 0;
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
//
|
//
|
||||||
// Created by xinyang on 19-4-19.
|
// Created by xinyang on 19-4-19.
|
||||||
//
|
//
|
||||||
|
// 对本文件的大致描述请看classifier.h
|
||||||
|
|
||||||
//#define LOG_LEVEL LOG_NONE
|
//#define LOG_LEVEL LOG_NONE
|
||||||
#include <armor_finder/classifier/classifier.h>
|
#include <armor_finder/classifier/classifier.h>
|
||||||
@@ -321,7 +322,6 @@ int Classifier::operator()(const cv::Mat &image) {
|
|||||||
vector<MatrixXd> sub = {b, g, r};
|
vector<MatrixXd> sub = {b, g, r};
|
||||||
vector<vector<MatrixXd>> in = {sub};
|
vector<vector<MatrixXd>> in = {sub};
|
||||||
MatrixXd result = calculate(in);
|
MatrixXd result = calculate(in);
|
||||||
// cout << result << "==============" <<endl;
|
|
||||||
MatrixXd::Index minRow, minCol;
|
MatrixXd::Index minRow, minCol;
|
||||||
result.maxCoeff(&minRow, &minCol);
|
result.maxCoeff(&minRow, &minCol);
|
||||||
if(result(minRow, minCol) > 0.50){
|
if(result(minRow, minCol) > 0.50){
|
||||||
|
|||||||
@@ -8,13 +8,13 @@
|
|||||||
#include <log.h>
|
#include <log.h>
|
||||||
|
|
||||||
bool ArmorFinder::stateSearchingTarget(cv::Mat &src) {
|
bool ArmorFinder::stateSearchingTarget(cv::Mat &src) {
|
||||||
if (findArmorBox(src, target_box)) {
|
if (findArmorBox(src, target_box)) { // 在原图中寻找目标,并返回是否找到
|
||||||
if (last_box.rect != cv::Rect2d() &&
|
if (last_box.rect != cv::Rect2d() &&
|
||||||
(getPointLength(last_box.getCenter() - target_box.getCenter()) > last_box.rect.height * 2.0) &&
|
(getPointLength(last_box.getCenter() - target_box.getCenter()) > last_box.rect.height * 2.0) &&
|
||||||
anti_switch_cnt++ < 3) {
|
anti_switch_cnt++ < 3) { // 判断当前目标和上次有效目标是否为同一个目标
|
||||||
target_box = ArmorBox();
|
target_box = ArmorBox(); // 并给3帧的时间,试图找到相同目标
|
||||||
LOGM("anti-switch!");
|
LOGM("anti-switch!"); // 即刚发生目标切换内的3帧内不发送目标位置
|
||||||
return false;
|
return false; // 可以一定程度避免频繁多目标切换
|
||||||
} else {
|
} else {
|
||||||
anti_switch_cnt = 0;
|
anti_switch_cnt = 0;
|
||||||
return true;
|
return true;
|
||||||
|
|||||||
@@ -9,7 +9,7 @@
|
|||||||
|
|
||||||
bool ArmorFinder::stateTrackingTarget(cv::Mat &src) {
|
bool ArmorFinder::stateTrackingTarget(cv::Mat &src) {
|
||||||
auto pos = target_box.rect;
|
auto pos = target_box.rect;
|
||||||
if(!tracker->update(src, pos)){
|
if(!tracker->update(src, pos)){ // 使用KCFTracker进行追踪
|
||||||
target_box = ArmorBox();
|
target_box = ArmorBox();
|
||||||
LOGW("Track fail!");
|
LOGW("Track fail!");
|
||||||
return false;
|
return false;
|
||||||
@@ -20,23 +20,20 @@ bool ArmorFinder::stateTrackingTarget(cv::Mat &src) {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 获取相较于追踪区域两倍长款的区域,用于重新搜索,获取灯条信息
|
||||||
cv::Rect2d bigger_rect;
|
cv::Rect2d bigger_rect;
|
||||||
|
|
||||||
bigger_rect.x = pos.x - pos.width / 2.0;
|
bigger_rect.x = pos.x - pos.width / 2.0;
|
||||||
bigger_rect.y = pos.y - pos.height / 2.0;
|
bigger_rect.y = pos.y - pos.height / 2.0;
|
||||||
bigger_rect.height = pos.height * 2;
|
bigger_rect.height = pos.height * 2;
|
||||||
bigger_rect.width = pos.width * 2;
|
bigger_rect.width = pos.width * 2;
|
||||||
bigger_rect &= cv::Rect2d(0, 0, 640, 480);
|
bigger_rect &= cv::Rect2d(0, 0, 640, 480);
|
||||||
|
|
||||||
// if(show_armor_box)
|
|
||||||
// showTrackSearchingPos("track", src, bigger_rect);
|
|
||||||
|
|
||||||
cv::Mat roi = src(bigger_rect).clone();
|
cv::Mat roi = src(bigger_rect).clone();
|
||||||
|
|
||||||
ArmorBox box;
|
ArmorBox box;
|
||||||
if(findArmorBox(roi, box)) {
|
// 在区域内重新搜索。
|
||||||
|
if(findArmorBox(roi, box)) { // 如果成功获取目标,则利用搜索区域重新更新追踪器
|
||||||
target_box = box;
|
target_box = box;
|
||||||
target_box.rect.x += bigger_rect.x;
|
target_box.rect.x += bigger_rect.x; // 添加roi偏移量
|
||||||
target_box.rect.y += bigger_rect.y;
|
target_box.rect.y += bigger_rect.y;
|
||||||
for(auto &blob : target_box.light_blobs){
|
for(auto &blob : target_box.light_blobs){
|
||||||
blob.rect.center.x += bigger_rect.x;
|
blob.rect.center.x += bigger_rect.x;
|
||||||
@@ -44,16 +41,16 @@ bool ArmorFinder::stateTrackingTarget(cv::Mat &src) {
|
|||||||
}
|
}
|
||||||
tracker = TrackerToUse::create();
|
tracker = TrackerToUse::create();
|
||||||
tracker->init(src, target_box.rect);
|
tracker->init(src, target_box.rect);
|
||||||
}else{
|
}else{ // 如果没有成功搜索目标,则使用判断是否跟丢。
|
||||||
roi = src(pos).clone();
|
roi = src(pos).clone();
|
||||||
if(classifier){
|
if(classifier){ // 分类器可用,使用分类器判断。
|
||||||
cv::resize(roi, roi, cv::Size(48, 36));
|
cv::resize(roi, roi, cv::Size(48, 36));
|
||||||
if(classifier(roi) == 0){
|
if(classifier(roi) == 0){
|
||||||
target_box = ArmorBox();
|
target_box = ArmorBox();
|
||||||
LOGW("Track classify fail range!");
|
LOGW("Track classify fail range!");
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}else{
|
}else{ // 分类器不可用,使用常规方法判断
|
||||||
cv::Mat roi_gray;
|
cv::Mat roi_gray;
|
||||||
cv::cvtColor(roi, roi_gray, CV_RGB2GRAY);
|
cv::cvtColor(roi, roi_gray, CV_RGB2GRAY);
|
||||||
cv::threshold(roi_gray, roi_gray, 180, 255, cv::THRESH_BINARY);
|
cv::threshold(roi_gray, roi_gray, 180, 255, cv::THRESH_BINARY);
|
||||||
@@ -65,7 +62,6 @@ bool ArmorFinder::stateTrackingTarget(cv::Mat &src) {
|
|||||||
}
|
}
|
||||||
target_box.rect = pos;
|
target_box.rect = pos;
|
||||||
target_box.light_blobs.clear();
|
target_box.light_blobs.clear();
|
||||||
target_box = ArmorBox();
|
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|||||||
15
main.cpp
15
main.cpp
@@ -47,8 +47,6 @@ ArmorFinder armor_finder(mcu_data.enemy_color, serial, PROJECT_DIR"/tools/para/"
|
|||||||
// 能量机关主程序对象
|
// 能量机关主程序对象
|
||||||
Energy energy(serial, mcu_data.enemy_color);
|
Energy energy(serial, mcu_data.enemy_color);
|
||||||
|
|
||||||
int box_distance = 0;
|
|
||||||
|
|
||||||
int main(int argc, char *argv[]) {
|
int main(int argc, char *argv[]) {
|
||||||
processOptions(argc, argv); // 处理命令行参数
|
processOptions(argc, argv); // 处理命令行参数
|
||||||
thread receive(uartReceive, &serial); // 开启串口接收线程
|
thread receive(uartReceive, &serial); // 开启串口接收线程
|
||||||
@@ -107,11 +105,10 @@ int main(int argc, char *argv[]) {
|
|||||||
}
|
}
|
||||||
energy.setEnergyInit();
|
energy.setEnergyInit();
|
||||||
}
|
}
|
||||||
last_state = curr_state;//更新上一帧状态
|
|
||||||
ok = checkReconnect(video->read(src));
|
ok = checkReconnect(video->read(src));
|
||||||
#ifdef GIMBAL_FLIP_MODE
|
#ifdef GIMBAL_FLIP_MODE
|
||||||
flip(src, src, GIMBAL_FLIP_MODE);
|
flip(src, src, GIMBAL_FLIP_MODE);
|
||||||
#endif
|
#endif
|
||||||
if (!from_camera) extract(src);
|
if (!from_camera) extract(src);
|
||||||
if (save_video) saveVideos(src);//保存视频
|
if (save_video) saveVideos(src);//保存视频
|
||||||
if (show_origin) showOrigin(src);//显示原始图像
|
if (show_origin) showOrigin(src);//显示原始图像
|
||||||
@@ -130,23 +127,23 @@ int main(int argc, char *argv[]) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
last_state = curr_state;
|
|
||||||
CNT_TIME(STR_CTR(WORD_GREEN, "read img"), {
|
CNT_TIME(STR_CTR(WORD_GREEN, "read img"), {
|
||||||
if(!checkReconnect(video->read(src))) continue;
|
if(!checkReconnect(video->read(src))) continue;
|
||||||
});
|
});
|
||||||
#ifdef GIMBAL_FLIP_MODE
|
#ifdef GIMBAL_FLIP_MODE
|
||||||
flip(src, src, GIMBAL_FLIP_MODE);
|
flip(src, src, GIMBAL_FLIP_MODE);
|
||||||
#endif
|
#endif
|
||||||
// CNT_TIME("something whatever", {
|
CNT_TIME("something whatever", {
|
||||||
if (!from_camera) extract(src);
|
if (!from_camera) extract(src);
|
||||||
if (save_video) saveVideos(src);
|
if (save_video) saveVideos(src);
|
||||||
if (show_origin) showOrigin(src);
|
if (show_origin) showOrigin(src);
|
||||||
// });
|
});
|
||||||
CNT_TIME(STR_CTR(WORD_CYAN, "Armor Time"), {
|
CNT_TIME(STR_CTR(WORD_CYAN, "Armor Time"), {
|
||||||
armor_finder.run(src);
|
armor_finder.run(src);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
// cv::waitKey(0);
|
last_state = curr_state;//更新上一帧状态
|
||||||
|
if(run_by_frame) cv::waitKey(0);
|
||||||
});
|
});
|
||||||
} while (ok);
|
} while (ok);
|
||||||
delete video;
|
delete video;
|
||||||
|
|||||||
@@ -1,7 +1,30 @@
|
|||||||
//
|
//
|
||||||
// Created by xinyang on 19-2-19.
|
// Created by xinyang on 19-2-19.
|
||||||
//
|
//
|
||||||
|
// 该文件提供一个更加方便的调试信息输出方式
|
||||||
|
// 所有输出信息分为三个LEVEL:MSG,WARNING,ERROR
|
||||||
|
// 可使用宏LOG_LEVEL定义当前文件使用的输出LEVEL
|
||||||
|
// 高于该LEVEL的输出讲不会被显示
|
||||||
|
// ============================================================
|
||||||
|
// 输出API:
|
||||||
|
// LOG(level, format, ...)
|
||||||
|
// arguments: level:当前输出的level
|
||||||
|
// format:标准printf格式化字符串
|
||||||
|
// LOGM(format, ...) 使用MSG level进行输出
|
||||||
|
// LOGW(format, ...) 使用WARNING level进行输出
|
||||||
|
// LOGE(format, ...) 使用ERROR level进行输出
|
||||||
|
// ============================================================
|
||||||
|
// 输出颜色API:(仅对部分终端生效)
|
||||||
|
// STR_CTR(ctrs, str)
|
||||||
|
// arguments: ctrs:该字符串对应的颜色(所有以WORD开头的宏)
|
||||||
|
// str:需要上色的字符串
|
||||||
|
// ============================================================
|
||||||
|
// 时间计算API:(需要配合systime.h使用)
|
||||||
|
// CNT_TIME(tag, codes, ...)
|
||||||
|
// arguments: tag:显示代码块执行时间前的用户信息,支持printf格式化字符串
|
||||||
|
// codes:需要被统计时间的代码块
|
||||||
|
// attention: 代码块内定义的局部变量作用域仅限于该代码块
|
||||||
|
//
|
||||||
#ifndef _LOG_H_
|
#ifndef _LOG_H_
|
||||||
#define _LOG_H_
|
#define _LOG_H_
|
||||||
|
|
||||||
@@ -35,6 +58,7 @@
|
|||||||
#define BACK_GRAY_CODE ";47"
|
#define BACK_GRAY_CODE ";47"
|
||||||
|
|
||||||
#define CTRS(ctrs) START_CTR ctrs END_CTR
|
#define CTRS(ctrs) START_CTR ctrs END_CTR
|
||||||
|
#define STR_CTR(ctrs, str) START_CTR ctrs END_CTR str CLEAR_ALL
|
||||||
|
|
||||||
#define WORD_WHITE WORD_WHITE_CODE
|
#define WORD_WHITE WORD_WHITE_CODE
|
||||||
#define WORD_RED WORD_RED_CODE
|
#define WORD_RED WORD_RED_CODE
|
||||||
@@ -106,7 +130,6 @@
|
|||||||
#define LOG_4(format, ...) ((void)0)
|
#define LOG_4(format, ...) ((void)0)
|
||||||
#define LOG_(level, format, ...) LOG_##level (format, ##__VA_ARGS__)
|
#define LOG_(level, format, ...) LOG_##level (format, ##__VA_ARGS__)
|
||||||
#define LOG(level, format, ...) LOG_(level, format"\n", ##__VA_ARGS__)
|
#define LOG(level, format, ...) LOG_(level, format"\n", ##__VA_ARGS__)
|
||||||
#define STR_CTR(ctrs, str) START_CTR ctrs END_CTR str CLEAR_ALL
|
|
||||||
|
|
||||||
#define LOGA(format, ...) LOG(LOG_NONE, format, ##__VA_ARGS__)
|
#define LOGA(format, ...) LOG(LOG_NONE, format, ##__VA_ARGS__)
|
||||||
#define LOGA_INFO(format, ...) LOG(LOG_NONE, "<%s:%d>: " format, ##__VA_ARGS__)
|
#define LOGA_INFO(format, ...) LOG(LOG_NONE, "<%s:%d>: " format, ##__VA_ARGS__)
|
||||||
|
|||||||
@@ -23,6 +23,8 @@ extern bool show_process;
|
|||||||
extern bool show_energy;
|
extern bool show_energy;
|
||||||
extern bool save_mark;
|
extern bool save_mark;
|
||||||
extern bool show_info;
|
extern bool show_info;
|
||||||
|
extern bool run_by_frame;
|
||||||
|
|
||||||
|
|
||||||
void processOptions(int argc, char **argv);
|
void processOptions(int argc, char **argv);
|
||||||
|
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ bool show_process = false;
|
|||||||
bool show_energy = false;
|
bool show_energy = false;
|
||||||
bool save_mark = false;
|
bool save_mark = false;
|
||||||
bool show_info = false;
|
bool show_info = false;
|
||||||
|
bool run_by_frame = false;
|
||||||
|
|
||||||
// 使用map保存所有选项及其描述和操作,加快查找速度。
|
// 使用map保存所有选项及其描述和操作,加快查找速度。
|
||||||
std::map<std::string, std::pair<std::string, void(*)(void)>> options = {
|
std::map<std::string, std::pair<std::string, void(*)(void)>> options = {
|
||||||
@@ -78,6 +79,12 @@ std::map<std::string, std::pair<std::string, void(*)(void)>> options = {
|
|||||||
LOGM("Enable wait uart!");
|
LOGM("Enable wait uart!");
|
||||||
}
|
}
|
||||||
}},
|
}},
|
||||||
|
{"--run-by-frame",{
|
||||||
|
"run the code frame by frame.(normally used when run video)", [](){
|
||||||
|
run_by_frame = true;
|
||||||
|
LOGM("Enable run frame by frame");
|
||||||
|
}
|
||||||
|
}},
|
||||||
{"--show-process", {
|
{"--show-process", {
|
||||||
"", [](){
|
"", [](){
|
||||||
show_process = true;
|
show_process = true;
|
||||||
|
|||||||
@@ -43,7 +43,7 @@ class DataSet:
|
|||||||
if file[-3:] == "jpg":
|
if file[-3:] == "jpg":
|
||||||
sample = self.file2nparray("%s/%s" % (dir, file))
|
sample = self.file2nparray("%s/%s" % (dir, file))
|
||||||
label = self.id2label(i)
|
label = self.id2label(i)
|
||||||
if random.random() > 0.7:
|
if random.random() < 0.7:
|
||||||
self.train_samples.append(sample)
|
self.train_samples.append(sample)
|
||||||
self.train_labels.append(label)
|
self.train_labels.append(label)
|
||||||
if i == 0:
|
if i == 0:
|
||||||
|
|||||||
Reference in New Issue
Block a user