更新CNN参数,使用CNN判断tracking是否跟丢。
This commit is contained in:
@@ -1,10 +1,12 @@
|
||||
#!/usr/bin/python3
|
||||
print("Preparing...")
|
||||
import tensorflow as tf
|
||||
from tqdm import tqdm
|
||||
import generate
|
||||
import forward
|
||||
import cv2
|
||||
import numpy as np
|
||||
print("Finish!")
|
||||
|
||||
def save_kernal(fp, val):
|
||||
print(val.shape[2], file=fp)
|
||||
@@ -91,7 +93,7 @@ def train(dataset, show_bar=False):
|
||||
init_op = tf.global_variables_initializer()
|
||||
sess.run(init_op)
|
||||
|
||||
bar = tqdm(range(STEPS), ascii=True, dynamic_ncols=True)
|
||||
bar = tqdm(range(STEPS), dynamic_ncols=True)
|
||||
for i in bar:
|
||||
images_samples, labels_samples = dataset.sample_train_sets(BATCH)
|
||||
|
||||
@@ -111,41 +113,41 @@ def train(dataset, show_bar=False):
|
||||
# bar.cursor.restore()
|
||||
# bar.draw(value=i+1)
|
||||
|
||||
video = cv2.VideoCapture("/home/xinyang/Desktop/Video.mp4")
|
||||
_ = True
|
||||
while _:
|
||||
_, frame = video.read()
|
||||
cv2.imshow("Video", frame)
|
||||
k = cv2.waitKey(10)
|
||||
if k == ord(" "):
|
||||
bbox = cv2.selectROI("frame", frame, False)
|
||||
print(bbox)
|
||||
roi = frame[bbox[1]:bbox[1]+bbox[3], bbox[0]:bbox[0]+bbox[2]]
|
||||
roi = cv2.resize(roi, (48, 36))
|
||||
cv2.imshow("roi", roi)
|
||||
cv2.waitKey(0)
|
||||
roi = roi.astype(np.float32)
|
||||
roi /= 255.0
|
||||
roi = roi.reshape([1, 36, 48, 3])
|
||||
res = sess.run(y, feed_dict={x: roi})
|
||||
res = res.reshape([forward.OUTPUT_NODES])
|
||||
print(np.argmax(res))
|
||||
elif k==ord("q"):
|
||||
break
|
||||
keep = True
|
||||
while keep:
|
||||
n = input()
|
||||
im = cv2.imread(n)
|
||||
im = cv2.resize(im, (48, 36))
|
||||
cv2.imshow("im", im)
|
||||
if cv2.waitKey(0) == ord("q"):
|
||||
keep = False
|
||||
im = im.astype(np.float32)
|
||||
im /= 255.0
|
||||
im = im.reshape([1, 36, 48, 3])
|
||||
res = sess.run(y, feed_dict={x: im})
|
||||
res = res.reshape([forward.OUTPUT_NODES])
|
||||
print(np.argmax(res))
|
||||
# video = cv2.VideoCapture("/home/xinyang/Desktop/Video.mp4")
|
||||
# _ = True
|
||||
# while _:
|
||||
# _, frame = video.read()
|
||||
# cv2.imshow("Video", frame)
|
||||
# k = cv2.waitKey(10)
|
||||
# if k == ord(" "):
|
||||
# bbox = cv2.selectROI("frame", frame, False)
|
||||
# print(bbox)
|
||||
# roi = frame[bbox[1]:bbox[1]+bbox[3], bbox[0]:bbox[0]+bbox[2]]
|
||||
# roi = cv2.resize(roi, (48, 36))
|
||||
# cv2.imshow("roi", roi)
|
||||
# cv2.waitKey(0)
|
||||
# roi = roi.astype(np.float32)
|
||||
# roi /= 255.0
|
||||
# roi = roi.reshape([1, 36, 48, 3])
|
||||
# res = sess.run(y, feed_dict={x: roi})
|
||||
# res = res.reshape([forward.OUTPUT_NODES])
|
||||
# print(np.argmax(res))
|
||||
# elif k==ord("q"):
|
||||
# break
|
||||
# keep = True
|
||||
# while keep:
|
||||
# n = input()
|
||||
# im = cv2.imread(n)
|
||||
# im = cv2.resize(im, (48, 36))
|
||||
# cv2.imshow("im", im)
|
||||
# if cv2.waitKey(0) == ord("q"):
|
||||
# keep = False
|
||||
# im = im.astype(np.float32)
|
||||
# im /= 255.0
|
||||
# im = im.reshape([1, 36, 48, 3])
|
||||
# res = sess.run(y, feed_dict={x: im})
|
||||
# res = res.reshape([forward.OUTPUT_NODES])
|
||||
# print(np.argmax(res))
|
||||
|
||||
vars_val = sess.run(vars)
|
||||
save_para("/home/xinyang/Desktop/AutoAim/tools/para", vars_val)
|
||||
@@ -154,5 +156,7 @@ def train(dataset, show_bar=False):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("Loading data sets...")
|
||||
dataset = generate.DataSet("/home/xinyang/Desktop/DataSets/box")
|
||||
print("Finish!")
|
||||
train(dataset, show_bar=True)
|
||||
|
||||
@@ -3,6 +3,9 @@ import os
|
||||
import cv2
|
||||
import random
|
||||
from forward import OUTPUT_NODES
|
||||
import sys
|
||||
import os
|
||||
from tqdm import tqdm
|
||||
# 原图像行数
|
||||
SRC_ROWS = 36
|
||||
|
||||
@@ -22,10 +25,14 @@ class DataSet:
|
||||
self.generate_data_sets(folder)
|
||||
|
||||
def file2nparray(self, name):
|
||||
image = cv2.imread(name)
|
||||
image = cv2.resize(image, (SRC_COLS, SRC_ROWS))
|
||||
image = image.astype(np.float32)
|
||||
return image / 255.0
|
||||
try:
|
||||
image = cv2.imread(name)
|
||||
image = cv2.resize(image, (SRC_COLS, SRC_ROWS))
|
||||
image = image.astype(np.float32)
|
||||
return image / 255.0
|
||||
except:
|
||||
print(name)
|
||||
sys.exit(-1)
|
||||
|
||||
def id2label(self, id):
|
||||
a = np.zeros([OUTPUT_NODES])
|
||||
@@ -37,7 +44,7 @@ class DataSet:
|
||||
for i in range(OUTPUT_NODES):
|
||||
dir = "%s/%d" % (folder, i)
|
||||
files = os.listdir(dir)
|
||||
for file in files:
|
||||
for file in tqdm(files, postfix={"loading id": i}, dynamic_ncols=True):
|
||||
if file[-3:] == "jpg":
|
||||
if random.random() > 0.2:
|
||||
self.train_samples.append(self.file2nparray("%s/%s" % (dir, file)))
|
||||
|
||||
Reference in New Issue
Block a user