fix bug
This commit is contained in:
@@ -70,8 +70,8 @@ def train(dataset, show_bar=False):
|
||||
nodes, vars = forward.forward(x, 0.01)
|
||||
y = nodes[-1]
|
||||
|
||||
# ce = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y, labels=tf.argmax(y_, 1))
|
||||
ce = tf.nn.weighted_cross_entropy_with_logits(logits=y, labels=tf.argmax(y_, 1), pos_weight=1)
|
||||
ce = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y, labels=tf.argmax(y_, 1))
|
||||
# ce = tf.nn.weighted_cross_entropy_with_logits(logits=y, labels=tf.argmax(y_, 1), pos_weight=1)
|
||||
cem = tf.reduce_mean(ce)
|
||||
loss= cem + tf.add_n(tf.get_collection("losses"))
|
||||
|
||||
@@ -114,7 +114,7 @@ def train(dataset, show_bar=False):
|
||||
|
||||
|
||||
vars_val = sess.run(vars)
|
||||
save_para("/home/xinyang/Desktop/RM_auto-aim/tools/para", vars_val)
|
||||
save_para("/home/xinyang/Workspace/RM_auto-aim/tools/para", vars_val)
|
||||
print("save done!")
|
||||
# nodes_val = sess.run(nodes, feed_dict={x:test_images})
|
||||
# return vars_val, nodes_val
|
||||
@@ -204,5 +204,5 @@ def train(dataset, show_bar=False):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
dataset = generate.DataSet("/home/xinyang/Desktop/box_cut")
|
||||
dataset = generate.DataSet("/home/xinyang/Workspace/dataset/box_cut")
|
||||
train(dataset, show_bar=True)
|
||||
|
||||
Reference in New Issue
Block a user