减少CNN训练时,测试集的内存占用
This commit is contained in:
@@ -5,6 +5,7 @@ from tqdm import tqdm
|
|||||||
import generate
|
import generate
|
||||||
import forward
|
import forward
|
||||||
import cv2
|
import cv2
|
||||||
|
import sys
|
||||||
import numpy as np
|
import numpy as np
|
||||||
print("Finish!")
|
print("Finish!")
|
||||||
|
|
||||||
@@ -61,8 +62,6 @@ MOVING_AVERAGE_DECAY = 0.99
|
|||||||
|
|
||||||
|
|
||||||
def train(dataset, show_bar=False):
|
def train(dataset, show_bar=False):
|
||||||
test_images, test_labels = dataset.all_test_sets()
|
|
||||||
|
|
||||||
x = tf.placeholder(tf.float32, [None, generate.SRC_ROWS, generate.SRC_COLS, generate.SRC_CHANNELS])
|
x = tf.placeholder(tf.float32, [None, generate.SRC_ROWS, generate.SRC_COLS, generate.SRC_CHANNELS])
|
||||||
y_= tf.placeholder(tf.float32, [None, forward.OUTPUT_NODES])
|
y_= tf.placeholder(tf.float32, [None, forward.OUTPUT_NODES])
|
||||||
nodes, vars = forward.forward(x, 0.001)
|
nodes, vars = forward.forward(x, 0.001)
|
||||||
@@ -104,15 +103,10 @@ def train(dataset, show_bar=False):
|
|||||||
|
|
||||||
if i % 100 == 0:
|
if i % 100 == 0:
|
||||||
if i % 1000 == 0:
|
if i % 1000 == 0:
|
||||||
acc = sess.run(accuracy, feed_dict={x: test_images, y_: test_labels})
|
test_samples, test_labels = dataset.sample_test_sets(5000)
|
||||||
|
acc = sess.run(accuracy, feed_dict={x: test_samples, y_: test_labels})
|
||||||
bar.set_postfix({"loss": loss_value, "acc": acc})
|
bar.set_postfix({"loss": loss_value, "acc": acc})
|
||||||
|
|
||||||
|
|
||||||
# if show_bar:
|
|
||||||
# bar.title = "step: %d, loss: %f, acc: %f" % (step, loss_value, acc)
|
|
||||||
# bar.cursor.restore()
|
|
||||||
# bar.draw(value=i+1)
|
|
||||||
|
|
||||||
# video = cv2.VideoCapture("/home/xinyang/Desktop/Video.mp4")
|
# video = cv2.VideoCapture("/home/xinyang/Desktop/Video.mp4")
|
||||||
# _ = True
|
# _ = True
|
||||||
# while _:
|
# while _:
|
||||||
@@ -149,10 +143,11 @@ def train(dataset, show_bar=False):
|
|||||||
# res = res.reshape([forward.OUTPUT_NODES])
|
# res = res.reshape([forward.OUTPUT_NODES])
|
||||||
# print(np.argmax(res))
|
# print(np.argmax(res))
|
||||||
|
|
||||||
|
test_samples, test_labels = dataset.sample_test_sets(100)
|
||||||
vars_val = sess.run(vars)
|
vars_val = sess.run(vars)
|
||||||
save_para("/home/xinyang/Desktop/AutoAim/tools/para", vars_val)
|
save_para("/home/xinyang/Desktop/AutoAim/tools/para", vars_val)
|
||||||
nodes_val = sess.run(nodes, feed_dict={x:test_images})
|
nodes_val = sess.run(nodes, feed_dict={x:test_samples})
|
||||||
return vars_val, nodes_val
|
return vars_val, nodes_val, test_samples
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
@@ -160,3 +155,4 @@ if __name__ == "__main__":
|
|||||||
dataset = generate.DataSet("/home/xinyang/Desktop/DataSets/box")
|
dataset = generate.DataSet("/home/xinyang/Desktop/DataSets/box")
|
||||||
print("Finish!")
|
print("Finish!")
|
||||||
train(dataset, show_bar=True)
|
train(dataset, show_bar=True)
|
||||||
|
input("Press any key to end...")
|
||||||
|
|||||||
@@ -41,7 +41,7 @@ CONV2_OUTPUT_CHANNELS = 10
|
|||||||
FC1_OUTPUT_NODES = 16
|
FC1_OUTPUT_NODES = 16
|
||||||
|
|
||||||
# 第二层全连接宽度(输出标签类型数)
|
# 第二层全连接宽度(输出标签类型数)
|
||||||
FC2_OUTPUT_NODES = 4
|
FC2_OUTPUT_NODES = 8
|
||||||
|
|
||||||
# 输出标签类型数
|
# 输出标签类型数
|
||||||
OUTPUT_NODES = FC2_OUTPUT_NODES
|
OUTPUT_NODES = FC2_OUTPUT_NODES
|
||||||
|
|||||||
@@ -67,6 +67,15 @@ class DataSet:
|
|||||||
labels.append(self.train_labels[id])
|
labels.append(self.train_labels[id])
|
||||||
return np.array(samples), np.array(labels)
|
return np.array(samples), np.array(labels)
|
||||||
|
|
||||||
|
def sample_test_sets(self, length):
|
||||||
|
samples = []
|
||||||
|
labels = []
|
||||||
|
for i in range(length):
|
||||||
|
id = random.randint(0, len(self.test_samples)-1)
|
||||||
|
samples.append(self.test_samples[id])
|
||||||
|
labels.append(self.test_labels[id])
|
||||||
|
return np.array(samples), np.array(labels)
|
||||||
|
|
||||||
def all_train_sets(self):
|
def all_train_sets(self):
|
||||||
return self.train_samples[:], self.train_labels[:]
|
return self.train_samples[:], self.train_labels[:]
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user