使用旧版候选区寻找方式。

This commit is contained in:
xinyang
2019-05-02 20:19:17 +08:00
parent ee83a373d3
commit 2f20367677
20 changed files with 12502 additions and 12395 deletions

View File

@@ -24,34 +24,40 @@ class DataSet:
self.test_labels = []
self.generate_data_sets(folder)
def file2nparray(self, name):
try:
image = cv2.imread(name)
image = cv2.resize(image, (SRC_COLS, SRC_ROWS))
image = image.astype(np.float32)
return image / 255.0
except:
print(name)
sys.exit(-1)
def id2label(self, id):
a = np.zeros([OUTPUT_NODES])
a[id] = 1
return a[:]
def generate_data_sets(self, folder):
def id2label(id):
a = np.zeros([OUTPUT_NODES])
a[id] = 1
return a[:]
def file2nparray(name):
try:
image = cv2.imread(name)
if image.shape[0] < 15:
return None
elif image.shape[1] < 10:
return None
image = cv2.resize(image, (SRC_COLS, SRC_ROWS))
image = image.astype(np.float32)
return image / 255.0, id2label(int(name.split("/")[-2]))
except TypeError:
print(name)
sys.exit(-1)
sets = []
for i in range(OUTPUT_NODES):
dir = "%s/%d" % (folder, i)
files = os.listdir(dir)
for file in tqdm(files, postfix={"loading id": i}, dynamic_ncols=True):
if file[-3:] == "jpg":
if random.random() > 0.2:
self.train_samples.append(self.file2nparray("%s/%s" % (dir, file)))
self.train_labels.append(self.id2label(i))
else:
self.test_samples.append(self.file2nparray("%s/%s" % (dir, file)))
self.test_labels.append(self.id2label(i))
x = file2nparray("%s/%s" % (dir, file))
if x is not None:
if random.random() > 0.2:
self.train_samples.append(x[0])
self.train_labels.append(x[1])
else:
self.test_samples.append(x[0])
self.test_labels.append(x[1])
self.train_samples = np.array(self.train_samples)
self.train_labels = np.array(self.train_labels)
self.test_samples = np.array(self.test_samples)
@@ -67,6 +73,15 @@ class DataSet:
labels.append(self.train_labels[id])
return np.array(samples), np.array(labels)
def sample_test_sets(self, length):
samples = []
labels = []
for i in range(length):
id = random.randint(0, len(self.test_samples)-1)
samples.append(self.test_samples[id])
labels.append(self.test_labels[id])
return np.array(samples), np.array(labels)
def all_train_sets(self):
return self.train_samples[:], self.train_labels[:]