大版本更新。

This commit is contained in:
xinyang
2019-07-06 20:20:48 +08:00
parent 9f9050e04a
commit e9a0e9ad7b
23 changed files with 15297 additions and 20022 deletions

View File

@@ -2,10 +2,8 @@ import numpy as np
import os
import cv2
import random
from forward import OUTPUT_NODES
import sys
import os
from tqdm import tqdm
from forward import OUTPUT_NODES
# 原图像行数
SRC_ROWS = 36
@@ -24,7 +22,7 @@ class DataSet:
self.test_labels = []
self.generate_data_sets(folder)
def file2nparray(self, name, random=False):
def file2nparray(self, name):
image = cv2.imread(name)
image = cv2.resize(image, (SRC_COLS, SRC_ROWS))
image = image.astype(np.float32)
@@ -42,16 +40,12 @@ class DataSet:
files = os.listdir(dir)
for file in tqdm(files, postfix={"loading id": i}, dynamic_ncols=True):
if file[-3:] == "jpg":
try:
if random.random() > 0.2:
self.train_samples.append(self.file2nparray("%s/%s" % (dir, file)))
self.train_labels.append(self.id2label(i))
else:
self.test_samples.append(self.file2nparray("%s/%s" % (dir, file)))
self.test_labels.append(self.id2label(i))
except:
print("%s/%s" % (dir, file))
continue
if random.random() > 0.2:
self.train_samples.append(self.file2nparray("%s/%s" % (dir, file)))
self.train_labels.append(self.id2label(i))
else:
self.test_samples.append(self.file2nparray("%s/%s" % (dir, file)))
self.test_labels.append(self.id2label(i))
self.train_samples = np.array(self.train_samples)
self.train_labels = np.array(self.train_labels)
self.test_samples = np.array(self.test_samples)
@@ -67,15 +61,6 @@ class DataSet:
labels.append(self.train_labels[id])
return np.array(samples), np.array(labels)
def sample_test_sets(self, length):
samples = []
labels = []
for i in range(length):
id = random.randint(0, len(self.test_samples)-1)
samples.append(self.test_samples[id])
labels.append(self.test_labels[id])
return np.array(samples), np.array(labels)
def all_train_sets(self):
return self.train_samples[:], self.train_labels[:]