更新了一大堆东西,我也说不清有什么。

This commit is contained in:
xinyang
2019-04-29 14:42:47 +08:00
parent 4015ee5910
commit 25927f0b34
9 changed files with 140 additions and 112 deletions

View File

@@ -1,34 +1,34 @@
cmake_minimum_required(VERSION 3.5) CMAKE_MINIMUM_REQUIRED(VERSION 3.5)
project(AutoAim) PROJECT(AutoAim)
set(CMAKE_CXX_STANDARD 11) SET(CMAKE_CXX_STANDARD 11)
SET(CMAKE_BUILD_TYPE DEBUG) SET(CMAKE_BUILD_TYPE RELEASE)
SET(CMAKE_CXX_FLAGS "-DPROJECT_DIR=\"\\\"${PROJECT_SOURCE_DIR}\\\"\"")
FIND_PROGRAM(CCACHE_FOUND ccache) FIND_PROGRAM(CCACHE_FOUND ccache)
IF(CCACHE_FOUND) IF(CCACHE_FOUND)
set_property(GLOBAL PROPERTY RULE_LAUNCH_COMPILE ccache) SET_PROPERTY(GLOBAL PROPERTY RULE_LAUNCH_COMPILE ccache)
set_property(GLOBAL PROPERTY RULE_LAUNCH_LINK ccache) SET_PROPERTY(GLOBAL PROPERTY RULE_LAUNCH_LINK ccache)
message("< Use ccache for compiler >") MESSAGE("< Use ccache for compiler >")
ENDIF() ENDIF()
FIND_PACKAGE(OpenCV 3 REQUIRED) FIND_PACKAGE(OpenCV 3 REQUIRED)
FIND_PACKAGE(Eigen3 REQUIRED) FIND_PACKAGE(Eigen3 REQUIRED)
FIND_PACKAGE(Threads) FIND_PACKAGE(Threads)
include_directories( ${EIGEN3_INCLUDE_DIR} ) INCLUDE_DIRECTORIES(${EIGEN3_INCLUDE_DIR})
include_directories( ${PROJECT_SOURCE_DIR}/energy/include ) INCLUDE_DIRECTORIES(${PROJECT_SOURCE_DIR}/energy/include)
include_directories( ${PROJECT_SOURCE_DIR}/armor/include ) INCLUDE_DIRECTORIES(${PROJECT_SOURCE_DIR}/armor/include)
include_directories( ${PROJECT_SOURCE_DIR}/others/include ) INCLUDE_DIRECTORIES(${PROJECT_SOURCE_DIR}/others/include)
FILE(GLOB_RECURSE sourcefiles "others/src/*.cpp" "energy/src/*cpp" "armor/src/*.cpp") FILE(GLOB_RECURSE sourcefiles "others/src/*.cpp" "energy/src/*cpp" "armor/src/*.cpp")
add_executable(run main.cpp ${sourcefiles} ) ADD_EXECUTABLE(run main.cpp ${sourcefiles} )
TARGET_LINK_LIBRARIES(run ${CMAKE_THREAD_LIBS_INIT}) TARGET_LINK_LIBRARIES(run ${CMAKE_THREAD_LIBS_INIT})
TARGET_LINK_LIBRARIES(run ${OpenCV_LIBS}) TARGET_LINK_LIBRARIES(run ${OpenCV_LIBS})
TARGET_LINK_LIBRARIES(run ${PROJECT_SOURCE_DIR}/others/libMVSDK.so) TARGET_LINK_LIBRARIES(run ${PROJECT_SOURCE_DIR}/others/libMVSDK.so)
# Todo ADD_CUSTOM_TARGET(train COMMAND "gnome-terminal" "-x" "bash" "-c" "\"${PROJECT_SOURCE_DIR}/tools/TrainCNN/backward.py\"" )
# ADD_CUSTOM_TARGET(bind-monitor COMMAND "")
# Todo # Todo
# ADD_CUSTOM_TARGET(train COMMAND "") # ADD_CUSTOM_TARGET(bind-monitor COMMAND "")

View File

