更新CNN训练代码。
This commit is contained in:
@@ -87,8 +87,8 @@ def train(dataset, show_bar=False):
|
||||
|
||||
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
|
||||
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
|
||||
|
||||
with tf.Session() as sess:
|
||||
config = tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True))
|
||||
with tf.Session(config=config) as sess:
|
||||
init_op = tf.global_variables_initializer()
|
||||
sess.run(init_op)
|
||||
|
||||
@@ -152,7 +152,7 @@ def train(dataset, show_bar=False):
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("Loading data sets...")
|
||||
dataset = generate.DataSet("/home/xinyang/Desktop/DataSets")
|
||||
dataset = generate.DataSet("/home/xinyang/Desktop/dataset/box")
|
||||
print("Finish!")
|
||||
train(dataset, show_bar=True)
|
||||
input("Press any key to end...")
|
||||
|
||||
Reference in New Issue
Block a user