更新CNN训练代码。
This commit is contained in:
@@ -87,8 +87,8 @@ def train(dataset, show_bar=False):
|
|||||||
|
|
||||||
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
|
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
|
||||||
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
|
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
|
||||||
|
config = tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True))
|
||||||
with tf.Session() as sess:
|
with tf.Session(config=config) as sess:
|
||||||
init_op = tf.global_variables_initializer()
|
init_op = tf.global_variables_initializer()
|
||||||
sess.run(init_op)
|
sess.run(init_op)
|
||||||
|
|
||||||
@@ -152,7 +152,7 @@ def train(dataset, show_bar=False):
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
print("Loading data sets...")
|
print("Loading data sets...")
|
||||||
dataset = generate.DataSet("/home/xinyang/Desktop/DataSets")
|
dataset = generate.DataSet("/home/xinyang/Desktop/dataset/box")
|
||||||
print("Finish!")
|
print("Finish!")
|
||||||
train(dataset, show_bar=True)
|
train(dataset, show_bar=True)
|
||||||
input("Press any key to end...")
|
input("Press any key to end...")
|
||||||
|
|||||||
@@ -41,7 +41,7 @@ CONV2_OUTPUT_CHANNELS = 16
|
|||||||
FC1_OUTPUT_NODES = 16
|
FC1_OUTPUT_NODES = 16
|
||||||
|
|
||||||
# 第二层全连接宽度(输出标签类型数)
|
# 第二层全连接宽度(输出标签类型数)
|
||||||
FC2_OUTPUT_NODES = 11
|
FC2_OUTPUT_NODES = 15
|
||||||
# 输出标签类型数
|
# 输出标签类型数
|
||||||
OUTPUT_NODES = FC2_OUTPUT_NODES
|
OUTPUT_NODES = FC2_OUTPUT_NODES
|
||||||
|
|
||||||
|
|||||||
@@ -24,15 +24,11 @@ class DataSet:
|
|||||||
self.test_labels = []
|
self.test_labels = []
|
||||||
self.generate_data_sets(folder)
|
self.generate_data_sets(folder)
|
||||||
|
|
||||||
def file2nparray(self, name):
|
def file2nparray(self, name, random=False):
|
||||||
try:
|
image = cv2.imread(name)
|
||||||
image = cv2.imread(name)
|
image = cv2.resize(image, (SRC_COLS, SRC_ROWS))
|
||||||
image = cv2.resize(image, (SRC_COLS, SRC_ROWS))
|
image = image.astype(np.float32)
|
||||||
image = image.astype(np.float32)
|
return image / 255.0
|
||||||
return image / 255.0
|
|
||||||
except:
|
|
||||||
print(name)
|
|
||||||
sys.exit(-1)
|
|
||||||
|
|
||||||
def id2label(self, id):
|
def id2label(self, id):
|
||||||
a = np.zeros([OUTPUT_NODES])
|
a = np.zeros([OUTPUT_NODES])
|
||||||
@@ -42,16 +38,20 @@ class DataSet:
|
|||||||
def generate_data_sets(self, folder):
|
def generate_data_sets(self, folder):
|
||||||
sets = []
|
sets = []
|
||||||
for i in range(OUTPUT_NODES):
|
for i in range(OUTPUT_NODES):
|
||||||
dir = "%s/%d" % (folder, i)
|
dir = "%s/id%d" % (folder, i)
|
||||||
files = os.listdir(dir)
|
files = os.listdir(dir)
|
||||||
for file in tqdm(files, postfix={"loading id": i}, dynamic_ncols=True):
|
for file in tqdm(files, postfix={"loading id": i}, dynamic_ncols=True):
|
||||||
if file[-3:] == "jpg":
|
if file[-3:] == "jpg":
|
||||||
if random.random() > 0.2:
|
try:
|
||||||
self.train_samples.append(self.file2nparray("%s/%s" % (dir, file)))
|
if random.random() > 0.2:
|
||||||
self.train_labels.append(self.id2label(i))
|
self.train_samples.append(self.file2nparray("%s/%s" % (dir, file)))
|
||||||
else:
|
self.train_labels.append(self.id2label(i))
|
||||||
self.test_samples.append(self.file2nparray("%s/%s" % (dir, file)))
|
else:
|
||||||
self.test_labels.append(self.id2label(i))
|
self.test_samples.append(self.file2nparray("%s/%s" % (dir, file)))
|
||||||
|
self.test_labels.append(self.id2label(i))
|
||||||
|
except:
|
||||||
|
print("%s/%s" % (dir, file))
|
||||||
|
continue
|
||||||
self.train_samples = np.array(self.train_samples)
|
self.train_samples = np.array(self.train_samples)
|
||||||
self.train_labels = np.array(self.train_labels)
|
self.train_labels = np.array(self.train_labels)
|
||||||
self.test_samples = np.array(self.test_samples)
|
self.test_samples = np.array(self.test_samples)
|
||||||
|
|||||||
Reference in New Issue
Block a user