倾斜四边形非极大值抑制(NMS)的计算思路

在做物体检测的时候常常会用到倾斜四边形(一般是矩形)的NMS问题,在允许使用OpenCV的环境下,可以直接调用cv2.dnn.NMSBoxesRotated函数。
但是在有些无法使用OpenCV的场合,只能靠自己实现这个功能了。
本文将会提供一个PyTorch版的NMSBoxesRotated函数,为了方便使用jit或onnx部署,函数中除PyTorch之外没有其他依赖(注意,这份nms代码在Python环境下速度很慢)。

文章分为两个部分,求倾斜四边形的重叠区域面积和NMS。

求重叠区域面积


求重叠区域面积的思路如下:

项目思路

求两条线段的交点

首先利用叉乘判断两条线段是否相交,然后对相交的线段计算交点。

def cross(a,b):
    '''平面向量的叉乘'''
    x1,y1 = a
    x2,y2 = b
    return x1 * y2 - x2 * y1
def line_cross(line1,line2):
    '''判断两条线段是否相交,并求交点'''
    a,b = line1
    c,d = line2
    # 两个三角形的面积同号或者其中一个为0(其中一条线段端点落在另一条线段上) ---> 不相交
    if cross(c - a,b - a) * cross(d - a,b - a) >= 0:
        return False
    if cross(b - c,d - c) * cross(a - c,d - c) >= 0:
        return False
    x1,y1 = a
    x2,y2 = b
    x3,y3 = c
    x4,y4 = d
    
    k = (x1 - x2) * (y3 - y4) - (y1 - y2) * (x3 - x4) 
    if  k != 0:
        xp = ((x1*y2 - y1*x2) * (x3 - x4) - (x1 - x2) * (x3*y4 - y3*x4)) / k
        yp = ((x1*y2 - y1*x2) * (y3 - y4) - (y1 - y2) * (x3*y4 - y3*x4)) / k
    else:
        # 共线
        return False
    return xp,yp

为了验证上面的函数的正确性,可以使用下面的代码测试一下:

from itertools import combinations
lines = torch.randn((100,4)).view((-1,2,2))
comb = combinations(lines,r =2 )
plt.figure(figsize=(10,10))
for line in lines:
    plt.plot(line[:,0],line[:,1],color = 'r')
for line1,line2 in comb:
    r = line_cross(line1,line2)
    if r:
        plt.scatter(r[0],r[1],color = 'g')
线段交点示意图

整理点集顺序

整理顺序的思路是先找到所有顶点到中心点的连线;然后定义一个判断线段相对位置(顺时针位还是逆时针位)的函数,这里同样用到了叉乘法;最后根据这个函数实现一个快速排序,代码如下:

def compare(a,b,center):
    '''
    对比a-center线段是在b-center线段的顺时针方向(True)还是逆时针方向(False)
    1. 通过叉乘积判断,积为负则a-center在b-center的逆时针方向,否则a-center在b-center的顺时针方向;
    2. 如果a,b,center三点共线,则按距离排列,距离center较远的作为顺时针位。

    原理:
    det = a x b = a * b * sin(<a,b>)
    其中<a,b>为a和b之间的夹角,意义为a逆时针旋转到b的位置所需转过的角度
    所以如果det为正,说明a可以逆时针转到b的位置,说明a在b的顺时针方向
    如果det为负,说明a可以顺时针转到b的位置,说明a在b的逆时针方向

    '''
    det = cross(a - center, b - center)
    if det > 0:
        return True
    elif det < 0:
        return False
    else:
        d_a = torch.sum((a - center) ** 2)
        d_b = torch.sum((b - center) ** 2)
        if d_a > d_b:
            return True
        else:
            return False

def quick_sort(box,left,right,center = None):
    '''快速排序'''
    if center is None:
        center = torch.mean(box,dim = 0)
    if left < right:
        q = partition(box,left,right,center)
        quick_sort(box,left,q - 1,center)
        quick_sort(box,q + 1,right,center)

def partition(box,left,right,center = None):
    '''辅助快排,使用最后一个元素将'''
    x = box[right]
    i = left - 1
    for j in range(left,right):
        if compare(x,box[j],center):
            i += 1
            temp = box[i].clone()
            box[i] = box[j]
            box[j] = temp
            # torch.Tensor不能使用下面的方式进行元素交换
            # box[i],box[j] = box[j],box[i]
    temp = box[i + 1].clone()
    box[i + 1] = box[right]
    box[right] = temp
    return i + 1

同样的,我们可以再写一段代码验证一下效果:

empty = (np.ones((800,800,3)) * 255).astype(np.uint8)
box = torch.rand((16,2)) * 800
cv2.polylines(empty,[box.data.numpy().astype(np.int32)],True,(0,255,0),2)
quick_sort(box,0,len(box) - 1)
cv2.polylines(empty, [box.data.numpy().astype(np.int32)], True, (255, 0, 0), 8)
plt.imshow(empty)
plt.show()

