更新反陀螺和分类器参数。
This commit is contained in:
@@ -54,9 +54,9 @@ def save_para(folder, paras):
|
||||
save_bias(fp, paras[7])
|
||||
|
||||
|
||||
STEPS = 20000
|
||||
STEPS = 100000
|
||||
BATCH = 50
|
||||
LEARNING_RATE_BASE = 0.005
|
||||
LEARNING_RATE_BASE = 0.001
|
||||
LEARNING_RATE_DECAY = 0.99
|
||||
MOVING_AVERAGE_DECAY = 0.99
|
||||
|
||||
@@ -101,7 +101,7 @@ def train(dataset, show_bar=False):
|
||||
|
||||
_, loss_value, step = sess.run(
|
||||
[train_op, loss, global_step],
|
||||
feed_dict={x: images_samples, y_: labels_samples, keep_rate:0.5}
|
||||
feed_dict={x: images_samples, y_: labels_samples, keep_rate:0.3}
|
||||
)
|
||||
|
||||
if i % 100 == 0:
|
||||
@@ -206,9 +206,9 @@ def train(dataset, show_bar=False):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# import os
|
||||
# os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
||||
# os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
|
||||
import os
|
||||
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
|
||||
dataset = generate.DataSet("/home/xinyang/Workspace/box_cut")
|
||||
train(dataset, show_bar=True)
|
||||
input("press enter to continue...")
|
||||
|
||||
@@ -38,7 +38,7 @@ CONV2_KERNAL_SIZE = 3
|
||||
CONV2_OUTPUT_CHANNELS = 12
|
||||
|
||||
# 第一层全连接宽度
|
||||
FC1_OUTPUT_NODES = 20
|
||||
FC1_OUTPUT_NODES = 30
|
||||
|
||||
# 第二层全连接宽度(输出标签类型数)
|
||||
FC2_OUTPUT_NODES = 15
|
||||
|
||||
Reference in New Issue
Block a user