反小陀螺前置代码已完成。
This commit is contained in:
@@ -54,16 +54,14 @@ def save_para(folder, paras):
|
||||
save_bias(fp, paras[7])
|
||||
|
||||
|
||||
STEPS = 10000
|
||||
BATCH = 30
|
||||
STEPS = 20000
|
||||
BATCH = 50
|
||||
LEARNING_RATE_BASE = 0.005
|
||||
LEARNING_RATE_DECAY = 0.99
|
||||
MOVING_AVERAGE_DECAY = 0.99
|
||||
|
||||
|
||||
def train(dataset, show_bar=False):
|
||||
test_images, test_labels = dataset.all_test_sets()
|
||||
|
||||
x = tf.placeholder(tf.float32, [None, generate.SRC_ROWS, generate.SRC_COLS, generate.SRC_CHANNELS])
|
||||
y_= tf.placeholder(tf.float32, [None, forward.OUTPUT_NODES])
|
||||
keep_rate = tf.placeholder(tf.float32)
|
||||
@@ -107,7 +105,8 @@ def train(dataset, show_bar=False):
|
||||
)
|
||||
|
||||
if i % 100 == 0:
|
||||
if i % 1000 == 0:
|
||||
if i % 500 == 0:
|
||||
test_images, test_labels = dataset.sample_test_sets(10000)
|
||||
acc = sess.run(accuracy, feed_dict={x: test_images, y_: test_labels, keep_rate:1.0})
|
||||
bar.set_postfix({"loss": loss_value, "acc": acc})
|
||||
|
||||
@@ -116,6 +115,9 @@ def train(dataset, show_bar=False):
|
||||
vars_val = sess.run(vars)
|
||||
save_para("/home/xinyang/Workspace/RM_auto-aim/tools/para", vars_val)
|
||||
print("save done!")
|
||||
|
||||
# pred = sess.run(y, feed_dict={x: test_images, keep_rate:1.0})
|
||||
|
||||
# nodes_val = sess.run(nodes, feed_dict={x:test_images})
|
||||
# return vars_val, nodes_val
|
||||
DevList = mvsdk.CameraEnumerateDevice()
|
||||
@@ -204,9 +206,9 @@ def train(dataset, show_bar=False):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import os
|
||||
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
|
||||
# import os
|
||||
# os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
||||
# os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
|
||||
dataset = generate.DataSet("/home/xinyang/Workspace/box_cut")
|
||||
train(dataset, show_bar=True)
|
||||
input("press enter to continue...")
|
||||
|
||||
@@ -35,10 +35,10 @@ CONV1_OUTPUT_CHANNELS = 6
|
||||
CONV2_KERNAL_SIZE = 3
|
||||
|
||||
# 第二层卷积输出通道数
|
||||
CONV2_OUTPUT_CHANNELS = 10
|
||||
CONV2_OUTPUT_CHANNELS = 12
|
||||
|
||||
# 第一层全连接宽度
|
||||
FC1_OUTPUT_NODES = 16
|
||||
FC1_OUTPUT_NODES = 20
|
||||
|
||||
# 第二层全连接宽度(输出标签类型数)
|
||||
FC2_OUTPUT_NODES = 15
|
||||
|
||||
@@ -40,7 +40,7 @@ class DataSet:
|
||||
files = os.listdir(dir)
|
||||
for file in tqdm(files, postfix={"loading id": i}, dynamic_ncols=True):
|
||||
if file[-3:] == "jpg":
|
||||
if random.random() > 0.2:
|
||||
if random.random() > 0.7:
|
||||
self.train_samples.append(self.file2nparray("%s/%s" % (dir, file)))
|
||||
self.train_labels.append(self.id2label(i))
|
||||
else:
|
||||
@@ -61,6 +61,15 @@ class DataSet:
|
||||
labels.append(self.train_labels[id])
|
||||
return np.array(samples), np.array(labels)
|
||||
|
||||
def sample_test_sets(self, length):
|
||||
samples = []
|
||||
labels = []
|
||||
for i in range(length):
|
||||
id = random.randint(0, len(self.test_samples)-1)
|
||||
samples.append(self.test_samples[id])
|
||||
labels.append(self.test_labels[id])
|
||||
return np.array(samples), np.array(labels)
|
||||
|
||||
def all_train_sets(self):
|
||||
return self.train_samples[:], self.train_labels[:]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user