更新CNN参数

This commit is contained in:
xinyang
2019-07-25 15:06:23 +08:00
parent 453670308c
commit 504d1aca86
10 changed files with 115426 additions and 26831 deletions

View File

@@ -54,7 +54,7 @@ def save_para(folder, paras):
save_bias(fp, paras[7]) save_bias(fp, paras[7])
STEPS = 100000 STEPS = 60000
BATCH = 50 BATCH = 50
LEARNING_RATE_BASE = 0.001 LEARNING_RATE_BASE = 0.001
LEARNING_RATE_DECAY = 0.99 LEARNING_RATE_DECAY = 0.99
@@ -101,16 +101,31 @@ def train(dataset, show_bar=False):
_, loss_value, step = sess.run( _, loss_value, step = sess.run(
[train_op, loss, global_step], [train_op, loss, global_step],
feed_dict={x: images_samples, y_: labels_samples, keep_rate:0.3} feed_dict={x: images_samples, y_: labels_samples, keep_rate:0.2}
) )
if i % 100 == 0: if (i-1) % 100 == 0:
if i % 500 == 0: if (i-1) % 500 == 0:
test_images, test_labels = dataset.sample_test_sets(10000) test_images, test_labels = dataset.sample_test_sets(5000)
acc = sess.run(accuracy, feed_dict={x: test_images, y_: test_labels, keep_rate:1.0}) test_acc, output = sess.run([accuracy, y], feed_dict={x: test_images, y_: test_labels, keep_rate:1.0})
bar.set_postfix({"loss": loss_value, "acc": acc}) output = np.argmax(output, axis=1)
real = np.argmax(test_labels, axis=1)
print("=============test-set===============")
for n in range(forward.OUTPUT_NODES):
print("label: %d, precise: %f, recall: %f" %
(n, np.mean(real[output==n]==n), np.mean(output[real==n]==n)))
train_images, train_labels = dataset.sample_train_sets(5000)
train_acc, output = sess.run([accuracy, y], feed_dict={x: train_images, y_: train_labels, keep_rate:1.0})
output = np.argmax(output, axis=1)
real = np.argmax(train_labels, axis=1)
print("=============train-set===============")
for n in range(forward.OUTPUT_NODES):
print("label: %d, precise: %f, recall: %f" %
(n, np.mean(real[output==n]==n), np.mean(output[real==n]==n)))
print("\n")
bar.set_postfix({"loss": loss_value, "train_acc": train_acc, "test_acc": test_acc})
vars_val = sess.run(vars) vars_val = sess.run(vars)
save_para("/home/xinyang/Workspace/RM_auto-aim/tools/para", vars_val) save_para("/home/xinyang/Workspace/RM_auto-aim/tools/para", vars_val)
@@ -206,9 +221,9 @@ def train(dataset, show_bar=False):
if __name__ == "__main__": if __name__ == "__main__":
import os # import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "-1" # os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
dataset = generate.DataSet("/home/xinyang/Workspace/box_cut") dataset = generate.DataSet("/home/xinyang/Workspace/box_resize")
train(dataset, show_bar=True) train(dataset, show_bar=True)
input("press enter to continue...") input("press enter to continue...")

View File

@@ -29,16 +29,16 @@ def max_pool_2x2(x):
CONV1_KERNAL_SIZE = 5 CONV1_KERNAL_SIZE = 5
# 第一层卷积输出通道数 # 第一层卷积输出通道数
CONV1_OUTPUT_CHANNELS = 6 CONV1_OUTPUT_CHANNELS = 8
# 第二层卷积核大小 # 第二层卷积核大小
CONV2_KERNAL_SIZE = 3 CONV2_KERNAL_SIZE = 3
# 第二层卷积输出通道数 # 第二层卷积输出通道数
CONV2_OUTPUT_CHANNELS = 12 CONV2_OUTPUT_CHANNELS = 16
# 第一层全连接宽度 # 第一层全连接宽度
FC1_OUTPUT_NODES = 30 FC1_OUTPUT_NODES = 100
# 第二层全连接宽度(输出标签类型数) # 第二层全连接宽度(输出标签类型数)
FC2_OUTPUT_NODES = 15 FC2_OUTPUT_NODES = 15

