反小陀螺前置代码已完成。

This commit is contained in:
xinyang
2019-07-16 16:31:58 +08:00
parent 298c4c9114
commit 73230d91f1
24 changed files with 18873 additions and 18675 deletions

View File

@@ -40,7 +40,7 @@ class DataSet:
files = os.listdir(dir)
for file in tqdm(files, postfix={"loading id": i}, dynamic_ncols=True):
if file[-3:] == "jpg":
if random.random() > 0.2:
if random.random() > 0.7:
self.train_samples.append(self.file2nparray("%s/%s" % (dir, file)))
self.train_labels.append(self.id2label(i))
else:
@@ -61,6 +61,15 @@ class DataSet:
labels.append(self.train_labels[id])
return np.array(samples), np.array(labels)
def sample_test_sets(self, length):
samples = []
labels = []
for i in range(length):
id = random.randint(0, len(self.test_samples)-1)
samples.append(self.test_samples[id])
labels.append(self.test_labels[id])
return np.array(samples), np.array(labels)
def all_train_sets(self):
return self.train_samples[:], self.train_labels[:]