同步master上的bug fix。

This commit is contained in:
xinyang
2019-05-03 11:13:17 +08:00
parent 64d2554069
commit 218a4b56c6
5 changed files with 50 additions and 57 deletions

View File

@@ -5,6 +5,7 @@ from tqdm import tqdm
import generate
import forward
import cv2
import sys
import numpy as np
print("Finish!")
@@ -53,7 +54,7 @@ def save_para(folder, paras):
save_bias(fp, paras[7])
STEPS = 10000
STEPS = 20000
BATCH = 30
LEARNING_RATE_BASE = 0.01
LEARNING_RATE_DECAY = 0.99
@@ -102,16 +103,10 @@ def train(dataset, show_bar=False):
if i % 100 == 0:
if i % 1000 == 0:
test_samples, test_labels = dataset.sample_train_sets(5000)
test_samples, test_labels = dataset.sample_test_sets(5000)
acc = sess.run(accuracy, feed_dict={x: test_samples, y_: test_labels})
bar.set_postfix({"loss": loss_value, "acc": acc})
# if show_bar:
# bar.title = "step: %d, loss: %f, acc: %f" % (step, loss_value, acc)
# bar.cursor.restore()
# bar.draw(value=i+1)
# video = cv2.VideoCapture("/home/xinyang/Desktop/Video.mp4")
# _ = True
# while _:
@@ -147,7 +142,8 @@ def train(dataset, show_bar=False):
# res = sess.run(y, feed_dict={x: im})
# res = res.reshape([forward.OUTPUT_NODES])
# print(np.argmax(res))
test_samples, test_labels = dataset.sample_test_sets(100)
vars_val = sess.run(vars)
save_para("/home/xinyang/Desktop/AutoAim/tools/para", vars_val)
nodes_val = sess.run(nodes, feed_dict={x:test_samples})
@@ -159,4 +155,4 @@ if __name__ == "__main__":
dataset = generate.DataSet("/home/xinyang/Desktop/DataSets/box")
print("Finish!")
train(dataset, show_bar=True)
input("Press any key to continue...")
input("Press any key to end...")

View File

@@ -41,7 +41,7 @@ CONV2_OUTPUT_CHANNELS = 10
FC1_OUTPUT_NODES = 16
# 第二层全连接宽度(输出标签类型数)
FC2_OUTPUT_NODES = 4
FC2_OUTPUT_NODES = 8
# 输出标签类型数
OUTPUT_NODES = FC2_OUTPUT_NODES
@@ -64,8 +64,8 @@ def forward(x, regularizer=None):
[CONV2_KERNAL_SIZE, CONV2_KERNAL_SIZE, CONV1_OUTPUT_CHANNELS, CONV2_OUTPUT_CHANNELS]
)
conv2_b = get_bias([CONV2_OUTPUT_CHANNELS])
conv2 = tf.nn.relu(tf.nn.bias_add(conv2d(pool1, conv2_w), conv2_b))
pool2 = avg_pool_2x2(conv2)
conv2 = tf.nn.relu(tf.nn.bias_add(conv2d(pool1, conv2_w), conv2_b))
pool2 = avg_pool_2x2(conv2)
vars.extend([conv2_w, conv2_b])
nodes.extend([conv2, pool2])

View File

@@ -24,40 +24,34 @@ class DataSet:
self.test_labels = []
self.generate_data_sets(folder)
def file2nparray(self, name):
try:
image = cv2.imread(name)
image = cv2.resize(image, (SRC_COLS, SRC_ROWS))
image = image.astype(np.float32)
return image / 255.0
except:
print(name)
sys.exit(-1)
def id2label(self, id):
a = np.zeros([OUTPUT_NODES])
a[id] = 1
return a[:]
def generate_data_sets(self, folder):
def id2label(id):
a = np.zeros([OUTPUT_NODES])
a[id] = 1
return a[:]
def file2nparray(name):
try:
image = cv2.imread(name)
if image.shape[0] < 15:
return None
elif image.shape[1] < 10:
return None
image = cv2.resize(image, (SRC_COLS, SRC_ROWS))
image = image.astype(np.float32)
return image / 255.0, id2label(int(name.split("/")[-2]))
except TypeError:
print(name)
sys.exit(-1)
sets = []
for i in range(OUTPUT_NODES):
dir = "%s/%d" % (folder, i)
files = os.listdir(dir)
for file in tqdm(files, postfix={"loading id": i}, dynamic_ncols=True):
if file[-3:] == "jpg":
x = file2nparray("%s/%s" % (dir, file))
if x is not None:
if random.random() > 0.2:
self.train_samples.append(x[0])
self.train_labels.append(x[1])
else:
self.test_samples.append(x[0])
self.test_labels.append(x[1])
if random.random() > 0.2:
self.train_samples.append(self.file2nparray("%s/%s" % (dir, file)))
self.train_labels.append(self.id2label(i))
else:
self.test_samples.append(self.file2nparray("%s/%s" % (dir, file)))
self.test_labels.append(self.id2label(i))
self.train_samples = np.array(self.train_samples)
self.train_labels = np.array(self.train_labels)
self.test_samples = np.array(self.test_samples)