Pytorch中的四种经典Loss源码解析

【导语】笔者最近在OneFlow框架对齐实现Pytorch相关Loss代码,其中也涉及到部分源码解读,数学特殊操作等知识,于是想写篇文章简单总结一下。

关于Pytorch的Loss源码

了解过Pytorch的应该知道其历史包袱比较重,它吸收了Caffe2的底层代码,然后自己借用这部分底层代码来写各种OP的逻辑,最后再暴露出一层Python接口供用户使用。

因此第一次接触Pytorch源代码可能有点不太熟悉,基本上Pytorch大部分OP逻辑实现代码都放在 Aten/native下,我们这里主要是根据Loss.cpp来进行讲解

MarginRankingLoss

RankingLoss系列是来计算输入样本的距离,而不像MSELoss这种直接进行回归。其主要思想就是分为 MarginRanking

MarginRankingLoss公式

Margin 这个词是页边空白的意思,平常我们打印的时候,文本内容外面的空白就叫 Margin。

而在Loss中也是表达类似的意思,相当于是一个固定的范围当样本距离(即Loss)超过范围,即表示样本差异性足够了,不需要再计算Loss。

Ranking 则是排序,当target=1,则说明x1排名需要大于x2;当target=2,则说明x2排名需要大于x1

其源码逻辑也很简单,就是根据公式进行计算,最后根据reduction类型来进行 reduce_mean/sum

Pytorch的MarginRankingLoss代码

下面是对应的numpy实现代码

def np_margin_ranking_loss(input1, input2, target, margin, reduction):
    output = np.maximum(0, -target*(input1 - input2) + margin)
    if reduction == 'mean':
        return np.mean(output)
    elif reduction == 'sum':
        return np.sum(output)
    else:
        return output

TripletMarginLoss

TripletLoss最早是在 FaceNet 提出的,它是用于衡量不同人脸特征之间的距离,进而实现人脸识别和聚类

TripletLoss

而TripletMarginLoss则是结合了TripletLoss和MarginRankingLoss的思想,具体可参考 Learning local feature descriptors with triplets and shallow convolutional neural networks其公式如下

TripletMarginLoss公式

其中d是p范数函数

距离函数

范数的具体公式是

范数公式

该Loss针对不同样本配对,有以下三种情况

  1. 简单样本,即

此时 正样本距离anchor的距离d(ai, pi) + Margin仍然小于负样本距离anchor的距离d(ai, ni),该情况认为正样本距离足够小,不需要进行优化,因此Loss为0

  1. 难样本,即

此时 负样本距离anchor的距离d(ai, ni) 小于 正样本距离anchor的距离d(ai, pi),需要优化

  1. 半难样本,即

此时虽然 负样本距离anchor的距离d(ai, ni) 大于 正样本距离anchor的距离d(ai, pi),但是还不够大,没有超过 Margin,需要优化

此外论文作者还提出了 swap 这个概念,原因是我们公式里只考虑了anchor距离正类和负类的距离,而没有考虑正类和负类之间的距离,考虑以下情况

可能Anchor距离正样本和负样本的距离相同,但是负样本和正样本的距离很近,不利于模型区分,因此会做一个swap,即交换操作,在代码里体现的操作是取最小值。

## 伪代码if swap:  D(a, n) = min(D(a,n), D(p, n))

这样取了最小值后,在Loss计算公式中,Loss值会增大,进一步帮助区分负样本。

有了前面的铺垫,我们理解Pytorch的TripletMarginRankingLoss源码也非常简单

TripletMarginLoss源码

at::pairwise_distance是距离计算函数,首先计算出了anchor与正类和负类的距离。然后根据参数swap,来确定是否考虑正类和负类之间的距离。最后output就是按照公式进行计算,下面是numpy的对应代码

def np_triplet_margin_loss(anchor, postive, negative, margin, swap, reduction='mean', p=2, eps=1e-6):
    def _np_distance(input1, input2, p, eps):
     # Compute the distance (p-norm)
        np_pnorm = np.power(np.abs((input1 - input2 + eps)), p)
        np_pnorm = np.power(np.sum(np_pnorm, axis=-1), 1.0 / p)
        return np_pnorm

dist_pos = _np_distance(anchor, postive, p, eps)
    dist_neg = _np_distance(anchor, negative, p, eps)

if swap:
        dist_swap = _np_distance(postive, negative, p, eps)
        dist_neg = np.minimum(dist_neg, dist_swap)
    output = np.maximum(margin + dist_pos - dist_neg, 0)