View File

@@ -1,7 +1,9 @@
6 8
-0.07201801 -0.19897088
2.4805095 2.2773967
-0.03871701 -0.07212669
-0.0016837252 -0.22893764
-0.0016489439 -0.022769619
-0.11918187 0.9122422
1.3221853
-0.21709026

File diff suppressed because it is too large Load Diff

View File

@@ -1,13 +1,17 @@
12 16
0.011819358 0.032270256
2.0857728 2.3110154
-0.00558426 0.4078679
0.571528 0.30516425
-0.5205912 -0.027841559
1.1178641 0.80007833
0.12018313 0.10999302
0.35316235 0.83167756
2.4709818 -0.09058099
-0.030677324 0.67901427
-0.07035873 0.2736649
2.3538203 0.08881351
0.10975416
0.25698876
0.076515704
-0.017468728

File diff suppressed because it is too large Load Diff

View File

@@ -1,31 +1,101 @@
30 100
-0.004537243 -0.39040318
-0.7359988 -0.031096617
-0.014530569 -0.06425226
1.9556539 0.24911235
-0.005490738 -0.002578787
-0.0075725853 -0.086705275
1.0653863 -0.0322658
-0.42296687 -0.017816741
-0.33600682 -0.11621032
-0.7384473 0.21196772
0.7910003 -0.004639828
-0.00820433 -0.023076132
-0.025636531 -0.50997764
-0.02046858 -0.04299724
-0.0068890513 -0.01989839
-0.26979277 -0.011238396
-0.004196582 -0.003221448
-0.7438407 -0.019384952
-0.5907334 -0.0007764693
-0.38981074 -0.015599826
0.13757768 0.16373938
-0.027325256 -0.0027049272
1.1313668 -0.18095633
0.5019343 -0.050923813
1.462373 0.12674743
0.6535602 -0.064153716
0.8464444 -0.028386148
0.059315335 -0.059802737
-0.030320408 -0.036068685
-0.329629 -0.004065791
-0.03783843
-0.16458924
-0.0328307
-0.032716025
-0.020594684
-0.042352736
-0.084991984
-0.028080234
-0.001538593
-0.10711875
-0.024680987
-0.008004385
-0.5063542
-0.09158748
-0.08181085
-0.22574262
-0.075171836
0.28233245
-0.024944687
-0.0029645876
-0.041441295
-0.08904015
0.30993482
-0.06328518
-0.0075723003
-0.005151164
-0.0021952058
-0.013833341
-0.023337327
-0.01824665
-0.025177158
-0.067239
-0.02126352
0.11769418
-0.64603645
-0.014887376
-0.14686602
-0.020528413
-0.018256638
-0.0017088759
-0.018110225
-0.003289471
1.0441891
0.30619404
-0.001282074
-0.09424017
-0.24455559
-0.026046017
-0.004658401
-0.022633847
-0.022873487
0.4393057
-0.033948973
-0.042779494
-0.0059623853
0.6859317
-0.19052452
-0.020080235
-0.010588832
0.012147919
-0.002949453
0.41500625
-0.16353038
-0.023607356
-0.38747007
-0.014272043
-0.0033837124
-0.1627222
-0.055108108
0.74174875

File diff suppressed because it is too large Load Diff

View File

@@ -1,16 +1,16 @@
15 15
1.4582382 0.7158619
0.23248857 0.074666105
0.04220094 -0.20262705
-0.46074286 -0.32018387
-1.0683035 -0.35891113
-0.18858828 -0.021645868
0.7637116 0.2373232
0.50826854 0.17633525
-0.99831194 -0.4311263
-0.7008356 -0.2179705
-0.7342338 -0.21293324
0.08691424 0.20182535
-0.80010396 -0.8007338
-0.11393161 -0.02198011
0.1028082 0.30222887

File diff suppressed because it is too large Load Diff