This commit is contained in:
xinyang
2019-07-09 00:03:06 +08:00
parent 960769689b
commit 13344f3e71
12 changed files with 12501 additions and 9 deletions

View File

@@ -70,8 +70,8 @@ def train(dataset, show_bar=False):
nodes, vars = forward.forward(x, 0.01)
y = nodes[-1]
# ce = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y, labels=tf.argmax(y_, 1))
ce = tf.nn.weighted_cross_entropy_with_logits(logits=y, labels=tf.argmax(y_, 1), pos_weight=1)
ce = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y, labels=tf.argmax(y_, 1))
# ce = tf.nn.weighted_cross_entropy_with_logits(logits=y, labels=tf.argmax(y_, 1), pos_weight=1)
cem = tf.reduce_mean(ce)
loss= cem + tf.add_n(tf.get_collection("losses"))
@@ -114,7 +114,7 @@ def train(dataset, show_bar=False):
vars_val = sess.run(vars)
save_para("/home/xinyang/Desktop/RM_auto-aim/tools/para", vars_val)
save_para("/home/xinyang/Workspace/RM_auto-aim/tools/para", vars_val)
print("save done!")
# nodes_val = sess.run(nodes, feed_dict={x:test_images})
# return vars_val, nodes_val
@@ -204,5 +204,5 @@ def train(dataset, show_bar=False):
if __name__ == "__main__":
dataset = generate.DataSet("/home/xinyang/Desktop/box_cut")
dataset = generate.DataSet("/home/xinyang/Workspace/dataset/box_cut")
train(dataset, show_bar=True)