得到如下图像,红色的是整理之后的多边形框

整理多边形顺序

判断点是否在多边形内

这个函数是用来求凸四边形交集的,因为凸四边形的交集图形的顶点由三部分构成:

  1. box1内部的box2的顶点;
  2. box2内部的box1的顶点;
  3. box1和box2的交点。

判断代码如下:

def inside(point,polygon):
    '''
    判断点是否在多边形内部
    原理:
    射线法
    从point作一条水平线,如果与polygon的焦点数量为奇数,则在polygon内,否则在polygon外

    为了排除特殊情况
    只有在线段的一个端点在射线下方,另一个端点在射线上方或者射线上的时候,才认为线段与射线相交
    '''
    x0,y0  = point
    # 做一条从point到多边形最左端位置的水平(y保持不变)射线
    left_line = torch.Tensor([[x0,y0],[torch.min(polygon,dim = 0)[0][0].item() - 1,y0]])
    lines = [[polygon[i],polygon[i+1]] for i in range(len(polygon) - 1)] + [[polygon[-1],polygon[0]]]
    ins = False
    for line in lines:
        (x1,y1),(x2,y2) = line
        if min(y1,y2) < y0 and max(y1,y2) >= y0:
            c = line_cross(left_line,line)
            if c and c[0] <= x0:
                ins = not ins
    return ins

然后使用下面的代码再验证一下:

points = torch.rand(800,2) * 800
for p_ in points:
    p = p_.clone().long()
    r = inside(p,box)
    if r:
        cv2.circle(empty,(p[0].item(),p[1].item()),5,color = (0,0,0),thickness=5)
    else:
        cv2.circle(empty,(p[0].item(),p[1].item()),5,color = (255,0,255),thickness=5)
plt.imshow(empty)

就可以获得下面这个很花哨的图形了:

点在多边形内部

求两个四边形的重叠区域

!!!只适用于四边形的重叠区域只有一个的情况,例如两者都是凸四边形的情况

def intersection(box1,box2):
    '''
    判断两个框是否相交,如果相交,返回重叠区域的顶点
    1. 求box1在box2内部的点;
    2. 求box2在box1内部的点;
    3. 求box1和box2的交点;
    4. 所有点构成重叠区域的多边形点集;
    5. 顺时针排序
    '''
    quick_sort(box1,0,len(box1) - 1)
    quick_sort(box2,0,len(box2) - 1)
    # 求重叠区域
    # 整理成线段
    lines1 = [[box1[i],box1[i + 1]] for i in range(len(box1) - 1)] + [[box1[-1],box1[0]]]
    lines2 = [[box2[i],box2[i + 1]] for i in range(len(box2) - 1)] + [[box2[-1],box2[0]]]
    cross_points = []
    # 交点
    for l1 in lines1:
        for l2 in lines2:
            c = line_cross(l1,l2)
            if c:
                cross_points.append(torch.Tensor(c).view(1,-1))
    # 求box1在box2内部的点
    for b in box1:
        if inside(b,box2):
            cross_points.append(b.view(1,-1))
    for b in box2:
        if inside(b,box1):
            cross_points.append(b.view(1,-1))
    if len(cross_points) > 0:
        cross_points = torch.cat(cross_points,dim = 0)
        quick_sort(cross_points,0,len(cross_points) - 1)
        return cross_points
    else:
        return None

验证代码如下:


plt.figure(figsize=(18,10))
for i in range(4):
    box1 = torch.rand((4,2)) * 800
    box2 = torch.rand((4,2)) * 800
    empty = (np.ones((800,800,3)) * 255).astype(np.uint8)
    quick_sort(box1,0,len(box1) - 1)
    quick_sort(box2,0,len(box2) - 1)
    cv2.polylines(empty, [box1.data.numpy().astype(np.int32)], True, (255, 0, 0), 4)
    cv2.polylines(empty, [box2.data.numpy().astype(np.int32)], True, (0, 255, 0), 4)
    cross_points = intersection(box1,box2)
    if cross_points is not None:
        cv2.polylines(empty, [cross_points.data.numpy().astype(np.int32)], True, (0, 0, 255), 4)
    plt.subplot(140 + i + 1)
    plt.imshow(empty)
四边形的重叠区域

计算多边形的面积

多边形面积也是利用叉乘来求的,这里利用了叉乘的集合意义以及叉乘的正负性。

def polygon_area(polygon):
    '''
    求多边形面积
    https://blog.csdn.net/m0_37914500/article/details/78615284 使用向量叉乘计算多边形面积,前提是多边形所有点按顺序排列
    '''
    lines = [[polygon[i],polygon[i+1]] for i in range(len(polygon) - 1)] + [[polygon[-1],polygon[0]]]
    s_polygon = 0.0
    for line in lines:
        a,b = line
        s_tri = cross(a,b)
        s_polygon += s_tri
    return s_polygon / 2