if reduction == 'mean':
        return np.mean(output)
    elif reduction == 'sum':
        return np.sum(output)
    else:
        return output

这里比较容易踩坑的是p范数的计算,因为当p=2,根据范数的公式,如果输入有负数是不合法的, 比如

于是我们从distance函数开始找线索,发现它是调用at::norm

pairwise_distance

根据Pytorch的文档,它其实在计算的时候调用了abs绝对值,来避免最后负数出现,从而保证运算的合理性

Norm文档

KLDivLoss

该损失函数是计算KL散度(即相对熵),它可以用于衡量两个分布的差异

KL散度基本定义

当p和q分布越接近,则趋近于1,经过log运算后,loss值为0

当分布差异比较大,则损失值就比较高

Pytorch中计算公式中还不太一样

Pytorch的KLDivLoss公式

下面我们看看Pytorch对应的源码

KLDivLoss源码

首先可以观察到,除了常规的input,target,reduction,还有一个额外的参数 log_target,用于表示target是否已经经过log运算。根据这个参数,KLDivLoss进而分成两个函数 _kl_div_log_target_kl_div_non_log_target 实现。

_kl_div_log_target 的实现比较简单,就是按照公式进行计算

而  _kl_div_non_log_target  有些许不同,因为target的数值范围不确定,当为负数的时候,log运算时不合法的。因此Pytorch初始化了一个全0数组,然后在最后的loss计算中,在target小于0的地方填0,避免nan数值出现

下面是对应的numpy实现代码

def np_kldivloss(input, target, log_target, reduction='mean'):    if log_target:        output = np.exp(target)*(target - input)    else:        output_pos = target*(np.log(target) - input)        zeros = np.zeros_like(input)        output = np.where(target>0, output_pos, zeros)    if reduction == 'mean':        return np.mean(output)    elif reduction == 'sum':        return np.sum(output)    else:        return output

BCEWithLogitsLoss

熟悉二分类交叉熵损失函数BCELoss的应该知道,该函数输入的是个分类概率,范围在0~1之间,最后计算交叉熵。我们先看下该损失函数的参数

BCEWithLogitsLoss参数
  • weight 表示最后loss缩放权重
  • reduction 表示 最后是做mean, sum, none 操作
  • pos_weight 表示针对正样本的权重,即positive weight

下面是其计算公式其中 表示sigmoid运算

BCEWithLogitsLoss

BCEWithLogitsLoss 相当于 sigmoid + BCELoss,但实际上 Pytorch为了更好的数值稳定性,并不是这么做的,下面我们看看对应的源代码

Pytorch的BCEWithLogitsLoss源码

这段源代码其实看的不太直观,我们可以看下numpy对应的代码

def np_bce_with_logits_loss(np_input, np_target, np_weight, np_pos_weight, reduction='mean'):
    max_val = np.maximum(-np_input, 0)

if np_pos_weight.any():
        log_weight = ((np_pos_weight - 1) * np_target) + 1
        loss = (1 - np_target) * np_input
        loss_1 = np.log(np.exp(-max_val) + np.exp(-np_input - max_val)) + max_val
        loss += log_weight * loss_1
    else:
        loss = (1 - np_target) * np_input
        loss += max_val
        loss += np.log(np.exp(-max_val) + np.exp(-np_input - max_val))

output = loss * np_weight

if reduction == 'mean':
        return np.mean(output)
    elif reduction == 'sum':
        return np.sum(output)
    else:
        return output

因为涉及到了sigmoid运算,所以有以下公式

