From 14d1544113cee136f93faa99074b125d951592bf Mon Sep 17 00:00:00 2001 From: xinyang <895639507@qq.com> Date: Tue, 2 Jul 2019 20:56:46 +0800 Subject: [PATCH] =?UTF-8?q?=E6=9B=B4=E6=96=B0CNN=E8=AE=AD=E7=BB=83?= =?UTF-8?q?=E4=BB=A3=E7=A0=81=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tools/TrainCNN/backward.py | 6 +++--- tools/TrainCNN/forward.py | 2 +- tools/TrainCNN/generate.py | 32 ++++++++++++++++---------------- 3 files changed, 20 insertions(+), 20 deletions(-) diff --git a/tools/TrainCNN/backward.py b/tools/TrainCNN/backward.py index 797defc..cead37b 100644 --- a/tools/TrainCNN/backward.py +++ b/tools/TrainCNN/backward.py @@ -87,8 +87,8 @@ def train(dataset, show_bar=False): correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1)) accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) - - with tf.Session() as sess: + config = tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True)) + with tf.Session(config=config) as sess: init_op = tf.global_variables_initializer() sess.run(init_op) @@ -152,7 +152,7 @@ def train(dataset, show_bar=False): if __name__ == "__main__": print("Loading data sets...") - dataset = generate.DataSet("/home/xinyang/Desktop/DataSets") + dataset = generate.DataSet("/home/xinyang/Desktop/dataset/box") print("Finish!") train(dataset, show_bar=True) input("Press any key to end...") diff --git a/tools/TrainCNN/forward.py b/tools/TrainCNN/forward.py index 6200957..cb627b8 100644 --- a/tools/TrainCNN/forward.py +++ b/tools/TrainCNN/forward.py @@ -41,7 +41,7 @@ CONV2_OUTPUT_CHANNELS = 16 FC1_OUTPUT_NODES = 16 # 第二层全连接宽度(输出标签类型数) -FC2_OUTPUT_NODES = 11 +FC2_OUTPUT_NODES = 15 # 输出标签类型数 OUTPUT_NODES = FC2_OUTPUT_NODES diff --git a/tools/TrainCNN/generate.py b/tools/TrainCNN/generate.py index 20c2a63..84d2112 100644 --- a/tools/TrainCNN/generate.py +++ b/tools/TrainCNN/generate.py @@ -24,15 +24,11 @@ class DataSet: self.test_labels = [] self.generate_data_sets(folder) - def file2nparray(self, name): - try: - image = cv2.imread(name) - image = cv2.resize(image, (SRC_COLS, SRC_ROWS)) - image = image.astype(np.float32) - return image / 255.0 - except: - print(name) - sys.exit(-1) + def file2nparray(self, name, random=False): + image = cv2.imread(name) + image = cv2.resize(image, (SRC_COLS, SRC_ROWS)) + image = image.astype(np.float32) + return image / 255.0 def id2label(self, id): a = np.zeros([OUTPUT_NODES]) @@ -42,16 +38,20 @@ class DataSet: def generate_data_sets(self, folder): sets = [] for i in range(OUTPUT_NODES): - dir = "%s/%d" % (folder, i) + dir = "%s/id%d" % (folder, i) files = os.listdir(dir) for file in tqdm(files, postfix={"loading id": i}, dynamic_ncols=True): if file[-3:] == "jpg": - if random.random() > 0.2: - self.train_samples.append(self.file2nparray("%s/%s" % (dir, file))) - self.train_labels.append(self.id2label(i)) - else: - self.test_samples.append(self.file2nparray("%s/%s" % (dir, file))) - self.test_labels.append(self.id2label(i)) + try: + if random.random() > 0.2: + self.train_samples.append(self.file2nparray("%s/%s" % (dir, file))) + self.train_labels.append(self.id2label(i)) + else: + self.test_samples.append(self.file2nparray("%s/%s" % (dir, file))) + self.test_labels.append(self.id2label(i)) + except: + print("%s/%s" % (dir, file)) + continue self.train_samples = np.array(self.train_samples) self.train_labels = np.array(self.train_labels) self.test_samples = np.array(self.test_samples)