更新了一大堆东西,我也说不清有什么。
This commit is contained in:
@@ -1,34 +1,34 @@
|
|||||||
cmake_minimum_required(VERSION 3.5)
|
CMAKE_MINIMUM_REQUIRED(VERSION 3.5)
|
||||||
|
|
||||||
project(AutoAim)
|
PROJECT(AutoAim)
|
||||||
set(CMAKE_CXX_STANDARD 11)
|
SET(CMAKE_CXX_STANDARD 11)
|
||||||
SET(CMAKE_BUILD_TYPE DEBUG)
|
SET(CMAKE_BUILD_TYPE RELEASE)
|
||||||
|
SET(CMAKE_CXX_FLAGS "-DPROJECT_DIR=\"\\\"${PROJECT_SOURCE_DIR}\\\"\"")
|
||||||
|
|
||||||
FIND_PROGRAM(CCACHE_FOUND ccache)
|
FIND_PROGRAM(CCACHE_FOUND ccache)
|
||||||
IF(CCACHE_FOUND)
|
IF(CCACHE_FOUND)
|
||||||
set_property(GLOBAL PROPERTY RULE_LAUNCH_COMPILE ccache)
|
SET_PROPERTY(GLOBAL PROPERTY RULE_LAUNCH_COMPILE ccache)
|
||||||
set_property(GLOBAL PROPERTY RULE_LAUNCH_LINK ccache)
|
SET_PROPERTY(GLOBAL PROPERTY RULE_LAUNCH_LINK ccache)
|
||||||
message("< Use ccache for compiler >")
|
MESSAGE("< Use ccache for compiler >")
|
||||||
ENDIF()
|
ENDIF()
|
||||||
|
|
||||||
FIND_PACKAGE(OpenCV 3 REQUIRED)
|
FIND_PACKAGE(OpenCV 3 REQUIRED)
|
||||||
FIND_PACKAGE(Eigen3 REQUIRED)
|
FIND_PACKAGE(Eigen3 REQUIRED)
|
||||||
FIND_PACKAGE(Threads)
|
FIND_PACKAGE(Threads)
|
||||||
|
|
||||||
include_directories( ${EIGEN3_INCLUDE_DIR} )
|
INCLUDE_DIRECTORIES(${EIGEN3_INCLUDE_DIR})
|
||||||
include_directories( ${PROJECT_SOURCE_DIR}/energy/include )
|
INCLUDE_DIRECTORIES(${PROJECT_SOURCE_DIR}/energy/include)
|
||||||
include_directories( ${PROJECT_SOURCE_DIR}/armor/include )
|
INCLUDE_DIRECTORIES(${PROJECT_SOURCE_DIR}/armor/include)
|
||||||
include_directories( ${PROJECT_SOURCE_DIR}/others/include )
|
INCLUDE_DIRECTORIES(${PROJECT_SOURCE_DIR}/others/include)
|
||||||
|
|
||||||
FILE(GLOB_RECURSE sourcefiles "others/src/*.cpp" "energy/src/*cpp" "armor/src/*.cpp")
|
FILE(GLOB_RECURSE sourcefiles "others/src/*.cpp" "energy/src/*cpp" "armor/src/*.cpp")
|
||||||
add_executable(run main.cpp ${sourcefiles} )
|
ADD_EXECUTABLE(run main.cpp ${sourcefiles} )
|
||||||
|
|
||||||
TARGET_LINK_LIBRARIES(run ${CMAKE_THREAD_LIBS_INIT})
|
TARGET_LINK_LIBRARIES(run ${CMAKE_THREAD_LIBS_INIT})
|
||||||
TARGET_LINK_LIBRARIES(run ${OpenCV_LIBS})
|
TARGET_LINK_LIBRARIES(run ${OpenCV_LIBS})
|
||||||
TARGET_LINK_LIBRARIES(run ${PROJECT_SOURCE_DIR}/others/libMVSDK.so)
|
TARGET_LINK_LIBRARIES(run ${PROJECT_SOURCE_DIR}/others/libMVSDK.so)
|
||||||
|
|
||||||
# Todo
|
ADD_CUSTOM_TARGET(train COMMAND "gnome-terminal" "-x" "bash" "-c" "\"${PROJECT_SOURCE_DIR}/tools/TrainCNN/backward.py\"" )
|
||||||
# ADD_CUSTOM_TARGET(bind-monitor COMMAND "")
|
|
||||||
|
|
||||||
# Todo
|
# Todo
|
||||||
# ADD_CUSTOM_TARGET(train COMMAND "")
|
# ADD_CUSTOM_TARGET(bind-monitor COMMAND "")
|
||||||
|
|||||||
@@ -2,6 +2,9 @@
|
|||||||
// Created by xinyang on 19-3-27.
|
// Created by xinyang on 19-3-27.
|
||||||
//
|
//
|
||||||
#include <log.h>
|
#include <log.h>
|
||||||
|
#include <options/options.h>
|
||||||
|
#include <show_images/show_images.h>
|
||||||
|
#include <opencv2/highgui.hpp>
|
||||||
#include <armor_finder/armor_finder.h>
|
#include <armor_finder/armor_finder.h>
|
||||||
|
|
||||||
ArmorFinder::ArmorFinder(EnemyColor color, Uart &u, string paras_folder) :
|
ArmorFinder::ArmorFinder(EnemyColor color, Uart &u, string paras_folder) :
|
||||||
@@ -11,38 +14,48 @@ ArmorFinder::ArmorFinder(EnemyColor color, Uart &u, string paras_folder) :
|
|||||||
classifier(std::move(paras_folder)),
|
classifier(std::move(paras_folder)),
|
||||||
contour_area(0)
|
contour_area(0)
|
||||||
{
|
{
|
||||||
auto para = TrackerToUse::Params();
|
// auto para = TrackerToUse::Params();
|
||||||
para.desc_npca = 1;
|
// para.desc_npca = 1;
|
||||||
para.desc_pca = 0;
|
// para.desc_pca = 0;
|
||||||
tracker = TrackerToUse::create(para);
|
// tracker = TrackerToUse::create(para);
|
||||||
if(!tracker){
|
// if(!tracker){
|
||||||
LOGW("Tracker Not init");
|
// LOGW("Tracker Not init");
|
||||||
}
|
// }
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void ArmorFinder::run(cv::Mat &src) {
|
void ArmorFinder::run(cv::Mat &src) {
|
||||||
cv::Mat src_use;
|
cv::Mat src_use;
|
||||||
// if (src.type() == CV_8UC3) {
|
src_use = src.clone();
|
||||||
// cv::cvtColor(src, src_use, CV_RGB2GRAY);
|
|
||||||
// }else{
|
|
||||||
src_use = src.clone();
|
|
||||||
// }
|
|
||||||
cv::cvtColor(src_use, src_gray, CV_RGB2GRAY);
|
cv::cvtColor(src_use, src_gray, CV_RGB2GRAY);
|
||||||
|
|
||||||
stateSearchingTarget(src_use);
|
if(show_armor_box){
|
||||||
return;
|
showArmorBox("box", src, armor_box);
|
||||||
|
cv::waitKey(1);
|
||||||
|
}
|
||||||
|
// stateSearchingTarget(src_use);
|
||||||
|
// return;
|
||||||
switch (state){
|
switch (state){
|
||||||
case SEARCHING_STATE:
|
case SEARCHING_STATE:
|
||||||
if(stateSearchingTarget(src_use)){
|
if(stateSearchingTarget(src_use)){
|
||||||
if((armor_box & cv::Rect2d(0, 0, 640, 480)) == armor_box) {
|
if((armor_box & cv::Rect2d(0, 0, 640, 480)) == armor_box) {
|
||||||
cv::Mat roi = src_use.clone()(armor_box);
|
// cv::Mat roi = src_gray.clone()(armor_box);
|
||||||
cv::threshold(roi, roi, 200, 255, cv::THRESH_BINARY);
|
// cv::threshold(roi, roi, 200, 255, cv::THRESH_BINARY);
|
||||||
contour_area = cv::countNonZero(roi);
|
// contour_area = cv::countNonZero(roi);
|
||||||
auto para = TrackerToUse::Params();
|
// auto para = TrackerToUse::Params();
|
||||||
para.desc_npca = 1;
|
// para.desc_npca = 1;
|
||||||
para.desc_pca = 0;
|
// para.desc_pca = 0;
|
||||||
tracker = TrackerToUse::create(para);
|
// tracker = TrackerToUse::create(para);
|
||||||
|
// tracker->init(src_gray, armor_box);
|
||||||
|
// tracker->update(src_gray, armor_box);
|
||||||
|
cv::Mat roi = src_use.clone()(armor_box), roi_gray;
|
||||||
|
cv::cvtColor(roi, roi_gray, CV_RGB2GRAY);
|
||||||
|
// cv::imshow("boxroi", roi);
|
||||||
|
// cv::waitKey(0);
|
||||||
|
cv::threshold(roi_gray, roi_gray, 180, 255, cv::THRESH_BINARY);
|
||||||
|
contour_area = cv::countNonZero(roi_gray);
|
||||||
|
LOGW("%d", contour_area);
|
||||||
|
tracker = TrackerToUse::create();
|
||||||
tracker->init(src_use, armor_box);
|
tracker->init(src_use, armor_box);
|
||||||
state = TRACKING_STATE;
|
state = TRACKING_STATE;
|
||||||
LOGW("into track");
|
LOGW("into track");
|
||||||
@@ -50,7 +63,7 @@ void ArmorFinder::run(cv::Mat &src) {
|
|||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
case TRACKING_STATE:
|
case TRACKING_STATE:
|
||||||
if(!stateTrackingTarget(src_gray)){
|
if(!stateTrackingTarget(src_use)){
|
||||||
state = SEARCHING_STATE;
|
state = SEARCHING_STATE;
|
||||||
//std::cout << "into search!" << std::endl;
|
//std::cout << "into search!" << std::endl;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -259,7 +259,7 @@ Classifier::Classifier(const string &folder) : state(true){
|
|||||||
fc2_w = load_fc_w(folder+"fc2_w");
|
fc2_w = load_fc_w(folder+"fc2_w");
|
||||||
fc2_b = load_fc_b(folder+"fc2_b");
|
fc2_b = load_fc_b(folder+"fc2_b");
|
||||||
if(state){
|
if(state){
|
||||||
LOGM("Load paras success!");
|
LOGM("Load para success!");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -7,7 +7,6 @@
|
|||||||
#include "image_process/image_process.h"
|
#include "image_process/image_process.h"
|
||||||
#include <log.h>
|
#include <log.h>
|
||||||
#include <show_images/show_images.h>
|
#include <show_images/show_images.h>
|
||||||
|
|
||||||
#include <options/options.h>
|
#include <options/options.h>
|
||||||
|
|
||||||
typedef std::vector<LightBlob> LightBlobs;
|
typedef std::vector<LightBlob> LightBlobs;
|
||||||
@@ -32,11 +31,10 @@ static void pipelineLightBlobPreprocess(cv::Mat &src) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
static bool findLightBlobs(const cv::Mat &src, LightBlobs &light_blobs) {
|
static bool findLightBlobs(const cv::Mat &src, LightBlobs &light_blobs) {
|
||||||
static cv::Mat src_bin;
|
// static cv::Mat src_bin;
|
||||||
|
|
||||||
cv::threshold(src, src_bin, 80, 255, CV_THRESH_BINARY);
|
|
||||||
std::vector<std::vector<cv::Point> > light_contours;
|
std::vector<std::vector<cv::Point> > light_contours;
|
||||||
cv::findContours(src_bin, light_contours, CV_RETR_EXTERNAL, CV_CHAIN_APPROX_NONE);
|
cv::findContours(src, light_contours, CV_RETR_EXTERNAL, CV_CHAIN_APPROX_NONE);
|
||||||
for (auto &light_contour : light_contours) {
|
for (auto &light_contour : light_contours) {
|
||||||
cv::RotatedRect rect = cv::minAreaRect(light_contour);
|
cv::RotatedRect rect = cv::minAreaRect(light_contour);
|
||||||
if(isValidLightBlob(rect)){
|
if(isValidLightBlob(rect)){
|
||||||
@@ -117,7 +115,7 @@ static bool findArmorBoxes(LightBlobs &light_blobs, std::vector<cv::Rect2d> &arm
|
|||||||
min_x = fmin(rect_left.x, rect_right.x);
|
min_x = fmin(rect_left.x, rect_right.x);
|
||||||
max_x = fmax(rect_left.x + rect_left.width, rect_right.x + rect_right.width);
|
max_x = fmax(rect_left.x + rect_left.width, rect_right.x + rect_right.width);
|
||||||
min_y = fmin(rect_left.y, rect_right.y) - 5;
|
min_y = fmin(rect_left.y, rect_right.y) - 5;
|
||||||
max_y = fmax(rect_left.y + rect_left.height, rect_right.y + rect_right.height) + 5;
|
max_y = fmax(rect_left.y + rect_left.height, rect_right.y + rect_right.height);
|
||||||
if (min_x < 0 || max_x > 640 || min_y < 0 || max_y > 480) {
|
if (min_x < 0 || max_x > 640 || min_y < 0 || max_y > 480) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@@ -149,19 +147,21 @@ bool judge_light_color(std::vector<LightBlob> &light, std::vector<LightBlob> &co
|
|||||||
}
|
}
|
||||||
|
|
||||||
bool ArmorFinder::stateSearchingTarget(cv::Mat &src) {
|
bool ArmorFinder::stateSearchingTarget(cv::Mat &src) {
|
||||||
cv::Mat split, pmsrc=src.clone();
|
cv::Mat split, pmsrc=src.clone(), src_bin;
|
||||||
LightBlobs light_blobs, pm_light_blobs, light_blobs_real;
|
LightBlobs light_blobs, pm_light_blobs, light_blobs_real;
|
||||||
std::vector<cv::Rect2d> armor_boxes, boxes_one, boxes_two, boxes_three;
|
std::vector<cv::Rect2d> armor_boxes, boxes_one, boxes_two, boxes_three;
|
||||||
|
|
||||||
// cv::resize(src, pmsrc, cv::Size(320, 240));
|
// cv::resize(src, pmsrc, cv::Size(320, 240));
|
||||||
imageColorSplit(src, split, enemy_color);
|
imageColorSplit(src, split, enemy_color);
|
||||||
imagePreProcess(split);
|
cv::threshold(split, src_bin, 130, 255, CV_THRESH_BINARY);
|
||||||
cv::resize(split, split, cv::Size(640, 480));
|
imagePreProcess(src_bin);
|
||||||
|
cv::imshow("bin", src_bin);
|
||||||
|
// cv::resize(split, split, cv::Size(640, 480));
|
||||||
// pipelineLightBlobPreprocess(pmsrc);
|
// pipelineLightBlobPreprocess(pmsrc);
|
||||||
// if(!findLightBlobs(pmsrc, pm_light_blobs)){
|
// if(!findLightBlobs(pmsrc, pm_light_blobs)){
|
||||||
// return false;
|
// return false;
|
||||||
// }
|
// }
|
||||||
if(!findLightBlobs(split, light_blobs)){
|
if(!findLightBlobs(src_bin, light_blobs)){
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
// if(!judge_light_color(light_blobs, pm_light_blobs, light_blobs_real)){
|
// if(!judge_light_color(light_blobs, pm_light_blobs, light_blobs_real)){
|
||||||
@@ -182,10 +182,7 @@ bool ArmorFinder::stateSearchingTarget(cv::Mat &src) {
|
|||||||
for(auto box : armor_boxes){
|
for(auto box : armor_boxes){
|
||||||
cv::Mat roi = src(box).clone();
|
cv::Mat roi = src(box).clone();
|
||||||
cv::resize(roi, roi, cv::Size(48, 36));
|
cv::resize(roi, roi, cv::Size(48, 36));
|
||||||
// cv::imshow("roi", roi);
|
|
||||||
// cv::waitKey(0);
|
|
||||||
int c = classifier(roi);
|
int c = classifier(roi);
|
||||||
// cout << c << endl;
|
|
||||||
switch(c){
|
switch(c){
|
||||||
case 1:
|
case 1:
|
||||||
boxes_one.emplace_back(box);
|
boxes_one.emplace_back(box);
|
||||||
@@ -204,6 +201,8 @@ bool ArmorFinder::stateSearchingTarget(cv::Mat &src) {
|
|||||||
armor_box = boxes_two[0];
|
armor_box = boxes_two[0];
|
||||||
}else if(!boxes_three.empty()){
|
}else if(!boxes_three.empty()){
|
||||||
armor_box = boxes_three[0];
|
armor_box = boxes_three[0];
|
||||||
|
} else{
|
||||||
|
return false;
|
||||||
}
|
}
|
||||||
if(show_armor_box){
|
if(show_armor_box){
|
||||||
showArmorBoxClass("class", src, boxes_one, boxes_two, boxes_three);
|
showArmorBoxClass("class", src, boxes_one, boxes_two, boxes_three);
|
||||||
@@ -211,10 +210,6 @@ bool ArmorFinder::stateSearchingTarget(cv::Mat &src) {
|
|||||||
}else{
|
}else{
|
||||||
armor_box = armor_boxes[0];
|
armor_box = armor_boxes[0];
|
||||||
}
|
}
|
||||||
if(show_armor_box){
|
|
||||||
showArmorBox("box", src, armor_box);
|
|
||||||
cv::waitKey(1);
|
|
||||||
}
|
|
||||||
if(split.size() == cv::Size(320, 240)){
|
if(split.size() == cv::Size(320, 240)){
|
||||||
armor_box.x *= 2;
|
armor_box.x *= 2;
|
||||||
armor_box.y *= 2;
|
armor_box.y *= 2;
|
||||||
|
|||||||
@@ -5,16 +5,18 @@
|
|||||||
#include <armor_finder/armor_finder.h>
|
#include <armor_finder/armor_finder.h>
|
||||||
|
|
||||||
bool ArmorFinder::stateTrackingTarget(cv::Mat &src) {
|
bool ArmorFinder::stateTrackingTarget(cv::Mat &src) {
|
||||||
auto last = armor_box;
|
if(!tracker->update(src, armor_box)){
|
||||||
tracker->update(src, armor_box);
|
return false;
|
||||||
|
}
|
||||||
if((armor_box & cv::Rect2d(0, 0, 640, 480)) != armor_box){
|
if((armor_box & cv::Rect2d(0, 0, 640, 480)) != armor_box){
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
cv::Mat roi = src(armor_box);
|
|
||||||
threshold(roi, roi, 200, 255, cv::THRESH_BINARY);
|
|
||||||
|
|
||||||
if(abs(cv::countNonZero(roi) - contour_area) > contour_area * 0.3){
|
cv::Mat roi = src.clone()(armor_box), roi_gray;
|
||||||
|
cv::cvtColor(roi, roi_gray, CV_RGB2GRAY);
|
||||||
|
cv::threshold(roi_gray, roi_gray, 180, 255, cv::THRESH_BINARY);
|
||||||
|
contour_area = cv::countNonZero(roi_gray);
|
||||||
|
if(abs(cv::countNonZero(roi_gray) - contour_area) > contour_area * 0.3){
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
return sendBoxPosition();
|
return sendBoxPosition();
|
||||||
|
|||||||
23
main.cpp
23
main.cpp
@@ -13,32 +13,33 @@
|
|||||||
#include <camera/wrapper_head.h>
|
#include <camera/wrapper_head.h>
|
||||||
#include <armor_finder/armor_finder.h>
|
#include <armor_finder/armor_finder.h>
|
||||||
#include <options/options.h>
|
#include <options/options.h>
|
||||||
|
#include <thread>
|
||||||
|
|
||||||
|
#define DO_NOT_CNT_TIME
|
||||||
#include <log.h>
|
#include <log.h>
|
||||||
|
|
||||||
#include <thread>
|
#define PATH PROJECT_DIR
|
||||||
|
#define ENERGY_STATE 1
|
||||||
|
#define ARMOR_STATE 0
|
||||||
|
|
||||||
using namespace cv;
|
using namespace cv;
|
||||||
using namespace std;
|
using namespace std;
|
||||||
|
|
||||||
#define ENERGY_STATE 1
|
|
||||||
#define ARMOR_STATE 0
|
|
||||||
|
|
||||||
int state = ENERGY_STATE;
|
int state = ARMOR_STATE;
|
||||||
float curr_yaw=0, curr_pitch=0;
|
float curr_yaw=0, curr_pitch=0;
|
||||||
float mark_yaw=0, mark_pitch=0;
|
float mark_yaw=0, mark_pitch=0;
|
||||||
int mark = 0;
|
int mark = 0;
|
||||||
|
|
||||||
void uartReceive(Uart* uart);
|
void uartReceive(Uart* uart);
|
||||||
|
|
||||||
int main(int argc, char *argv[])
|
int main(int argc, char *argv[]){
|
||||||
{
|
|
||||||
process_options(argc, argv);
|
process_options(argc, argv);
|
||||||
Uart uart;
|
Uart uart;
|
||||||
thread receive(uartReceive, &uart);
|
thread receive(uartReceive, &uart);
|
||||||
bool flag = true;
|
bool flag = true;
|
||||||
|
|
||||||
while (flag)
|
while (flag){
|
||||||
{
|
|
||||||
int ally_color = ALLY_RED;
|
int ally_color = ALLY_RED;
|
||||||
int energy_part_rotation = CLOCKWISE;
|
int energy_part_rotation = CLOCKWISE;
|
||||||
|
|
||||||
@@ -54,8 +55,8 @@ int main(int argc, char *argv[])
|
|||||||
// video_armor = new CameraWrapper();
|
// video_armor = new CameraWrapper();
|
||||||
video_energy = new CameraWrapper();
|
video_energy = new CameraWrapper();
|
||||||
}else {
|
}else {
|
||||||
video_armor = new VideoWrapper("r_l_640.avi");
|
video_armor = new VideoWrapper("/home/xinyang/Desktop/Video.mp4");
|
||||||
video_energy = new VideoWrapper("r_l_640.avi");
|
video_energy = new VideoWrapper("/home/xinyang/Desktop/Video.mp4");
|
||||||
}
|
}
|
||||||
if (video_energy->init()) {
|
if (video_energy->init()) {
|
||||||
cout << "Video source initialization successfully." << endl;
|
cout << "Video source initialization successfully." << endl;
|
||||||
@@ -63,7 +64,7 @@ int main(int argc, char *argv[])
|
|||||||
|
|
||||||
Mat energy_src, armor_src;
|
Mat energy_src, armor_src;
|
||||||
|
|
||||||
ArmorFinder armorFinder(ENEMY_BLUE, uart, "/home/xinyang/Desktop/AutoAim/tools/para/");
|
ArmorFinder armorFinder(ENEMY_BLUE, uart, PATH"/tools/para/");
|
||||||
|
|
||||||
Energy energy(uart);
|
Energy energy(uart);
|
||||||
energy.setAllyColor(ally_color);
|
energy.setAllyColor(ally_color);
|
||||||
|
|||||||
79
tools/TrainCNN/backward.py
Normal file → Executable file
79
tools/TrainCNN/backward.py
Normal file → Executable file
@@ -1,5 +1,6 @@
|
|||||||
|
#!/usr/bin/python3
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from progressive.bar import Bar
|
from tqdm import tqdm
|
||||||
import generate
|
import generate
|
||||||
import forward
|
import forward
|
||||||
import cv2
|
import cv2
|
||||||
@@ -51,7 +52,7 @@ def save_para(folder, paras):
|
|||||||
|
|
||||||
|
|
||||||
STEPS = 20000
|
STEPS = 20000
|
||||||
BATCH = 10
|
BATCH = 30
|
||||||
LEARNING_RATE_BASE = 0.01
|
LEARNING_RATE_BASE = 0.01
|
||||||
LEARNING_RATE_DECAY = 0.99
|
LEARNING_RATE_DECAY = 0.99
|
||||||
MOVING_AVERAGE_DECAY = 0.99
|
MOVING_AVERAGE_DECAY = 0.99
|
||||||
@@ -85,18 +86,13 @@ def train(dataset, show_bar=False):
|
|||||||
|
|
||||||
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
|
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
|
||||||
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
|
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
|
||||||
acc = 0
|
|
||||||
|
|
||||||
with tf.Session() as sess:
|
with tf.Session() as sess:
|
||||||
init_op = tf.global_variables_initializer()
|
init_op = tf.global_variables_initializer()
|
||||||
sess.run(init_op)
|
sess.run(init_op)
|
||||||
|
|
||||||
if show_bar:
|
bar = tqdm(range(STEPS), ascii=True, dynamic_ncols=True)
|
||||||
bar = Bar(max_value=STEPS, width=u'50%')
|
for i in bar:
|
||||||
bar.cursor.clear_lines(1)
|
|
||||||
bar.cursor.save()
|
|
||||||
|
|
||||||
for i in range(STEPS):
|
|
||||||
images_samples, labels_samples = dataset.sample_train_sets(BATCH)
|
images_samples, labels_samples = dataset.sample_train_sets(BATCH)
|
||||||
|
|
||||||
_, loss_value, step = sess.run(
|
_, loss_value, step = sess.run(
|
||||||
@@ -107,30 +103,49 @@ def train(dataset, show_bar=False):
|
|||||||
if i % 100 == 0:
|
if i % 100 == 0:
|
||||||
if i % 1000 == 0:
|
if i % 1000 == 0:
|
||||||
acc = sess.run(accuracy, feed_dict={x: test_images, y_: test_labels})
|
acc = sess.run(accuracy, feed_dict={x: test_images, y_: test_labels})
|
||||||
|
bar.set_postfix({"loss": loss_value, "acc": acc})
|
||||||
|
|
||||||
if show_bar:
|
|
||||||
bar.title = "step: %d, loss: %f, acc: %f" % (step, loss_value, acc)
|
|
||||||
bar.cursor.restore()
|
|
||||||
bar.draw(value=i+1)
|
|
||||||
|
|
||||||
# video = cv2.VideoCapture("/home/xinyang/Desktop/Video.mp4")
|
# if show_bar:
|
||||||
# _ = True
|
# bar.title = "step: %d, loss: %f, acc: %f" % (step, loss_value, acc)
|
||||||
# while _:
|
# bar.cursor.restore()
|
||||||
# _, frame = video.read()
|
# bar.draw(value=i+1)
|
||||||
# cv2.imshow("Video", frame)
|
|
||||||
# if cv2.waitKey(10) == 113:
|
video = cv2.VideoCapture("/home/xinyang/Desktop/Video.mp4")
|
||||||
# bbox = cv2.selectROI("frame", frame, False)
|
_ = True
|
||||||
# print(bbox)
|
while _:
|
||||||
# roi = frame[bbox[1]:bbox[1]+bbox[3], bbox[0]:bbox[0]+bbox[2]]
|
_, frame = video.read()
|
||||||
# roi = cv2.resize(roi, (48, 36))
|
cv2.imshow("Video", frame)
|
||||||
# cv2.imshow("roi", roi)
|
k = cv2.waitKey(10)
|
||||||
# cv2.waitKey(0)
|
if k == ord(" "):
|
||||||
# roi = roi.astype(np.float32)
|
bbox = cv2.selectROI("frame", frame, False)
|
||||||
# roi /= 255.0
|
print(bbox)
|
||||||
# roi = roi.reshape([1, 36, 48, 3])
|
roi = frame[bbox[1]:bbox[1]+bbox[3], bbox[0]:bbox[0]+bbox[2]]
|
||||||
# res = sess.run(y, feed_dict={x: roi})
|
roi = cv2.resize(roi, (48, 36))
|
||||||
# res = res.reshape([forward.OUTPUT_NODES])
|
cv2.imshow("roi", roi)
|
||||||
# print(np.argmax(res))
|
cv2.waitKey(0)
|
||||||
|
roi = roi.astype(np.float32)
|
||||||
|
roi /= 255.0
|
||||||
|
roi = roi.reshape([1, 36, 48, 3])
|
||||||
|
res = sess.run(y, feed_dict={x: roi})
|
||||||
|
res = res.reshape([forward.OUTPUT_NODES])
|
||||||
|
print(np.argmax(res))
|
||||||
|
elif k==ord("q"):
|
||||||
|
break
|
||||||
|
keep = True
|
||||||
|
while keep:
|
||||||
|
n = input()
|
||||||
|
im = cv2.imread(n)
|
||||||
|
im = cv2.resize(im, (48, 36))
|
||||||
|
cv2.imshow("im", im)
|
||||||
|
if cv2.waitKey(0) == ord("q"):
|
||||||
|
keep = False
|
||||||
|
im = im.astype(np.float32)
|
||||||
|
im /= 255.0
|
||||||
|
im = im.reshape([1, 36, 48, 3])
|
||||||
|
res = sess.run(y, feed_dict={x: im})
|
||||||
|
res = res.reshape([forward.OUTPUT_NODES])
|
||||||
|
print(np.argmax(res))
|
||||||
|
|
||||||
vars_val = sess.run(vars)
|
vars_val = sess.run(vars)
|
||||||
save_para("/home/xinyang/Desktop/AutoAim/tools/para", vars_val)
|
save_para("/home/xinyang/Desktop/AutoAim/tools/para", vars_val)
|
||||||
@@ -139,5 +154,5 @@ def train(dataset, show_bar=False):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
dataset = generate.DataSet("/home/xinyang/Desktop/DataSets")
|
dataset = generate.DataSet("/home/xinyang/Desktop/DataSets/box")
|
||||||
train(dataset, show_bar=True)
|
train(dataset, show_bar=True)
|
||||||
|
|||||||
@@ -29,13 +29,13 @@ def max_pool_2x2(x):
|
|||||||
CONV1_KERNAL_SIZE = 5
|
CONV1_KERNAL_SIZE = 5
|
||||||
|
|
||||||
# 第一层卷积输出通道数
|
# 第一层卷积输出通道数
|
||||||
CONV1_OUTPUT_CHANNELS = 4
|
CONV1_OUTPUT_CHANNELS = 6
|
||||||
|
|
||||||
# 第二层卷积核大小
|
# 第二层卷积核大小
|
||||||
CONV2_KERNAL_SIZE = 3
|
CONV2_KERNAL_SIZE = 3
|
||||||
|
|
||||||
# 第二层卷积输出通道数
|
# 第二层卷积输出通道数
|
||||||
CONV2_OUTPUT_CHANNELS = 8
|
CONV2_OUTPUT_CHANNELS = 10
|
||||||
|
|
||||||
# 第一层全连接宽度
|
# 第一层全连接宽度
|
||||||
FC1_OUTPUT_NODES = 16
|
FC1_OUTPUT_NODES = 16
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ SRC_COLS = 48
|
|||||||
# 原图像通道数
|
# 原图像通道数
|
||||||
SRC_CHANNELS = 3
|
SRC_CHANNELS = 3
|
||||||
|
|
||||||
|
|
||||||
class DataSet:
|
class DataSet:
|
||||||
def __init__(self, folder):
|
def __init__(self, folder):
|
||||||
self.train_samples = []
|
self.train_samples = []
|
||||||
@@ -37,12 +38,13 @@ class DataSet:
|
|||||||
dir = "%s/%d" % (folder, i)
|
dir = "%s/%d" % (folder, i)
|
||||||
files = os.listdir(dir)
|
files = os.listdir(dir)
|
||||||
for file in files:
|
for file in files:
|
||||||
if random.random() > 0.2:
|
if file[-3:] == "jpg":
|
||||||
self.train_samples.append(self.file2nparray("%s/%s" % (dir, file)))
|
if random.random() > 0.2:
|
||||||
self.train_labels.append(self.id2label(i))
|
self.train_samples.append(self.file2nparray("%s/%s" % (dir, file)))
|
||||||
else:
|
self.train_labels.append(self.id2label(i))
|
||||||
self.test_samples.append(self.file2nparray("%s/%s" % (dir, file)))
|
else:
|
||||||
self.test_labels.append(self.id2label(i))
|
self.test_samples.append(self.file2nparray("%s/%s" % (dir, file)))
|
||||||
|
self.test_labels.append(self.id2label(i))
|
||||||
self.train_samples = np.array(self.train_samples)
|
self.train_samples = np.array(self.train_samples)
|
||||||
self.train_labels = np.array(self.train_labels)
|
self.train_labels = np.array(self.train_labels)
|
||||||
self.test_samples = np.array(self.test_samples)
|
self.test_samples = np.array(self.test_samples)
|
||||||
|
|||||||
Reference in New Issue
Block a user