使用旧版候选区寻找方式。
This commit is contained in:
@@ -53,7 +53,7 @@ def save_para(folder, paras):
|
||||
save_bias(fp, paras[7])
|
||||
|
||||
|
||||
STEPS = 20000
|
||||
STEPS = 10000
|
||||
BATCH = 30
|
||||
LEARNING_RATE_BASE = 0.01
|
||||
LEARNING_RATE_DECAY = 0.99
|
||||
@@ -61,8 +61,6 @@ MOVING_AVERAGE_DECAY = 0.99
|
||||
|
||||
|
||||
def train(dataset, show_bar=False):
|
||||
test_images, test_labels = dataset.all_test_sets()
|
||||
|
||||
x = tf.placeholder(tf.float32, [None, generate.SRC_ROWS, generate.SRC_COLS, generate.SRC_CHANNELS])
|
||||
y_= tf.placeholder(tf.float32, [None, forward.OUTPUT_NODES])
|
||||
nodes, vars = forward.forward(x, 0.001)
|
||||
@@ -104,7 +102,8 @@ def train(dataset, show_bar=False):
|
||||
|
||||
if i % 100 == 0:
|
||||
if i % 1000 == 0:
|
||||
acc = sess.run(accuracy, feed_dict={x: test_images, y_: test_labels})
|
||||
test_samples, test_labels = dataset.sample_train_sets(5000)
|
||||
acc = sess.run(accuracy, feed_dict={x: test_samples, y_: test_labels})
|
||||
bar.set_postfix({"loss": loss_value, "acc": acc})
|
||||
|
||||
|
||||
@@ -151,8 +150,8 @@ def train(dataset, show_bar=False):
|
||||
|
||||
vars_val = sess.run(vars)
|
||||
save_para("/home/xinyang/Desktop/AutoAim/tools/para", vars_val)
|
||||
nodes_val = sess.run(nodes, feed_dict={x:test_images})
|
||||
return vars_val, nodes_val
|
||||
nodes_val = sess.run(nodes, feed_dict={x:test_samples})
|
||||
return vars_val, nodes_val, test_samples
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -160,3 +159,4 @@ if __name__ == "__main__":
|
||||
dataset = generate.DataSet("/home/xinyang/Desktop/DataSets/box")
|
||||
print("Finish!")
|
||||
train(dataset, show_bar=True)
|
||||
input("Press any key to continue...")
|
||||
|
||||
@@ -64,8 +64,8 @@ def forward(x, regularizer=None):
|
||||
[CONV2_KERNAL_SIZE, CONV2_KERNAL_SIZE, CONV1_OUTPUT_CHANNELS, CONV2_OUTPUT_CHANNELS]
|
||||
)
|
||||
conv2_b = get_bias([CONV2_OUTPUT_CHANNELS])
|
||||
conv2 = tf.nn.relu(tf.nn.bias_add(conv2d(pool1, conv2_w), conv2_b))
|
||||
pool2 = avg_pool_2x2(conv2)
|
||||
conv2 = tf.nn.relu(tf.nn.bias_add(conv2d(pool1, conv2_w), conv2_b))
|
||||
pool2 = avg_pool_2x2(conv2)
|
||||
vars.extend([conv2_w, conv2_b])
|
||||
nodes.extend([conv2, pool2])
|
||||
|
||||
|
||||
@@ -24,34 +24,40 @@ class DataSet:
|
||||
self.test_labels = []
|
||||
self.generate_data_sets(folder)
|
||||
|
||||
def file2nparray(self, name):
|
||||
try:
|
||||
image = cv2.imread(name)
|
||||
image = cv2.resize(image, (SRC_COLS, SRC_ROWS))
|
||||
image = image.astype(np.float32)
|
||||
return image / 255.0
|
||||
except:
|
||||
print(name)
|
||||
sys.exit(-1)
|
||||
|
||||
def id2label(self, id):
|
||||
a = np.zeros([OUTPUT_NODES])
|
||||
a[id] = 1
|
||||
return a[:]
|
||||
|
||||
def generate_data_sets(self, folder):
|
||||
def id2label(id):
|
||||
a = np.zeros([OUTPUT_NODES])
|
||||
a[id] = 1
|
||||
return a[:]
|
||||
|
||||
def file2nparray(name):
|
||||
try:
|
||||
image = cv2.imread(name)
|
||||
if image.shape[0] < 15:
|
||||
return None
|
||||
elif image.shape[1] < 10:
|
||||
return None
|
||||
image = cv2.resize(image, (SRC_COLS, SRC_ROWS))
|
||||
image = image.astype(np.float32)
|
||||
return image / 255.0, id2label(int(name.split("/")[-2]))
|
||||
except TypeError:
|
||||
print(name)
|
||||
sys.exit(-1)
|
||||
|
||||
sets = []
|
||||
for i in range(OUTPUT_NODES):
|
||||
dir = "%s/%d" % (folder, i)
|
||||
files = os.listdir(dir)
|
||||
for file in tqdm(files, postfix={"loading id": i}, dynamic_ncols=True):
|
||||
if file[-3:] == "jpg":
|
||||
if random.random() > 0.2:
|
||||
self.train_samples.append(self.file2nparray("%s/%s" % (dir, file)))
|
||||
self.train_labels.append(self.id2label(i))
|
||||
else:
|
||||
self.test_samples.append(self.file2nparray("%s/%s" % (dir, file)))
|
||||
self.test_labels.append(self.id2label(i))
|
||||
x = file2nparray("%s/%s" % (dir, file))
|
||||
if x is not None:
|
||||
if random.random() > 0.2:
|
||||
self.train_samples.append(x[0])
|
||||
self.train_labels.append(x[1])
|
||||
else:
|
||||
self.test_samples.append(x[0])
|
||||
self.test_labels.append(x[1])
|
||||
self.train_samples = np.array(self.train_samples)
|
||||
self.train_labels = np.array(self.train_labels)
|
||||
self.test_samples = np.array(self.test_samples)
|
||||
@@ -67,6 +73,15 @@ class DataSet:
|
||||
labels.append(self.train_labels[id])
|
||||
return np.array(samples), np.array(labels)
|
||||
|
||||
def sample_test_sets(self, length):
|
||||
samples = []
|
||||
labels = []
|
||||
for i in range(length):
|
||||
id = random.randint(0, len(self.test_samples)-1)
|
||||
samples.append(self.test_samples[id])
|
||||
labels.append(self.test_labels[id])
|
||||
return np.array(samples), np.array(labels)
|
||||
|
||||
def all_train_sets(self):
|
||||
return self.train_samples[:], self.train_labels[:]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user