更新CNN参数

This commit is contained in:
xinyang
2019-07-25 15:06:23 +08:00
parent 453670308c
commit 504d1aca86
10 changed files with 115426 additions and 26831 deletions

View File

@@ -54,7 +54,7 @@ def save_para(folder, paras):
save_bias(fp, paras[7])
STEPS = 100000
STEPS = 60000
BATCH = 50
LEARNING_RATE_BASE = 0.001
LEARNING_RATE_DECAY = 0.99
@@ -101,16 +101,31 @@ def train(dataset, show_bar=False):
_, loss_value, step = sess.run(
[train_op, loss, global_step],
feed_dict={x: images_samples, y_: labels_samples, keep_rate:0.3}
feed_dict={x: images_samples, y_: labels_samples, keep_rate:0.2}
)
if i % 100 == 0:
if i % 500 == 0:
test_images, test_labels = dataset.sample_test_sets(10000)
acc = sess.run(accuracy, feed_dict={x: test_images, y_: test_labels, keep_rate:1.0})
bar.set_postfix({"loss": loss_value, "acc": acc})
if (i-1) % 100 == 0:
if (i-1) % 500 == 0:
test_images, test_labels = dataset.sample_test_sets(5000)
test_acc, output = sess.run([accuracy, y], feed_dict={x: test_images, y_: test_labels, keep_rate:1.0})
output = np.argmax(output, axis=1)
real = np.argmax(test_labels, axis=1)
print("=============test-set===============")
for n in range(forward.OUTPUT_NODES):
print("label: %d, precise: %f, recall: %f" %
(n, np.mean(real[output==n]==n), np.mean(output[real==n]==n)))
train_images, train_labels = dataset.sample_train_sets(5000)
train_acc, output = sess.run([accuracy, y], feed_dict={x: train_images, y_: train_labels, keep_rate:1.0})
output = np.argmax(output, axis=1)
real = np.argmax(train_labels, axis=1)
print("=============train-set===============")
for n in range(forward.OUTPUT_NODES):
print("label: %d, precise: %f, recall: %f" %
(n, np.mean(real[output==n]==n), np.mean(output[real==n]==n)))
print("\n")
bar.set_postfix({"loss": loss_value, "train_acc": train_acc, "test_acc": test_acc})
vars_val = sess.run(vars)
save_para("/home/xinyang/Workspace/RM_auto-aim/tools/para", vars_val)
@@ -206,9 +221,9 @@ def train(dataset, show_bar=False):
if __name__ == "__main__":
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
dataset = generate.DataSet("/home/xinyang/Workspace/box_cut")
# import os
# os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
# os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
dataset = generate.DataSet("/home/xinyang/Workspace/box_resize")
train(dataset, show_bar=True)
input("press enter to continue...")

View File

@@ -29,16 +29,16 @@ def max_pool_2x2(x):
CONV1_KERNAL_SIZE = 5
# 第一层卷积输出通道数
CONV1_OUTPUT_CHANNELS = 6
CONV1_OUTPUT_CHANNELS = 8
# 第二层卷积核大小
CONV2_KERNAL_SIZE = 3
# 第二层卷积输出通道数
CONV2_OUTPUT_CHANNELS = 12
CONV2_OUTPUT_CHANNELS = 16
# 第一层全连接宽度
FC1_OUTPUT_NODES = 30
FC1_OUTPUT_NODES = 100
# 第二层全连接宽度(输出标签类型数)
FC2_OUTPUT_NODES = 15

View File

@@ -1,7 +1,9 @@
6
-0.07201801
2.4805095
-0.03871701
-0.0016837252
-0.0016489439
-0.11918187
8
-0.19897088
2.2773967
-0.07212669
-0.22893764
-0.022769619
0.9122422
1.3221853
-0.21709026

File diff suppressed because it is too large Load Diff

View File

@@ -1,13 +1,17 @@
12
0.011819358
2.0857728
-0.00558426
0.571528
-0.5205912
1.1178641
0.12018313
0.35316235
2.4709818
-0.030677324
-0.07035873
2.3538203
16
0.032270256
2.3110154
0.4078679
0.30516425
-0.027841559
0.80007833
0.10999302
0.83167756
-0.09058099
0.67901427
0.2736649
0.08881351
0.10975416
0.25698876
0.076515704
-0.017468728

File diff suppressed because it is too large Load Diff

View File

@@ -1,31 +1,101 @@
30
-0.004537243
-0.7359988
-0.014530569
1.9556539
-0.005490738
-0.0075725853
1.0653863
-0.42296687
-0.33600682
-0.7384473
0.7910003
-0.00820433
-0.025636531
-0.02046858
-0.0068890513
-0.26979277
-0.004196582
-0.7438407
-0.5907334
-0.38981074
0.13757768
-0.027325256
1.1313668
0.5019343
1.462373
0.6535602
0.8464444
0.059315335
-0.030320408
-0.329629
100
-0.39040318
-0.031096617
-0.06425226
0.24911235
-0.002578787
-0.086705275
-0.0322658
-0.017816741
-0.11621032
0.21196772
-0.004639828
-0.023076132
-0.50997764
-0.04299724
-0.01989839
-0.011238396
-0.003221448
-0.019384952
-0.0007764693
-0.015599826
0.16373938
-0.0027049272
-0.18095633
-0.050923813
0.12674743
-0.064153716
-0.028386148
-0.059802737
-0.036068685
-0.004065791
-0.03783843
-0.16458924
-0.0328307
-0.032716025
-0.020594684
-0.042352736
-0.084991984
-0.028080234
-0.001538593
-0.10711875
-0.024680987
-0.008004385
-0.5063542
-0.09158748
-0.08181085
-0.22574262
-0.075171836
0.28233245
-0.024944687
-0.0029645876
-0.041441295
-0.08904015
0.30993482
-0.06328518
-0.0075723003
-0.005151164
-0.0021952058
-0.013833341
-0.023337327
-0.01824665
-0.025177158
-0.067239
-0.02126352
0.11769418
-0.64603645
-0.014887376
-0.14686602
-0.020528413
-0.018256638
-0.0017088759
-0.018110225
-0.003289471
1.0441891
0.30619404
-0.001282074
-0.09424017
-0.24455559
-0.026046017
-0.004658401
-0.022633847
-0.022873487
0.4393057
-0.033948973
-0.042779494
-0.0059623853
0.6859317
-0.19052452
-0.020080235
-0.010588832
0.012147919
-0.002949453
0.41500625
-0.16353038
-0.023607356
-0.38747007
-0.014272043
-0.0033837124
-0.1627222
-0.055108108
0.74174875

File diff suppressed because it is too large Load Diff

View File

@@ -1,16 +1,16 @@
15
1.4582382
0.23248857
0.04220094
-0.46074286
-1.0683035
-0.18858828
0.7637116
0.50826854
-0.99831194
-0.7008356
-0.7342338
0.08691424
-0.80010396
-0.11393161
0.1028082
0.7158619
0.074666105
-0.20262705
-0.32018387
-0.35891113
-0.021645868
0.2373232
0.17633525
-0.4311263
-0.2179705
-0.21293324
0.20182535
-0.8007338
-0.02198011
0.30222887

File diff suppressed because it is too large Load Diff