更新了一大堆东西,我也说不清有什么。
This commit is contained in:
@@ -12,6 +12,7 @@ SRC_COLS = 48
|
||||
# 原图像通道数
|
||||
SRC_CHANNELS = 3
|
||||
|
||||
|
||||
class DataSet:
|
||||
def __init__(self, folder):
|
||||
self.train_samples = []
|
||||
@@ -37,12 +38,13 @@ class DataSet:
|
||||
dir = "%s/%d" % (folder, i)
|
||||
files = os.listdir(dir)
|
||||
for file in files:
|
||||
if random.random() > 0.2:
|
||||
self.train_samples.append(self.file2nparray("%s/%s" % (dir, file)))
|
||||
self.train_labels.append(self.id2label(i))
|
||||
else:
|
||||
self.test_samples.append(self.file2nparray("%s/%s" % (dir, file)))
|
||||
self.test_labels.append(self.id2label(i))
|
||||
if file[-3:] == "jpg":
|
||||
if random.random() > 0.2:
|
||||
self.train_samples.append(self.file2nparray("%s/%s" % (dir, file)))
|
||||
self.train_labels.append(self.id2label(i))
|
||||
else:
|
||||
self.test_samples.append(self.file2nparray("%s/%s" % (dir, file)))
|
||||
self.test_labels.append(self.id2label(i))
|
||||
self.train_samples = np.array(self.train_samples)
|
||||
self.train_labels = np.array(self.train_labels)
|
||||
self.test_samples = np.array(self.test_samples)
|
||||
|
||||
Reference in New Issue
Block a user