更新了一大堆东西,我也说不清有什么。
This commit is contained in:
79
tools/TrainCNN/backward.py
Normal file → Executable file
79
tools/TrainCNN/backward.py
Normal file → Executable file
@@ -1,5 +1,6 @@
|
||||
#!/usr/bin/python3
|
||||
import tensorflow as tf
|
||||
from progressive.bar import Bar
|
||||
from tqdm import tqdm
|
||||
import generate
|
||||
import forward
|
||||
import cv2
|
||||
@@ -51,7 +52,7 @@ def save_para(folder, paras):
|
||||
|
||||
|
||||
STEPS = 20000
|
||||
BATCH = 10
|
||||
BATCH = 30
|
||||
LEARNING_RATE_BASE = 0.01
|
||||
LEARNING_RATE_DECAY = 0.99
|
||||
MOVING_AVERAGE_DECAY = 0.99
|
||||
@@ -85,18 +86,13 @@ def train(dataset, show_bar=False):
|
||||
|
||||
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
|
||||
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
|
||||
acc = 0
|
||||
|
||||
with tf.Session() as sess:
|
||||
init_op = tf.global_variables_initializer()
|
||||
sess.run(init_op)
|
||||
|
||||
if show_bar:
|
||||
bar = Bar(max_value=STEPS, width=u'50%')
|
||||
bar.cursor.clear_lines(1)
|
||||
bar.cursor.save()
|
||||
|
||||
for i in range(STEPS):
|
||||
bar = tqdm(range(STEPS), ascii=True, dynamic_ncols=True)
|
||||
for i in bar:
|
||||
images_samples, labels_samples = dataset.sample_train_sets(BATCH)
|
||||
|
||||
_, loss_value, step = sess.run(
|
||||
@@ -107,30 +103,49 @@ def train(dataset, show_bar=False):
|
||||
if i % 100 == 0:
|
||||
if i % 1000 == 0:
|
||||
acc = sess.run(accuracy, feed_dict={x: test_images, y_: test_labels})
|
||||
bar.set_postfix({"loss": loss_value, "acc": acc})
|
||||
|
||||
if show_bar:
|
||||
bar.title = "step: %d, loss: %f, acc: %f" % (step, loss_value, acc)
|
||||
bar.cursor.restore()
|
||||
bar.draw(value=i+1)
|
||||
|
||||
# video = cv2.VideoCapture("/home/xinyang/Desktop/Video.mp4")
|
||||
# _ = True
|
||||
# while _:
|
||||
# _, frame = video.read()
|
||||
# cv2.imshow("Video", frame)
|
||||
# if cv2.waitKey(10) == 113:
|
||||
# bbox = cv2.selectROI("frame", frame, False)
|
||||
# print(bbox)
|
||||
# roi = frame[bbox[1]:bbox[1]+bbox[3], bbox[0]:bbox[0]+bbox[2]]
|
||||
# roi = cv2.resize(roi, (48, 36))
|
||||
# cv2.imshow("roi", roi)
|
||||
# cv2.waitKey(0)
|
||||
# roi = roi.astype(np.float32)
|
||||
# roi /= 255.0
|
||||
# roi = roi.reshape([1, 36, 48, 3])
|
||||
# res = sess.run(y, feed_dict={x: roi})
|
||||
# res = res.reshape([forward.OUTPUT_NODES])
|
||||
# print(np.argmax(res))
|
||||
# if show_bar:
|
||||
# bar.title = "step: %d, loss: %f, acc: %f" % (step, loss_value, acc)
|
||||
# bar.cursor.restore()
|
||||
# bar.draw(value=i+1)
|
||||
|
||||
video = cv2.VideoCapture("/home/xinyang/Desktop/Video.mp4")
|
||||
_ = True
|
||||
while _:
|
||||
_, frame = video.read()
|
||||
cv2.imshow("Video", frame)
|
||||
k = cv2.waitKey(10)
|
||||
if k == ord(" "):
|
||||
bbox = cv2.selectROI("frame", frame, False)
|
||||
print(bbox)
|
||||
roi = frame[bbox[1]:bbox[1]+bbox[3], bbox[0]:bbox[0]+bbox[2]]
|
||||
roi = cv2.resize(roi, (48, 36))
|
||||
cv2.imshow("roi", roi)
|
||||
cv2.waitKey(0)
|
||||
roi = roi.astype(np.float32)
|
||||
roi /= 255.0
|
||||
roi = roi.reshape([1, 36, 48, 3])
|
||||
res = sess.run(y, feed_dict={x: roi})
|
||||
res = res.reshape([forward.OUTPUT_NODES])
|
||||
print(np.argmax(res))
|
||||
elif k==ord("q"):
|
||||
break
|
||||
keep = True
|
||||
while keep:
|
||||
n = input()
|
||||
im = cv2.imread(n)
|
||||
im = cv2.resize(im, (48, 36))
|
||||
cv2.imshow("im", im)
|
||||
if cv2.waitKey(0) == ord("q"):
|
||||
keep = False
|
||||
im = im.astype(np.float32)
|
||||
im /= 255.0
|
||||
im = im.reshape([1, 36, 48, 3])
|
||||
res = sess.run(y, feed_dict={x: im})
|
||||
res = res.reshape([forward.OUTPUT_NODES])
|
||||
print(np.argmax(res))
|
||||
|
||||
vars_val = sess.run(vars)
|
||||
save_para("/home/xinyang/Desktop/AutoAim/tools/para", vars_val)
|
||||
@@ -139,5 +154,5 @@ def train(dataset, show_bar=False):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
dataset = generate.DataSet("/home/xinyang/Desktop/DataSets")
|
||||
dataset = generate.DataSet("/home/xinyang/Desktop/DataSets/box")
|
||||
train(dataset, show_bar=True)
|
||||
|
||||
@@ -29,13 +29,13 @@ def max_pool_2x2(x):
|
||||
CONV1_KERNAL_SIZE = 5
|
||||
|
||||
# 第一层卷积输出通道数
|
||||
CONV1_OUTPUT_CHANNELS = 4
|
||||
CONV1_OUTPUT_CHANNELS = 6
|
||||
|
||||
# 第二层卷积核大小
|
||||
CONV2_KERNAL_SIZE = 3
|
||||
|
||||
# 第二层卷积输出通道数
|
||||
CONV2_OUTPUT_CHANNELS = 8
|
||||
CONV2_OUTPUT_CHANNELS = 10
|
||||
|
||||
# 第一层全连接宽度
|
||||
FC1_OUTPUT_NODES = 16
|
||||
|
||||
@@ -12,6 +12,7 @@ SRC_COLS = 48
|
||||
# 原图像通道数
|
||||
SRC_CHANNELS = 3
|
||||
|
||||
|
||||
class DataSet:
|
||||
def __init__(self, folder):
|
||||
self.train_samples = []
|
||||
@@ -37,12 +38,13 @@ class DataSet:
|
||||
dir = "%s/%d" % (folder, i)
|
||||
files = os.listdir(dir)
|
||||
for file in files:
|
||||
if random.random() > 0.2:
|
||||
self.train_samples.append(self.file2nparray("%s/%s" % (dir, file)))
|
||||
self.train_labels.append(self.id2label(i))
|
||||
else:
|
||||
self.test_samples.append(self.file2nparray("%s/%s" % (dir, file)))
|
||||
self.test_labels.append(self.id2label(i))
|
||||
if file[-3:] == "jpg":
|
||||
if random.random() > 0.2:
|
||||
self.train_samples.append(self.file2nparray("%s/%s" % (dir, file)))
|
||||
self.train_labels.append(self.id2label(i))
|
||||
else:
|
||||
self.test_samples.append(self.file2nparray("%s/%s" % (dir, file)))
|
||||
self.test_labels.append(self.id2label(i))
|
||||
self.train_samples = np.array(self.train_samples)
|
||||
self.train_labels = np.array(self.train_labels)
|
||||
self.test_samples = np.array(self.test_samples)
|
||||
|
||||
Reference in New Issue
Block a user