更新CNN训练代码。

This commit is contained in:
xinyang
2019-07-02 20:56:46 +08:00
parent cb6cf4b2fc
commit 14d1544113
3 changed files with 20 additions and 20 deletions

View File

@@ -87,8 +87,8 @@ def train(dataset, show_bar=False):
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
with tf.Session() as sess:
config = tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True))
with tf.Session(config=config) as sess:
init_op = tf.global_variables_initializer()
sess.run(init_op)
@@ -152,7 +152,7 @@ def train(dataset, show_bar=False):
if __name__ == "__main__":
print("Loading data sets...")
dataset = generate.DataSet("/home/xinyang/Desktop/DataSets")
dataset = generate.DataSet("/home/xinyang/Desktop/dataset/box")
print("Finish!")
train(dataset, show_bar=True)
input("Press any key to end...")

View File

@@ -41,7 +41,7 @@ CONV2_OUTPUT_CHANNELS = 16
FC1_OUTPUT_NODES = 16
# 第二层全连接宽度(输出标签类型数)
FC2_OUTPUT_NODES = 11
FC2_OUTPUT_NODES = 15
# 输出标签类型数
OUTPUT_NODES = FC2_OUTPUT_NODES

View File

@@ -24,15 +24,11 @@ class DataSet:
self.test_labels = []
self.generate_data_sets(folder)
def file2nparray(self, name):
try:
def file2nparray(self, name, random=False):
image = cv2.imread(name)
image = cv2.resize(image, (SRC_COLS, SRC_ROWS))
image = image.astype(np.float32)
return image / 255.0
except:
print(name)
sys.exit(-1)
def id2label(self, id):
a = np.zeros([OUTPUT_NODES])
@@ -42,16 +38,20 @@ class DataSet:
def generate_data_sets(self, folder):
sets = []
for i in range(OUTPUT_NODES):
dir = "%s/%d" % (folder, i)
dir = "%s/id%d" % (folder, i)
files = os.listdir(dir)
for file in tqdm(files, postfix={"loading id": i}, dynamic_ncols=True):
if file[-3:] == "jpg":
try:
if random.random() > 0.2:
self.train_samples.append(self.file2nparray("%s/%s" % (dir, file)))
self.train_labels.append(self.id2label(i))
else:
self.test_samples.append(self.file2nparray("%s/%s" % (dir, file)))
self.test_labels.append(self.id2label(i))
except:
print("%s/%s" % (dir, file))
continue
self.train_samples = np.array(self.train_samples)
self.train_labels = np.array(self.train_labels)
self.test_samples = np.array(self.test_samples)