@@ -2,6 +2,9 @@
// Created by xinyang on 19-3-27. // Created by xinyang on 19-3-27.
// //
#include <log.h> #include <log.h>
#include <options/options.h>
#include <show_images/show_images.h>
#include <opencv2/highgui.hpp>
#include <armor_finder/armor_finder.h> #include <armor_finder/armor_finder.h>
ArmorFinder::ArmorFinder(EnemyColor color, Uart &u, string paras_folder) : ArmorFinder::ArmorFinder(EnemyColor color, Uart &u, string paras_folder) :
@@ -11,38 +14,48 @@ ArmorFinder::ArmorFinder(EnemyColor color, Uart &u, string paras_folder) :
classifier(std::move(paras_folder)), classifier(std::move(paras_folder)),
contour_area(0) contour_area(0)
{ {
auto para = TrackerToUse::Params(); // auto para = TrackerToUse::Params();
para.desc_npca = 1; // para.desc_npca = 1;
para.desc_pca = 0; // para.desc_pca = 0;
tracker = TrackerToUse::create(para); // tracker = TrackerToUse::create(para);
if(!tracker){ // if(!tracker){
LOGW("Tracker Not init"); // LOGW("Tracker Not init");
} // }
} }
void ArmorFinder::run(cv::Mat &src) { void ArmorFinder::run(cv::Mat &src) {
cv::Mat src_use; cv::Mat src_use;
// if (src.type() == CV_8UC3) { src_use = src.clone();
// cv::cvtColor(src, src_use, CV_RGB2GRAY);
// }else{
src_use = src.clone();
// }
cv::cvtColor(src_use, src_gray, CV_RGB2GRAY); cv::cvtColor(src_use, src_gray, CV_RGB2GRAY);
stateSearchingTarget(src_use); if(show_armor_box){
return; showArmorBox("box", src, armor_box);
cv::waitKey(1);
}
// stateSearchingTarget(src_use);
// return;
switch (state){ switch (state){
case SEARCHING_STATE: case SEARCHING_STATE:
if(stateSearchingTarget(src_use)){ if(stateSearchingTarget(src_use)){
if((armor_box & cv::Rect2d(0, 0, 640, 480)) == armor_box) { if((armor_box & cv::Rect2d(0, 0, 640, 480)) == armor_box) {
cv::Mat roi = src_use.clone()(armor_box); // cv::Mat roi = src_gray.clone()(armor_box);
cv::threshold(roi, roi, 200, 255, cv::THRESH_BINARY); // cv::threshold(roi, roi, 200, 255, cv::THRESH_BINARY);
contour_area = cv::countNonZero(roi); // contour_area = cv::countNonZero(roi);
auto para = TrackerToUse::Params(); // auto para = TrackerToUse::Params();
para.desc_npca = 1; // para.desc_npca = 1;
para.desc_pca = 0; // para.desc_pca = 0;
tracker = TrackerToUse::create(para); // tracker = TrackerToUse::create(para);
// tracker->init(src_gray, armor_box);
// tracker->update(src_gray, armor_box);
cv::Mat roi = src_use.clone()(armor_box), roi_gray;
cv::cvtColor(roi, roi_gray, CV_RGB2GRAY);
// cv::imshow("boxroi", roi);
// cv::waitKey(0);
cv::threshold(roi_gray, roi_gray, 180, 255, cv::THRESH_BINARY);
contour_area = cv::countNonZero(roi_gray);
LOGW("%d", contour_area);
tracker = TrackerToUse::create();
tracker->init(src_use, armor_box); tracker->init(src_use, armor_box);
state = TRACKING_STATE; state = TRACKING_STATE;
LOGW("into track"); LOGW("into track");
@@ -50,7 +63,7 @@ void ArmorFinder::run(cv::Mat &src) {
} }
break; break;
case TRACKING_STATE: case TRACKING_STATE:
if(!stateTrackingTarget(src_gray)){ if(!stateTrackingTarget(src_use)){
state = SEARCHING_STATE; state = SEARCHING_STATE;
//std::cout << "into search!" << std::endl; //std::cout << "into search!" << std::endl;
} }

View File

@@ -259,7 +259,7 @@ Classifier::Classifier(const string &folder) : state(true){
fc2_w = load_fc_w(folder+"fc2_w"); fc2_w = load_fc_w(folder+"fc2_w");
fc2_b = load_fc_b(folder+"fc2_b"); fc2_b = load_fc_b(folder+"fc2_b");
if(state){ if(state){
LOGM("Load paras success!"); LOGM("Load para success!");
} }
} }

View File

@@ -7,7 +7,6 @@
#include "image_process/image_process.h" #include "image_process/image_process.h"
#include <log.h> #include <log.h>
#include <show_images/show_images.h> #include <show_images/show_images.h>
#include <options/options.h> #include <options/options.h>
typedef std::vector<LightBlob> LightBlobs; typedef std::vector<LightBlob> LightBlobs;
@@ -32,11 +31,10 @@ static void pipelineLightBlobPreprocess(cv::Mat &src) {
} }
static bool findLightBlobs(const cv::Mat &src, LightBlobs &light_blobs) { static bool findLightBlobs(const cv::Mat &src, LightBlobs &light_blobs) {
static cv::Mat src_bin; // static cv::Mat src_bin;
cv::threshold(src, src_bin, 80, 255, CV_THRESH_BINARY);
std::vector<std::vector<cv::Point> > light_contours; std::vector<std::vector<cv::Point> > light_contours;
cv::findContours(src_bin, light_contours, CV_RETR_EXTERNAL, CV_CHAIN_APPROX_NONE); cv::findContours(src, light_contours, CV_RETR_EXTERNAL, CV_CHAIN_APPROX_NONE);
for (auto &light_contour : light_contours) { for (auto &light_contour : light_contours) {
cv::RotatedRect rect = cv::minAreaRect(light_contour); cv::RotatedRect rect = cv::minAreaRect(light_contour);
if(isValidLightBlob(rect)){ if(isValidLightBlob(rect)){
@@ -117,7 +115,7 @@ static bool findArmorBoxes(LightBlobs &light_blobs, std::vector<cv::Rect2d> &arm
min_x = fmin(rect_left.x, rect_right.x); min_x = fmin(rect_left.x, rect_right.x);
max_x = fmax(rect_left.x + rect_left.width, rect_right.x + rect_right.width); max_x = fmax(rect_left.x + rect_left.width, rect_right.x + rect_right.width);
min_y = fmin(rect_left.y, rect_right.y) - 5; min_y = fmin(rect_left.y, rect_right.y) - 5;
max_y = fmax(rect_left.y + rect_left.height, rect_right.y + rect_right.height) + 5; max_y = fmax(rect_left.y + rect_left.height, rect_right.y + rect_right.height);
if (min_x < 0 || max_x > 640 || min_y < 0 || max_y > 480) { if (min_x < 0 || max_x > 640 || min_y < 0 || max_y > 480) {
continue; continue;
} }
@@ -149,19 +147,21 @@ bool judge_light_color(std::vector<LightBlob> &light, std::vector<LightBlob> &co
} }
bool ArmorFinder::stateSearchingTarget(cv::Mat &src) { bool ArmorFinder::stateSearchingTarget(cv::Mat &src) {
cv::Mat split, pmsrc=src.clone(); cv::Mat split, pmsrc=src.clone(), src_bin;
LightBlobs light_blobs, pm_light_blobs, light_blobs_real; LightBlobs light_blobs, pm_light_blobs, light_blobs_real;
std::vector<cv::Rect2d> armor_boxes, boxes_one, boxes_two, boxes_three; std::vector<cv::Rect2d> armor_boxes, boxes_one, boxes_two, boxes_three;
// cv::resize(src, pmsrc, cv::Size(320, 240)); // cv::resize(src, pmsrc, cv::Size(320, 240));
imageColorSplit(src, split, enemy_color); imageColorSplit(src, split, enemy_color);
imagePreProcess(split); cv::threshold(split, src_bin, 130, 255, CV_THRESH_BINARY);
cv::resize(split, split, cv::Size(640, 480)); imagePreProcess(src_bin);
cv::imshow("bin", src_bin);
// cv::resize(split, split, cv::Size(640, 480));
// pipelineLightBlobPreprocess(pmsrc); // pipelineLightBlobPreprocess(pmsrc);
// if(!findLightBlobs(pmsrc, pm_light_blobs)){ // if(!findLightBlobs(pmsrc, pm_light_blobs)){
// return false; // return false;
// } // }
if(!findLightBlobs(split, light_blobs)){ if(!findLightBlobs(src_bin, light_blobs)){
return false; return false;
} }
// if(!judge_light_color(light_blobs, pm_light_blobs, light_blobs_real)){ // if(!judge_light_color(light_blobs, pm_light_blobs, light_blobs_real)){
@@ -182,10 +182,7 @@ bool ArmorFinder::stateSearchingTarget(cv::Mat &src) {
for(auto box : armor_boxes){ for(auto box : armor_boxes){
cv::Mat roi = src(box).clone(); cv::Mat roi = src(box).clone();
cv::resize(roi, roi, cv::Size(48, 36)); cv::resize(roi, roi, cv::Size(48, 36));
// cv::imshow("roi", roi);
// cv::waitKey(0);
int c = classifier(roi); int c = classifier(roi);
// cout << c << endl;
switch(c){ switch(c){
case 1: case 1:
boxes_one.emplace_back(box); boxes_one.emplace_back(box);
@@ -204,6 +201,8 @@ bool ArmorFinder::stateSearchingTarget(cv::Mat &src) {
armor_box = boxes_two[0]; armor_box = boxes_two[0];
}else if(!boxes_three.empty()){ }else if(!boxes_three.empty()){
armor_box = boxes_three[0]; armor_box = boxes_three[0];
} else{
return false;
} }
if(show_armor_box){ if(show_armor_box){
showArmorBoxClass("class", src, boxes_one, boxes_two, boxes_three); showArmorBoxClass("class", src, boxes_one, boxes_two, boxes_three);
@@ -211,10 +210,6 @@ bool ArmorFinder::stateSearchingTarget(cv::Mat &src) {
}else{ }else{
armor_box = armor_boxes[0]; armor_box = armor_boxes[0];
} }
if(show_armor_box){
showArmorBox("box", src, armor_box);
cv::waitKey(1);
}
if(split.size() == cv::Size(320, 240)){ if(split.size() == cv::Size(320, 240)){
armor_box.x *= 2; armor_box.x *= 2;
armor_box.y *= 2; armor_box.y *= 2;

View File

@@ -5,16 +5,18 @@
#include <armor_finder/armor_finder.h> #include <armor_finder/armor_finder.h>
bool ArmorFinder::stateTrackingTarget(cv::Mat &src) { bool ArmorFinder::stateTrackingTarget(cv::Mat &src) {
auto last = armor_box; if(!tracker->update(src, armor_box)){
tracker->update(src, armor_box); return false;
}
if((armor_box & cv::Rect2d(0, 0, 640, 480)) != armor_box){ if((armor_box & cv::Rect2d(0, 0, 640, 480)) != armor_box){
return false; return false;
} }
cv::Mat roi = src(armor_box);
threshold(roi, roi, 200, 255, cv::THRESH_BINARY);
if(abs(cv::countNonZero(roi) - contour_area) > contour_area * 0.3){ cv::Mat roi = src.clone()(armor_box), roi_gray;
cv::cvtColor(roi, roi_gray, CV_RGB2GRAY);
cv::threshold(roi_gray, roi_gray, 180, 255, cv::THRESH_BINARY);
contour_area = cv::countNonZero(roi_gray);
if(abs(cv::countNonZero(roi_gray) - contour_area) > contour_area * 0.3){
return false; return false;
} }
return sendBoxPosition(); return sendBoxPosition();

View File

@@ -13,32 +13,33 @@
#include <camera/wrapper_head.h> #include <camera/wrapper_head.h>
#include <armor_finder/armor_finder.h> #include <armor_finder/armor_finder.h>
#include <options/options.h> #include <options/options.h>
#include <thread>
#define DO_NOT_CNT_TIME
#include <log.h> #include <log.h>
#include <thread> #define PATH PROJECT_DIR
#define ENERGY_STATE 1
#define ARMOR_STATE 0
using namespace cv; using namespace cv;
using namespace std; using namespace std;
#define ENERGY_STATE 1
#define ARMOR_STATE 0
int state = ENERGY_STATE; int state = ARMOR_STATE;
float curr_yaw=0, curr_pitch=0; float curr_yaw=0, curr_pitch=0;
float mark_yaw=0, mark_pitch=0; float mark_yaw=0, mark_pitch=0;
int mark = 0; int mark = 0;
void uartReceive(Uart* uart); void uartReceive(Uart* uart);
int main(int argc, char *argv[]) int main(int argc, char *argv[]){
{
process_options(argc, argv); process_options(argc, argv);
Uart uart; Uart uart;
thread receive(uartReceive, &uart); thread receive(uartReceive, &uart);
bool flag = true; bool flag = true;
while (flag) while (flag){
{
int ally_color = ALLY_RED; int ally_color = ALLY_RED;
int energy_part_rotation = CLOCKWISE; int energy_part_rotation = CLOCKWISE;
@@ -54,8 +55,8 @@ int main(int argc, char *argv[])
// video_armor = new CameraWrapper(); // video_armor = new CameraWrapper();
video_energy = new CameraWrapper(); video_energy = new CameraWrapper();
}else { }else {
video_armor = new VideoWrapper("r_l_640.avi"); video_armor = new VideoWrapper("/home/xinyang/Desktop/Video.mp4");
video_energy = new VideoWrapper("r_l_640.avi"); video_energy = new VideoWrapper("/home/xinyang/Desktop/Video.mp4");
} }
if (video_energy->init()) { if (video_energy->init()) {
cout << "Video source initialization successfully." << endl; cout << "Video source initialization successfully." << endl;
@@ -63,7 +64,7 @@ int main(int argc, char *argv[])
Mat energy_src, armor_src; Mat energy_src, armor_src;
ArmorFinder armorFinder(ENEMY_BLUE, uart, "/home/xinyang/Desktop/AutoAim/tools/para/"); ArmorFinder armorFinder(ENEMY_BLUE, uart, PATH"/tools/para/");
Energy energy(uart); Energy energy(uart);
energy.setAllyColor(ally_color); energy.setAllyColor(ally_color);

79
tools/TrainCNN/backward.py Normal file → Executable file
View File

@@ -1,5 +1,6 @@
#!/usr/bin/python3
import tensorflow as tf import tensorflow as tf
from progressive.bar import Bar from tqdm import tqdm
import generate import generate
import forward import forward
import cv2 import cv2
@@ -51,7 +52,7 @@ def save_para(folder, paras):
STEPS = 20000 STEPS = 20000
BATCH = 10 BATCH = 30
LEARNING_RATE_BASE = 0.01 LEARNING_RATE_BASE = 0.01
LEARNING_RATE_DECAY = 0.99 LEARNING_RATE_DECAY = 0.99
MOVING_AVERAGE_DECAY = 0.99 MOVING_AVERAGE_DECAY = 0.99
@@ -85,18 +86,13 @@ def train(dataset, show_bar=False):
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1)) correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
acc = 0
with tf.Session() as sess: with tf.Session() as sess:
init_op = tf.global_variables_initializer() init_op = tf.global_variables_initializer()
sess.run(init_op) sess.run(init_op)
if show_bar: bar = tqdm(range(STEPS), ascii=True, dynamic_ncols=True)
bar = Bar(max_value=STEPS, width=u'50%') for i in bar:
bar.cursor.clear_lines(1)
bar.cursor.save()
for i in range(STEPS):
images_samples, labels_samples = dataset.sample_train_sets(BATCH) images_samples, labels_samples = dataset.sample_train_sets(BATCH)
_, loss_value, step = sess.run( _, loss_value, step = sess.run(
@@ -107,30 +103,49 @@ def train(dataset, show_bar=False):
if i % 100 == 0: if i % 100 == 0:
if i % 1000 == 0: if i % 1000 == 0:
acc = sess.run(accuracy, feed_dict={x: test_images, y_: test_labels}) acc = sess.run(accuracy, feed_dict={x: test_images, y_: test_labels})
bar.set_postfix({"loss": loss_value, "acc": acc})
if show_bar:
bar.title = "step: %d, loss: %f, acc: %f" % (step, loss_value, acc)
bar.cursor.restore()
bar.draw(value=i+1)
# video = cv2.VideoCapture("/home/xinyang/Desktop/Video.mp4") # if show_bar:
# _ = True # bar.title = "step: %d, loss: %f, acc: %f" % (step, loss_value, acc)
# while _: # bar.cursor.restore()
# _, frame = video.read() # bar.draw(value=i+1)
# cv2.imshow("Video", frame)
# if cv2.waitKey(10) == 113: video = cv2.VideoCapture("/home/xinyang/Desktop/Video.mp4")
# bbox = cv2.selectROI("frame", frame, False) _ = True
# print(bbox) while _:
# roi = frame[bbox[1]:bbox[1]+bbox[3], bbox[0]:bbox[0]+bbox[2]] _, frame = video.read()
# roi = cv2.resize(roi, (48, 36)) cv2.imshow("Video", frame)
# cv2.imshow("roi", roi) k = cv2.waitKey(10)
# cv2.waitKey(0) if k == ord(" "):
# roi = roi.astype(np.float32) bbox = cv2.selectROI("frame", frame, False)
# roi /= 255.0 print(bbox)
# roi = roi.reshape([1, 36, 48, 3]) roi = frame[bbox[1]:bbox[1]+bbox[3], bbox[0]:bbox[0]+bbox[2]]
# res = sess.run(y, feed_dict={x: roi}) roi = cv2.resize(roi, (48, 36))
# res = res.reshape([forward.OUTPUT_NODES]) cv2.imshow("roi", roi)
# print(np.argmax(res)) cv2.waitKey(0)
roi = roi.astype(np.float32)
roi /= 255.0
roi = roi.reshape([1, 36, 48, 3])
res = sess.run(y, feed_dict={x: roi})
res = res.reshape([forward.OUTPUT_NODES])
print(np.argmax(res))
elif k==ord("q"):
break
keep = True
while keep:
n = input()
im = cv2.imread(n)
im = cv2.resize(im, (48, 36))
cv2.imshow("im", im)
if cv2.waitKey(0) == ord("q"):
keep = False
im = im.astype(np.float32)
im /= 255.0
im = im.reshape([1, 36, 48, 3])
res = sess.run(y, feed_dict={x: im})
res = res.reshape([forward.OUTPUT_NODES])
print(np.argmax(res))
vars_val = sess.run(vars) vars_val = sess.run(vars)
save_para("/home/xinyang/Desktop/AutoAim/tools/para", vars_val) save_para("/home/xinyang/Desktop/AutoAim/tools/para", vars_val)
@@ -139,5 +154,5 @@ def train(dataset, show_bar=False):
if __name__ == "__main__": if __name__ == "__main__":
dataset = generate.DataSet("/home/xinyang/Desktop/DataSets") dataset = generate.DataSet("/home/xinyang/Desktop/DataSets/box")
train(dataset, show_bar=True) train(dataset, show_bar=True)

View File

@@ -29,13 +29,13 @@ def max_pool_2x2(x):
CONV1_KERNAL_SIZE = 5 CONV1_KERNAL_SIZE = 5
# 第一层卷积输出通道数 # 第一层卷积输出通道数
CONV1_OUTPUT_CHANNELS = 4 CONV1_OUTPUT_CHANNELS = 6
# 第二层卷积核大小 # 第二层卷积核大小
CONV2_KERNAL_SIZE = 3 CONV2_KERNAL_SIZE = 3
# 第二层卷积输出通道数 # 第二层卷积输出通道数
CONV2_OUTPUT_CHANNELS = 8 CONV2_OUTPUT_CHANNELS = 10
# 第一层全连接宽度 # 第一层全连接宽度
FC1_OUTPUT_NODES = 16 FC1_OUTPUT_NODES = 16

View File

@@ -12,6 +12,7 @@ SRC_COLS = 48
# 原图像通道数 # 原图像通道数
SRC_CHANNELS = 3 SRC_CHANNELS = 3
class DataSet: class DataSet:
def __init__(self, folder): def __init__(self, folder):
self.train_samples = [] self.train_samples = []
@@ -37,12 +38,13 @@ class DataSet:
dir = "%s/%d" % (folder, i) dir = "%s/%d" % (folder, i)
files = os.listdir(dir) files = os.listdir(dir)
for file in files: for file in files:
if random.random() > 0.2: if file[-3:] == "jpg":
self.train_samples.append(self.file2nparray("%s/%s" % (dir, file))) if random.random() > 0.2:
self.train_labels.append(self.id2label(i)) self.train_samples.append(self.file2nparray("%s/%s" % (dir, file)))
else: self.train_labels.append(self.id2label(i))
self.test_samples.append(self.file2nparray("%s/%s" % (dir, file))) else:
self.test_labels.append(self.id2label(i)) self.test_samples.append(self.file2nparray("%s/%s" % (dir, file)))
self.test_labels.append(self.id2label(i))
self.train_samples = np.array(self.train_samples) self.train_samples = np.array(self.train_samples)
self.train_labels = np.array(self.train_labels) self.train_labels = np.array(self.train_labels)
self.test_samples = np.array(self.test_samples) self.test_samples = np.array(self.test_samples)