更新CNN参数,使用CNN判断tracking是否跟丢。

This commit is contained in:
xinyang
2019-04-30 20:20:35 +08:00
parent e3098fe3fa
commit ee83a373d3
16 changed files with 12416 additions and 9793 deletions

View File

@@ -1,10 +1,12 @@
#!/usr/bin/python3
print("Preparing...")
import tensorflow as tf
from tqdm import tqdm
import generate
import forward
import cv2
import numpy as np
print("Finish!")
def save_kernal(fp, val):
print(val.shape[2], file=fp)
@@ -91,7 +93,7 @@ def train(dataset, show_bar=False):
init_op = tf.global_variables_initializer()
sess.run(init_op)
bar = tqdm(range(STEPS), ascii=True, dynamic_ncols=True)
bar = tqdm(range(STEPS), dynamic_ncols=True)
for i in bar:
images_samples, labels_samples = dataset.sample_train_sets(BATCH)
@@ -111,41 +113,41 @@ def train(dataset, show_bar=False):
# bar.cursor.restore()
# bar.draw(value=i+1)
video = cv2.VideoCapture("/home/xinyang/Desktop/Video.mp4")
_ = True
while _:
_, frame = video.read()
cv2.imshow("Video", frame)
k = cv2.waitKey(10)
if k == ord(" "):
bbox = cv2.selectROI("frame", frame, False)
print(bbox)
roi = frame[bbox[1]:bbox[1]+bbox[3], bbox[0]:bbox[0]+bbox[2]]
roi = cv2.resize(roi, (48, 36))
cv2.imshow("roi", roi)
cv2.waitKey(0)
roi = roi.astype(np.float32)
roi /= 255.0
roi = roi.reshape([1, 36, 48, 3])
res = sess.run(y, feed_dict={x: roi})
res = res.reshape([forward.OUTPUT_NODES])
print(np.argmax(res))
elif k==ord("q"):
break
keep = True
while keep:
n = input()
im = cv2.imread(n)
im = cv2.resize(im, (48, 36))
cv2.imshow("im", im)
if cv2.waitKey(0) == ord("q"):
keep = False
im = im.astype(np.float32)
im /= 255.0
im = im.reshape([1, 36, 48, 3])
res = sess.run(y, feed_dict={x: im})
res = res.reshape([forward.OUTPUT_NODES])
print(np.argmax(res))
# video = cv2.VideoCapture("/home/xinyang/Desktop/Video.mp4")
# _ = True
# while _:
# _, frame = video.read()
# cv2.imshow("Video", frame)
# k = cv2.waitKey(10)
# if k == ord(" "):
# bbox = cv2.selectROI("frame", frame, False)
# print(bbox)
# roi = frame[bbox[1]:bbox[1]+bbox[3], bbox[0]:bbox[0]+bbox[2]]
# roi = cv2.resize(roi, (48, 36))
# cv2.imshow("roi", roi)
# cv2.waitKey(0)
# roi = roi.astype(np.float32)
# roi /= 255.0
# roi = roi.reshape([1, 36, 48, 3])
# res = sess.run(y, feed_dict={x: roi})
# res = res.reshape([forward.OUTPUT_NODES])
# print(np.argmax(res))
# elif k==ord("q"):
# break
# keep = True
# while keep:
# n = input()
# im = cv2.imread(n)
# im = cv2.resize(im, (48, 36))
# cv2.imshow("im", im)
# if cv2.waitKey(0) == ord("q"):
# keep = False
# im = im.astype(np.float32)
# im /= 255.0
# im = im.reshape([1, 36, 48, 3])
# res = sess.run(y, feed_dict={x: im})
# res = res.reshape([forward.OUTPUT_NODES])
# print(np.argmax(res))
vars_val = sess.run(vars)
save_para("/home/xinyang/Desktop/AutoAim/tools/para", vars_val)
@@ -154,5 +156,7 @@ def train(dataset, show_bar=False):
if __name__ == "__main__":
print("Loading data sets...")
dataset = generate.DataSet("/home/xinyang/Desktop/DataSets/box")
print("Finish!")
train(dataset, show_bar=True)

View File

@@ -3,6 +3,9 @@ import os
import cv2
import random
from forward import OUTPUT_NODES
import sys
import os
from tqdm import tqdm
# 原图像行数
SRC_ROWS = 36
@@ -22,10 +25,14 @@ class DataSet:
self.generate_data_sets(folder)
def file2nparray(self, name):
image = cv2.imread(name)
image = cv2.resize(image, (SRC_COLS, SRC_ROWS))
image = image.astype(np.float32)
return image / 255.0
try:
image = cv2.imread(name)
image = cv2.resize(image, (SRC_COLS, SRC_ROWS))
image = image.astype(np.float32)
return image / 255.0
except:
print(name)
sys.exit(-1)
def id2label(self, id):
a = np.zeros([OUTPUT_NODES])
@@ -37,7 +44,7 @@ class DataSet:
for i in range(OUTPUT_NODES):
dir = "%s/%d" % (folder, i)
files = os.listdir(dir)
for file in files:
for file in tqdm(files, postfix={"loading id": i}, dynamic_ncols=True):
if file[-3:] == "jpg":
if random.random() > 0.2:
self.train_samples.append(self.file2nparray("%s/%s" % (dir, file)))