更新CNN训练代码。
This commit is contained in:
@@ -24,15 +24,11 @@ class DataSet:
|
||||
self.test_labels = []
|
||||
self.generate_data_sets(folder)
|
||||
|
||||
def file2nparray(self, name):
|
||||
try:
|
||||
image = cv2.imread(name)
|
||||
image = cv2.resize(image, (SRC_COLS, SRC_ROWS))
|
||||
image = image.astype(np.float32)
|
||||
return image / 255.0
|
||||
except:
|
||||
print(name)
|
||||
sys.exit(-1)
|
||||
def file2nparray(self, name, random=False):
|
||||
image = cv2.imread(name)
|
||||
image = cv2.resize(image, (SRC_COLS, SRC_ROWS))
|
||||
image = image.astype(np.float32)
|
||||
return image / 255.0
|
||||
|
||||
def id2label(self, id):
|
||||
a = np.zeros([OUTPUT_NODES])
|
||||
@@ -42,16 +38,20 @@ class DataSet:
|
||||
def generate_data_sets(self, folder):
|
||||
sets = []
|
||||
for i in range(OUTPUT_NODES):
|
||||
dir = "%s/%d" % (folder, i)
|
||||
dir = "%s/id%d" % (folder, i)
|
||||
files = os.listdir(dir)
|
||||
for file in tqdm(files, postfix={"loading id": i}, dynamic_ncols=True):
|
||||
if file[-3:] == "jpg":
|
||||
if random.random() > 0.2:
|
||||
self.train_samples.append(self.file2nparray("%s/%s" % (dir, file)))
|
||||
self.train_labels.append(self.id2label(i))
|
||||
else:
|
||||
self.test_samples.append(self.file2nparray("%s/%s" % (dir, file)))
|
||||
self.test_labels.append(self.id2label(i))
|
||||
try:
|
||||
if random.random() > 0.2:
|
||||
self.train_samples.append(self.file2nparray("%s/%s" % (dir, file)))
|
||||
self.train_labels.append(self.id2label(i))
|
||||
else:
|
||||
self.test_samples.append(self.file2nparray("%s/%s" % (dir, file)))
|
||||
self.test_labels.append(self.id2label(i))
|
||||
except:
|
||||
print("%s/%s" % (dir, file))
|
||||
continue
|
||||
self.train_samples = np.array(self.train_samples)
|
||||
self.train_labels = np.array(self.train_labels)
|
||||
self.test_samples = np.array(self.test_samples)
|
||||
|
||||
Reference in New Issue
Block a user