更新CNN参数

This commit is contained in:
xinyang
2019-07-25 15:06:23 +08:00
parent 453670308c
commit 504d1aca86
10 changed files with 115426 additions and 26831 deletions

View File

@@ -54,7 +54,7 @@ def save_para(folder, paras):
save_bias(fp, paras[7])
STEPS = 100000
STEPS = 60000
BATCH = 50
LEARNING_RATE_BASE = 0.001
LEARNING_RATE_DECAY = 0.99
@@ -101,16 +101,31 @@ def train(dataset, show_bar=False):
_, loss_value, step = sess.run(
[train_op, loss, global_step],
feed_dict={x: images_samples, y_: labels_samples, keep_rate:0.3}
feed_dict={x: images_samples, y_: labels_samples, keep_rate:0.2}
)
if i % 100 == 0:
if i % 500 == 0:
test_images, test_labels = dataset.sample_test_sets(10000)
acc = sess.run(accuracy, feed_dict={x: test_images, y_: test_labels, keep_rate:1.0})
bar.set_postfix({"loss": loss_value, "acc": acc})
if (i-1) % 100 == 0:
if (i-1) % 500 == 0:
test_images, test_labels = dataset.sample_test_sets(5000)
test_acc, output = sess.run([accuracy, y], feed_dict={x: test_images, y_: test_labels, keep_rate:1.0})
output = np.argmax(output, axis=1)
real = np.argmax(test_labels, axis=1)
print("=============test-set===============")
for n in range(forward.OUTPUT_NODES):
print("label: %d, precise: %f, recall: %f" %
(n, np.mean(real[output==n]==n), np.mean(output[real==n]==n)))
train_images, train_labels = dataset.sample_train_sets(5000)
train_acc, output = sess.run([accuracy, y], feed_dict={x: train_images, y_: train_labels, keep_rate:1.0})
output = np.argmax(output, axis=1)
real = np.argmax(train_labels, axis=1)
print("=============train-set===============")
for n in range(forward.OUTPUT_NODES):
print("label: %d, precise: %f, recall: %f" %
(n, np.mean(real[output==n]==n), np.mean(output[real==n]==n)))
print("\n")
bar.set_postfix({"loss": loss_value, "train_acc": train_acc, "test_acc": test_acc})
vars_val = sess.run(vars)
save_para("/home/xinyang/Workspace/RM_auto-aim/tools/para", vars_val)
@@ -206,9 +221,9 @@ def train(dataset, show_bar=False):
if __name__ == "__main__":
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
dataset = generate.DataSet("/home/xinyang/Workspace/box_cut")
# import os
# os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
# os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
dataset = generate.DataSet("/home/xinyang/Workspace/box_resize")
train(dataset, show_bar=True)
input("press enter to continue...")