更新了一大堆东西,我也说不清有什么。

This commit is contained in:
xinyang
2019-04-29 14:42:47 +08:00
parent 4015ee5910
commit 25927f0b34
9 changed files with 140 additions and 112 deletions

79
tools/TrainCNN/backward.py Normal file → Executable file
View File

@@ -1,5 +1,6 @@
#!/usr/bin/python3
import tensorflow as tf
from progressive.bar import Bar
from tqdm import tqdm
import generate
import forward
import cv2
@@ -51,7 +52,7 @@ def save_para(folder, paras):
STEPS = 20000
BATCH = 10
BATCH = 30
LEARNING_RATE_BASE = 0.01
LEARNING_RATE_DECAY = 0.99
MOVING_AVERAGE_DECAY = 0.99
@@ -85,18 +86,13 @@ def train(dataset, show_bar=False):
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
acc = 0
with tf.Session() as sess:
init_op = tf.global_variables_initializer()
sess.run(init_op)
if show_bar:
bar = Bar(max_value=STEPS, width=u'50%')
bar.cursor.clear_lines(1)
bar.cursor.save()
for i in range(STEPS):
bar = tqdm(range(STEPS), ascii=True, dynamic_ncols=True)
for i in bar:
images_samples, labels_samples = dataset.sample_train_sets(BATCH)
_, loss_value, step = sess.run(
@@ -107,30 +103,49 @@ def train(dataset, show_bar=False):
if i % 100 == 0:
if i % 1000 == 0:
acc = sess.run(accuracy, feed_dict={x: test_images, y_: test_labels})
bar.set_postfix({"loss": loss_value, "acc": acc})
if show_bar:
bar.title = "step: %d, loss: %f, acc: %f" % (step, loss_value, acc)
bar.cursor.restore()
bar.draw(value=i+1)
# video = cv2.VideoCapture("/home/xinyang/Desktop/Video.mp4")
# _ = True
# while _:
# _, frame = video.read()
# cv2.imshow("Video", frame)
# if cv2.waitKey(10) == 113:
# bbox = cv2.selectROI("frame", frame, False)
# print(bbox)
# roi = frame[bbox[1]:bbox[1]+bbox[3], bbox[0]:bbox[0]+bbox[2]]
# roi = cv2.resize(roi, (48, 36))
# cv2.imshow("roi", roi)
# cv2.waitKey(0)
# roi = roi.astype(np.float32)
# roi /= 255.0
# roi = roi.reshape([1, 36, 48, 3])
# res = sess.run(y, feed_dict={x: roi})
# res = res.reshape([forward.OUTPUT_NODES])
# print(np.argmax(res))
# if show_bar:
# bar.title = "step: %d, loss: %f, acc: %f" % (step, loss_value, acc)
# bar.cursor.restore()
# bar.draw(value=i+1)
video = cv2.VideoCapture("/home/xinyang/Desktop/Video.mp4")
_ = True
while _:
_, frame = video.read()
cv2.imshow("Video", frame)
k = cv2.waitKey(10)
if k == ord(" "):
bbox = cv2.selectROI("frame", frame, False)
print(bbox)
roi = frame[bbox[1]:bbox[1]+bbox[3], bbox[0]:bbox[0]+bbox[2]]
roi = cv2.resize(roi, (48, 36))
cv2.imshow("roi", roi)
cv2.waitKey(0)
roi = roi.astype(np.float32)
roi /= 255.0
roi = roi.reshape([1, 36, 48, 3])
res = sess.run(y, feed_dict={x: roi})
res = res.reshape([forward.OUTPUT_NODES])
print(np.argmax(res))
elif k==ord("q"):
break
keep = True
while keep:
n = input()
im = cv2.imread(n)
im = cv2.resize(im, (48, 36))
cv2.imshow("im", im)
if cv2.waitKey(0) == ord("q"):
keep = False
im = im.astype(np.float32)
im /= 255.0
im = im.reshape([1, 36, 48, 3])
res = sess.run(y, feed_dict={x: im})
res = res.reshape([forward.OUTPUT_NODES])
print(np.argmax(res))
vars_val = sess.run(vars)
save_para("/home/xinyang/Desktop/AutoAim/tools/para", vars_val)
@@ -139,5 +154,5 @@ def train(dataset, show_bar=False):
if __name__ == "__main__":
dataset = generate.DataSet("/home/xinyang/Desktop/DataSets")
dataset = generate.DataSet("/home/xinyang/Desktop/DataSets/box")
train(dataset, show_bar=True)

View File

@@ -29,13 +29,13 @@ def max_pool_2x2(x):
CONV1_KERNAL_SIZE = 5
# 第一层卷积输出通道数
CONV1_OUTPUT_CHANNELS = 4
CONV1_OUTPUT_CHANNELS = 6
# 第二层卷积核大小
CONV2_KERNAL_SIZE = 3
# 第二层卷积输出通道数
CONV2_OUTPUT_CHANNELS = 8
CONV2_OUTPUT_CHANNELS = 10
# 第一层全连接宽度
FC1_OUTPUT_NODES = 16

View File

@@ -12,6 +12,7 @@ SRC_COLS = 48
# 原图像通道数
SRC_CHANNELS = 3
class DataSet:
def __init__(self, folder):
self.train_samples = []
@@ -37,12 +38,13 @@ class DataSet:
dir = "%s/%d" % (folder, i)
files = os.listdir(dir)
for file in files:
if random.random() > 0.2:
self.train_samples.append(self.file2nparray("%s/%s" % (dir, file)))
self.train_labels.append(self.id2label(i))
else:
self.test_samples.append(self.file2nparray("%s/%s" % (dir, file)))
self.test_labels.append(self.id2label(i))
if file[-3:] == "jpg":
if random.random() > 0.2:
self.train_samples.append(self.file2nparray("%s/%s" % (dir, file)))
self.train_labels.append(self.id2label(i))
else:
self.test_samples.append(self.file2nparray("%s/%s" % (dir, file)))
self.test_labels.append(self.id2label(i))
self.train_samples = np.array(self.train_samples)
self.train_labels = np.array(self.train_labels)
self.test_samples = np.array(self.test_samples)