计算中,如果x过大或过小,会导致指数运算出现上溢或下溢,因此我们可以用 log-sum-exp 的技巧来避免数值溢出,具体可以看下面公式推导(特此感谢德澎!

公式推导

总结

看源代码没有想象中那么难,只要破除迷信,敢于尝试,你也能揭开源码的神秘面纱~

(0)

相关推荐

  • 万字长文,60分钟闪电战

    大家好,我是 Jack. 本文是翻译自官方版教程:DEEP LEARNING WITH PYTORCH: A 60 MINUTE BLITZ,一份 60 分钟带你快速入门 PyTorch 的官方教程. ...

  • Pytorch实战:使用RNN网络对姓名进行分类

    项目地址:https://github.com/spro/practical-pytorch 项目作者: spro 翻译: 大邓 注意:文章末尾有jupyter notebook获取方式 本文我们构建 ...

  • pytorch转keras

    pytorch与keras的区别模型输入:区别pytorchkerasAPItorch.tensorInput形状NCHWNHWC#pytorch #批次, 通道, 高, 宽a = torch.ran ...

  • mmdetection最小复刻版(六):FCOS深入可视化分析

    AI编辑:深度眸 0 概要 论文名称:FCOS: A simple and strong anchor-free object detector 论文地址:https://arxiv.org/pdf/ ...

  • 损失函数技术总结及Pytorch使用示例

    作者丨仿佛若有光 来源丨CV技术指南 编辑丨极市平台 极市导读 本文对损失函数的类别和应用场景,常见的损失函数,常见损失函数的表达式,特性,应用场景和使用示例作了详细的总结. 前言 一直想写损失函数的 ...

  • Tensorflow中卷积的padding操作

    目录 Tensorflow中padding为valid的情况 Tensorflow中padding为same的情况 和Pytorch的padding简单对比 实验对比 实验1 实验2 实验3 实验4 ...

  • 【cntk速成】cntk图像分类从模型自定义到测试

    欢迎来到专栏<2小时玩转开源框架系列>,这是我们第七篇,前面已经说过了caffe,tensorflow,pytorch,mxnet,keras,paddlepaddle. 今天说cntk, ...

  • 超快速!清除测序数据中的rRNAs — RiboDetector

    导言 RiboDetector 是基友老邓的作品,主要特点是超快速,超准确去除测序数据中的rRNA.目前该软件应是审稿修回阶段.正好,一个师妹的课题要做微生物数据分析,便让其学习一下,流程部署起来,方 ...

  • 机器学习「Pytorch 」笔记六:初始化与 18 种损失函数的源码解析

    机器学习「Pytorch 」笔记六:初始化与 18 种损失函数的源码解析

  • 【Pytorch 】笔记六:初始化与 18 种损失函数的源码解析

    ❝ 阿泽推荐:这是 Miracle 同学 Pytorch 系列的第六篇,共有十篇. 1.写在前面 疫情在家的这段时间,想系统的学习一遍 Pytorch 基础知识,因为我发现虽然直接 Pytorch 实 ...

  • 绩效管理中的四个经典误区,你中招了吗

    目前,越来越多的医院开始实施绩效管理,或请咨询公司,或自己动手,基本上都有了绩效管理的概念.但是,似乎许多医院都遇到了一个同样的问题,人力资源部花费大量的时间和心血,制定了绩效管理方案.在方案的推销上 ...

  • 一位资深交易员总结四种经典“止损”铁律,从未被套原因揭秘!

    通常而言,止损是指股民在买入股票后因判断和选择失误而发生亏损的情况下,为避免造成更大的亏损,及时卖出手中所持有的股票.止损的主要目的是为了防止损失的进一步扩大,并将资金保存下来,以便下次重新介入,弥补 ...

  • 教育研究中存在四种人

    当下的教育研究脱离教育实践,背离教育理论为实践服务的宗旨,既有国内大的背景因素,也有学者们的浮躁心态.我们为什么出不了像苏霍姆林斯基那样的教育家?为什么现在没有陶行知.晏阳初那样的教育理论的倡导者和实 ...

  • 改变教育科学研究中的四种现象

    社会科学报/2009 年/5月/7 日/第 005 版学科探讨 改革开放以来,我国教育科学研究取得了举世瞩目的成就,但于中也出现了一些新问题,对此我们应该保持一份清醒思考. 改变教育科学研究中的四种现 ...

  • 『解酒』中药方: 四种解酒中药

    葛花:为葛藤上未开放的花蕾.其性味甘.平,入脾.胃经,具有解酒醒脾之功效.主要用于饮酒过度出现的头痛.头昏.烦渴.饱胀.呕吐酸水等,可取葛花10-15克,水煎服. 乌梅:性味酸.平,入肝.脾.肺.大肠 ...

  • 知识管理实施中的四种角色:依靠谁,为了谁?

    文/田志刚 摘自<卓越密码:如何成为专家> · 从知到行:企业知识管理实施在线培训 · 企业知识管理实施是一个系统工程,涉及到企业的战略.组织.人员.流程.方法.系统.制度等多个维度.在这 ...

  • 期货交易中的四种盈利模式

    在期货投机交易中,交易盈利模式有很多,比如日内交易.短线交易.波段交易.中线交易.长线交易.到底哪个好,哪个不好,这个问题不是他人能够回答清楚的,因为,同样的操作模式,不同的人不同的时期所产生的结果是 ...