修改了摄像头读取方式

This commit is contained in:
xinyang
2019-04-27 16:16:53 +08:00
parent 9cfd26cc23
commit 4e47b38d7d
15 changed files with 272 additions and 255 deletions

View File

@@ -2,7 +2,8 @@ import tensorflow as tf
from progressive.bar import Bar
import generate
import forward
import cv2
import numpy as np
def save_kernal(fp, val):
print(val.shape[2], file=fp)
@@ -49,7 +50,7 @@ def save_para(folder, paras):
save_bias(fp, paras[7])
STEPS = 30000
STEPS = 20000
BATCH = 10
LEARNING_RATE_BASE = 0.01
LEARNING_RATE_DECAY = 0.99
@@ -59,9 +60,9 @@ MOVING_AVERAGE_DECAY = 0.99
def train(dataset, show_bar=False):
test_images, test_labels = dataset.all_test_sets()
x = tf.placeholder(tf.float32, [None, forward.SRC_ROWS, forward.SRC_COLS, forward.SRC_CHANNELS])
x = tf.placeholder(tf.float32, [None, generate.SRC_ROWS, generate.SRC_COLS, generate.SRC_CHANNELS])
y_= tf.placeholder(tf.float32, [None, forward.OUTPUT_NODES])
nodes, vars = forward.forward(0.001)
nodes, vars = forward.forward(x, 0.001)
y = nodes[-1]
ce = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y, labels=tf.argmax(y_, 1))
@@ -72,7 +73,7 @@ def train(dataset, show_bar=False):
learning_rate = tf.train.exponential_decay(
LEARNING_RATE_BASE,
global_step,
len(dataset.train_sets) / BATCH,
len(dataset.train_samples) / BATCH,
LEARNING_RATE_DECAY,
staircase=False)
train_step = tf.train.AdamOptimizer(learning_rate).minimize(loss, global_step=global_step)
@@ -112,12 +113,31 @@ def train(dataset, show_bar=False):
bar.cursor.restore()
bar.draw(value=i+1)
# video = cv2.VideoCapture("/home/xinyang/Desktop/Video.mp4")
# _ = True
# while _:
# _, frame = video.read()
# cv2.imshow("Video", frame)
# if cv2.waitKey(10) == 113:
# bbox = cv2.selectROI("frame", frame, False)
# print(bbox)
# roi = frame[bbox[1]:bbox[1]+bbox[3], bbox[0]:bbox[0]+bbox[2]]
# roi = cv2.resize(roi, (48, 36))
# cv2.imshow("roi", roi)
# cv2.waitKey(0)
# roi = roi.astype(np.float32)
# roi /= 255.0
# roi = roi.reshape([1, 36, 48, 3])
# res = sess.run(y, feed_dict={x: roi})
# res = res.reshape([forward.OUTPUT_NODES])
# print(np.argmax(res))
vars_val = sess.run(vars)
save_para("paras", vars_val)
# nodes_val = sess.run(nodes, feed_dict={x:test})
# return vars_val, nodes_val
save_para("/home/xinyang/Desktop/AutoAim/tools/para", vars_val)
nodes_val = sess.run(nodes, feed_dict={x:test_images})
return vars_val, nodes_val
if __name__ == "__main__":
dataset = generate.DataSet("images")
dataset = generate.DataSet("/home/xinyang/Desktop/DataSets")
train(dataset, show_bar=True)