计算IOU

IOU即交并比,也就是两个多边形的交集面积除以并集面积。

def intersection_of_union(box1,box2):
    '''
    iou = intersection(s_1,s_2) / (s_1 + s_2 - intersection(s_1,s_2))
    '''
    quick_sort(box1,0,len(box1) - 1)
    quick_sort(box2,0,len(box2) - 1)
    s_box1 = torch.abs(polygon_area(box1))
    s_box2 = torch.abs(polygon_area(box2))
    cross_points = intersection(box1,box2)
    if cross_points is not None:
        cv2.polylines(empty, [cross_points.data.numpy().astype(np.int32)], True, (0, 0, 255), 4)
        s_cross = torch.abs(polygon_area(cross_points))
    else:
        s_cross = torch.Tensor([[0]])
    iou = s_cross / (s_box1 + s_box2 - s_cross)
    return iou

计算结果如下:

plt.figure(figsize=(18,10))
for i in range(4):
    box1 = torch.rand((4,2)) * 800
    box2 = torch.rand((4,2)) * 800
    empty = (np.ones((800,800,3)) * 255).astype(np.uint8)
    quick_sort(box1,0,len(box1) - 1)
    quick_sort(box2,0,len(box2) - 1)
#     s_box1 = torch.abs(polygon_area(box1))
#     s_box2 = torch.abs(polygon_area(box2))
    cv2.polylines(empty, [box1.data.numpy().astype(np.int32)], True, (255, 0, 0), 4)
    cv2.polylines(empty, [box2.data.numpy().astype(np.int32)], True, (0, 255, 0), 4)
    cross_points = intersection(box1,box2)
    if cross_points is not None:
        cv2.polylines(empty, [cross_points.data.numpy().astype(np.int32)], True, (0, 0, 255), 4)
#         s_cross = torch.abs(polygon_area(cross_points))
#     else:
#         s_cross = torch.Tensor([[0]])
    iou = intersection_of_union(box1,box2)
    print(iou.item())
    plt.subplot(140 + i + 1)
    plt.title("IOU : {}".format(iou.item()))
    plt.imshow(empty)

iou值展示

NMS

nms原理相信大家都比较了解了,分为如下几个步骤:

  1. 选择score最大的box;
  2. 删除与该box的iou超过nms_thresh的box;
  3. 从剩余的box中选择score最大的box,重复第二步。

def nms(boxes,scores,score_thresh = 0.95,nms_thresh = 0.1):
    indices = torch.where(scores > score_thresh)[0]
    if len(indices) <= 1:
        return boxes[indices]
    boxes = boxes[indices]
    scores = scores[indices]
    keep_indices = []
    # 从大到小
    order = torch.argsort(scores).flip(dims = [0])
    while order.shape[0] > 0:
        i = order[0]
        keep_indices.append(i)
        not_overlaps = []
        for j in range(len(order)):
            if order[j] != i:
                iou = intersection_of_union(boxes[i],boxes[order[j]])
                if iou < nms_thresh:
                    not_overlaps.append(j)
        order = order[not_overlaps]
    keep_boxes = boxes[[i.item() for i in keep_indices]]
    return keep_boxes

验证代码:

boxes = torch.rand((10,4,2)) * 800
empty = (np.ones((800,800,3)) * 255).astype(np.uint8)
for i in range(len(boxes)):
    quick_sort(boxes[i],0,len(boxes[i]) - 1)
cv2.polylines(empty,boxes.data.numpy().astype(np.int32),True,(0,255,0),4)
plt.subplot(121)
plt.imshow(empty)
scores = torch.arange(10) + 1
keep_boxes = nms(boxes,scores)
# print("keep indices",keep_indices,boxes.shape)
# keep_boxes = boxes[[i.item() for i in keep_indices]]
empty = (np.ones((800,800,3)) * 255).astype(np.uint8)
cv2.polylines(empty,keep_boxes.data.numpy().astype(np.int32),True,(0,255,0),4)
plt.subplot(122)
plt.imshow(empty)

最终结果:

nms
最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 159,458评论 4 363
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 67,454评论 1 294
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 109,171评论 0 243
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 44,062评论 0 207
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 52,440评论 3 287
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 40,661评论 1 219
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 31,906评论 2 313
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 30,609评论 0 200
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 34,379评论 1 246
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 30,600评论 2 246
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 32,085评论 1 261
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 28,409评论 2 254
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 33,072评论 3 237
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 26,088评论 0 8
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 26,860评论 0 195
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 35,704评论 2 276
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 35,608评论 2 270

推荐阅读更多精彩内容