Keras 自定义loss函数 focal loss + triplet loss

上一节中已经阐述清楚了,keras.Model的输入输出与loss的关系。

一、自定义loss损失函数

https://spaces.ac.cn/archives/4493/comment-page-1#comments
非常简单,其实和官方写的方法一样。比如MSE:

def mean_squared_error(y_true, y_pred):
    return K.mean(K.square(y_pred - y_true), axis=-1)

model.compile(optimizer=optim, loss=[mean_squared_error])

注意的是,损失函数def mean_squared_error(y_true, y_pred)中的两个参数是固定的,由Keras自动注入。第一个参数来自于model.fit(x=[],y=[])中的y中的第n个,代表的是真实标签。第二个参数来自于推理后model.outputs相应位置的输出。
同时,model.compile()方法中loss传入的是方法体名称,非方法的return。

二、自定义keras损失函数:focal loss

https://github.com/umbertogriffo/focal-loss-keras/blob/master/losses.py

为了传入超参数,使用了python的wrapper模式构建函数,函数实际返回的是内部函数的名称,符合上述定义。

三、自定义keras损失函数:triplet loss

https://stackoverflow.com/questions/53996020/keras-model-with-tf-contrib-losses-metric-learning-triplet-semihard-loss-asserti
https://github.com/rsalesc/TCC/blob/master/scpd/tf/keras/common.py

由于triplet loss的输入比较特殊,是label(非one-hot格式)与嵌入层向量,因此,对应的,我们在keras的数据输入阶段,提供的第二个label就得是非one-hot格式。同时,model构造中得定义嵌入层,并使用L2正则化,且作为model的一个output以方便loss中调用。

实例中,定义模型时,我们分开定义嵌入层的logits与激活函数,以提取出来嵌入层的值。

    conv3 = layers.Conv2D(5, 1)(maxpool)
    embed = Dense(128, activation=None, name="embedding", kernel_regularizer=regularizers.l2(0.01))(conv3)
    dense4 = layers.Activation(activation=keras.activations.relu)(embed)
    norm_x = Lambda(lambda x: K.l2_normalize(x, axis=1))(embed)
    dense5 = layers.Dense(10, activation='softmax')(dense4)
    model = keras.Model(inputs=[a], outputs=[dense5, norm_x])

当然,在输入数据的生成器中,也必须每次:
yield img,[one_hot_label, label]以对应。

之后即可构造自定义的triplet loss func:

def semi_hard(labels, embeddings):
    labels = K.squeeze(labels, axis=1)
    return tf.contrib.losses.metric_learning.triplet_semihard_loss(labels, embeddings, margin=1.0)

最后在compile中调用即可:

model.compile(optimizer=optim, loss=[classify_loss, semi_hard], loss_weights=[1,0.1])

四、tensorflow中的triplet loss

网易云课堂-吴恩达深度学习的triplet loss章节
https://blog.csdn.net/weixin_40400177/article/details/105213578

https://blog.csdn.net/qq_36387683/article/details/83583099
https://zhuanlan.zhihu.com/p/121763855

Easy Triplets 显然不应加入训练,因为它的损失为0,加在loss里面会拉低loss的平均值。Hard Triplets 和 Semi-Hard Triplets 的选择则见仁见智,针对不同的任务需求,可以只选择Semi-Hard Triplets或者Hard Triplets,也可以两者混用。

如图中所示,其实最难分类的是
semi-hard triplets:d(a,p) < d(a,n) < d(a,p) + margin
我们试图找出这样的图片对来加以训练。

可以使用离线学习,每次训练先找到难分类的图片对,然后喂入网络,但是这样很麻烦,且网络结构同样不好设计。因此使用在线挖掘,即每次在一个batch即B个特征向量中,去挖掘出(a,p)和最难分类的(a,n)来计算loss并反向传播。

官方API:

def triplet_semihard_loss(labels, embeddings, margin=1.0):
  """Computes the triplet loss with semi-hard negative mining.

  The loss encourages the positive distances (between a pair of embeddings with
  the same labels) to be smaller than the minimum negative distance among
  which are at least greater than the positive distance plus the margin constant
  (called semi-hard negative) in the mini-batch. If no such negative exists,
  uses the largest negative distance instead.
  See: https://arxiv.org/abs/1503.03832.

  Args:
    labels: 1-D tf.int32 `Tensor` with shape [batch_size] of
      multiclass integer labels.
    embeddings: 2-D float `Tensor` of embedding vectors. Embeddings should
      be l2 normalized.
    margin: Float, margin term in the loss definition.

  Returns:
    triplet_loss: tf.float32 scalar.
  """

官方解释的很清楚了,就是想让处于semi-hard区域的最小的d(a,n)尽量去远离>d(a,p)+margin,而由于该(a,n)处于semi-hard区域因此该d(a,n)必须至少>d(a,p)。若找不到这样的(a,n),则表明可能(a,n)比起(a,p)更小,因此使用最大的(a,n)代替。