更新CNN训练代码。

This commit is contained in:
xinyang
2019-07-02 20:56:46 +08:00
parent cb6cf4b2fc
commit 14d1544113
3 changed files with 20 additions and 20 deletions

View File

@@ -87,8 +87,8 @@ def train(dataset, show_bar=False):
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
with tf.Session() as sess:
config = tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True))
with tf.Session(config=config) as sess:
init_op = tf.global_variables_initializer()
sess.run(init_op)
@@ -152,7 +152,7 @@ def train(dataset, show_bar=False):
if __name__ == "__main__":
print("Loading data sets...")
dataset = generate.DataSet("/home/xinyang/Desktop/DataSets")
dataset = generate.DataSet("/home/xinyang/Desktop/dataset/box")
print("Finish!")
train(dataset, show_bar=True)
input("Press any key to end...")