更新CNN参数,使用CNN判断tracking是否跟丢。

This commit is contained in:
xinyang
2019-04-30 20:20:35 +08:00
parent e3098fe3fa
commit ee83a373d3
16 changed files with 12416 additions and 9793 deletions

View File

@@ -1,10 +1,12 @@
#!/usr/bin/python3
print("Preparing...")
import tensorflow as tf
from tqdm import tqdm
import generate
import forward
import cv2
import numpy as np
print("Finish!")
def save_kernal(fp, val):
print(val.shape[2], file=fp)
@@ -91,7 +93,7 @@ def train(dataset, show_bar=False):
init_op = tf.global_variables_initializer()
sess.run(init_op)
bar = tqdm(range(STEPS), ascii=True, dynamic_ncols=True)
bar = tqdm(range(STEPS), dynamic_ncols=True)
for i in bar:
images_samples, labels_samples = dataset.sample_train_sets(BATCH)
@@ -111,41 +113,41 @@ def train(dataset, show_bar=False):
# bar.cursor.restore()
# bar.draw(value=i+1)
video = cv2.VideoCapture("/home/xinyang/Desktop/Video.mp4")
_ = True
while _:
_, frame = video.read()
cv2.imshow("Video", frame)
k = cv2.waitKey(10)
if k == ord(" "):
bbox = cv2.selectROI("frame", frame, False)
print(bbox)
roi = frame[bbox[1]:bbox[1]+bbox[3], bbox[0]:bbox[0]+bbox[2]]
roi = cv2.resize(roi, (48, 36))
cv2.imshow("roi", roi)
cv2.waitKey(0)
roi = roi.astype(np.float32)
roi /= 255.0
roi = roi.reshape([1, 36, 48, 3])
res = sess.run(y, feed_dict={x: roi})
res = res.reshape([forward.OUTPUT_NODES])
print(np.argmax(res))
elif k==ord("q"):
break
keep = True
while keep:
n = input()
im = cv2.imread(n)
im = cv2.resize(im, (48, 36))
cv2.imshow("im", im)
if cv2.waitKey(0) == ord("q"):
keep = False
im = im.astype(np.float32)
im /= 255.0
im = im.reshape([1, 36, 48, 3])
res = sess.run(y, feed_dict={x: im})
res = res.reshape([forward.OUTPUT_NODES])
print(np.argmax(res))
# video = cv2.VideoCapture("/home/xinyang/Desktop/Video.mp4")
# _ = True
# while _:
# _, frame = video.read()
# cv2.imshow("Video", frame)
# k = cv2.waitKey(10)
# if k == ord(" "):
# bbox = cv2.selectROI("frame", frame, False)
# print(bbox)
# roi = frame[bbox[1]:bbox[1]+bbox[3], bbox[0]:bbox[0]+bbox[2]]
# roi = cv2.resize(roi, (48, 36))
# cv2.imshow("roi", roi)
# cv2.waitKey(0)
# roi = roi.astype(np.float32)
# roi /= 255.0
# roi = roi.reshape([1, 36, 48, 3])
# res = sess.run(y, feed_dict={x: roi})
# res = res.reshape([forward.OUTPUT_NODES])
# print(np.argmax(res))
# elif k==ord("q"):
# break
# keep = True
# while keep:
# n = input()
# im = cv2.imread(n)
# im = cv2.resize(im, (48, 36))
# cv2.imshow("im", im)
# if cv2.waitKey(0) == ord("q"):
# keep = False
# im = im.astype(np.float32)
# im /= 255.0
# im = im.reshape([1, 36, 48, 3])
# res = sess.run(y, feed_dict={x: im})
# res = res.reshape([forward.OUTPUT_NODES])
# print(np.argmax(res))
vars_val = sess.run(vars)
save_para("/home/xinyang/Desktop/AutoAim/tools/para", vars_val)
@@ -154,5 +156,7 @@ def train(dataset, show_bar=False):
if __name__ == "__main__":
print("Loading data sets...")
dataset = generate.DataSet("/home/xinyang/Desktop/DataSets/box")
print("Finish!")
train(dataset, show_bar=True)