更新CNN参数,使用CNN判断tracking是否跟丢。
This commit is contained in:
@@ -1,10 +1,12 @@
|
||||
#!/usr/bin/python3
|
||||
print("Preparing...")
|
||||
import tensorflow as tf
|
||||
from tqdm import tqdm
|
||||
import generate
|
||||
import forward
|
||||
import cv2
|
||||
import numpy as np
|
||||
print("Finish!")
|
||||
|
||||
def save_kernal(fp, val):
|
||||
print(val.shape[2], file=fp)
|
||||
@@ -91,7 +93,7 @@ def train(dataset, show_bar=False):
|
||||
init_op = tf.global_variables_initializer()
|
||||
sess.run(init_op)
|
||||
|
||||
bar = tqdm(range(STEPS), ascii=True, dynamic_ncols=True)
|
||||
bar = tqdm(range(STEPS), dynamic_ncols=True)
|
||||
for i in bar:
|
||||
images_samples, labels_samples = dataset.sample_train_sets(BATCH)
|
||||
|
||||
@@ -111,41 +113,41 @@ def train(dataset, show_bar=False):
|
||||
# bar.cursor.restore()
|
||||
# bar.draw(value=i+1)
|
||||
|
||||
video = cv2.VideoCapture("/home/xinyang/Desktop/Video.mp4")
|
||||
_ = True
|
||||
while _:
|
||||
_, frame = video.read()
|
||||
cv2.imshow("Video", frame)
|
||||
k = cv2.waitKey(10)
|
||||
if k == ord(" "):
|
||||
bbox = cv2.selectROI("frame", frame, False)
|
||||
print(bbox)
|
||||
roi = frame[bbox[1]:bbox[1]+bbox[3], bbox[0]:bbox[0]+bbox[2]]
|
||||
roi = cv2.resize(roi, (48, 36))
|
||||
cv2.imshow("roi", roi)
|
||||
cv2.waitKey(0)
|
||||
roi = roi.astype(np.float32)
|
||||
roi /= 255.0
|
||||
roi = roi.reshape([1, 36, 48, 3])
|
||||
res = sess.run(y, feed_dict={x: roi})
|
||||
res = res.reshape([forward.OUTPUT_NODES])
|
||||
print(np.argmax(res))
|
||||
elif k==ord("q"):
|
||||
break
|
||||
keep = True
|
||||
while keep:
|
||||
n = input()
|
||||
im = cv2.imread(n)
|
||||
im = cv2.resize(im, (48, 36))
|
||||
cv2.imshow("im", im)
|
||||
if cv2.waitKey(0) == ord("q"):
|
||||
keep = False
|
||||
im = im.astype(np.float32)
|
||||
im /= 255.0
|
||||
im = im.reshape([1, 36, 48, 3])
|
||||
res = sess.run(y, feed_dict={x: im})
|
||||
res = res.reshape([forward.OUTPUT_NODES])
|
||||
print(np.argmax(res))
|
||||
# video = cv2.VideoCapture("/home/xinyang/Desktop/Video.mp4")
|
||||
# _ = True
|
||||
# while _:
|
||||
# _, frame = video.read()
|
||||
# cv2.imshow("Video", frame)
|
||||
# k = cv2.waitKey(10)
|
||||
# if k == ord(" "):
|
||||
# bbox = cv2.selectROI("frame", frame, False)
|
||||
# print(bbox)
|
||||
# roi = frame[bbox[1]:bbox[1]+bbox[3], bbox[0]:bbox[0]+bbox[2]]
|
||||
# roi = cv2.resize(roi, (48, 36))
|
||||
# cv2.imshow("roi", roi)
|
||||
# cv2.waitKey(0)
|
||||
# roi = roi.astype(np.float32)
|
||||
# roi /= 255.0
|
||||
# roi = roi.reshape([1, 36, 48, 3])
|
||||
# res = sess.run(y, feed_dict={x: roi})
|
||||
# res = res.reshape([forward.OUTPUT_NODES])
|
||||
# print(np.argmax(res))
|
||||
# elif k==ord("q"):
|
||||
# break
|
||||
# keep = True
|
||||
# while keep:
|
||||
# n = input()
|
||||
# im = cv2.imread(n)
|
||||
# im = cv2.resize(im, (48, 36))
|
||||
# cv2.imshow("im", im)
|
||||
# if cv2.waitKey(0) == ord("q"):
|
||||
# keep = False
|
||||
# im = im.astype(np.float32)
|
||||
# im /= 255.0
|
||||
# im = im.reshape([1, 36, 48, 3])
|
||||
# res = sess.run(y, feed_dict={x: im})
|
||||
# res = res.reshape([forward.OUTPUT_NODES])
|
||||
# print(np.argmax(res))
|
||||
|
||||
vars_val = sess.run(vars)
|
||||
save_para("/home/xinyang/Desktop/AutoAim/tools/para", vars_val)
|
||||
@@ -154,5 +156,7 @@ def train(dataset, show_bar=False):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("Loading data sets...")
|
||||
dataset = generate.DataSet("/home/xinyang/Desktop/DataSets/box")
|
||||
print("Finish!")
|
||||
train(dataset, show_bar=True)
|
||||
|
||||
Reference in New Issue
Block a user