import tensorflow as tf
import scipy.misc as msc
'''
    对于语义分割的边缘线,白色的为255,进行去除
'''
def remove_ignore_label(gt ,output=None ,pred=None):
    '''   去除label为255的值,进行交叉熵的计算
    gt: not one-hot
    output: a distriution of all labels, and is scaled to macth the size of gt
    NOTE the result is a flatted tensor
    and all label which is bigger that or equal to self.category_num is void label
    '''
    gt = tf.reshape(gt ,shape=[-1])  # (180000,)  把矩阵  转化为向量
    indices = tf.squeeze(tf.where(tf.less(gt, 21)) ,axis=1) #除去边缘线  判断是否小于 255
    #tf.less(gt, 21)  找到所以小于21的label,相当于除去边缘线,  某位置< 21  返回True, 否则返回False
    #tf.where(tf.less(gt, 21))  返回这个位置的index,在为True的位置
    #tf.squeeze 压缩为1的维度
    gt = tf.gather(gt ,indices)
    #  根据indices  取出这个位置的值
    if output is not None:
        output = tf.reshape(output, shape=[-1, 21])   #转化为21维度的特征,每个特征,相当于一张图片
        output = tf.gather(output ,indices) # output 输出也是  [小于21的索引值(相当与一张图片除为255的所以值) , 21]
        return gt ,output
    elif pred is not None:
        pred = tf.reshape(pred, shape=[-1])
        pred = tf.gather(pred, indices)
        return gt ,pred
label = tf.truncated_normal(shape=(3,281,500),stddev=0.1)  # 输入图片的label (b, w, h)
output = tf.truncated_normal(shape=(3,281,500,21),stddev=0.1) # 网络输出图片的概率 (b, w, h, c)  #21代表类别
label,output = remove_ignore_label(label,output)
label = tf.cast(label, tf.int32)
loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=label, logits=output))
#这里进行计算交叉熵
with tf.Session() as sess:
    print(sess.run(loss))