较大规模图片 使用phash去重

96
辰辰沉沉沉
0.1 2017.07.25 00:41* 字数 1230

起因

先说下为什么要做这个事。做的图片站的图片来源为很多美女图片站,自然地,会有很多重复的图片,而我的目标就是要把重复的图片找出来,剔除掉或者是做其他处理。

什么样的图片属于相同图片呢?因为会存在一些有水印的图片(如下图),或者是略微变形的图片(如1024 * 720 与1020 * 720的图片)

with_logo.jpeg

without_logo.jpeg

phash

phash全称是感知哈希算法(Perceptual hash algorithm),使用这玩意儿可以对每个图片生成一个值,如上面两个图分别是2582314446007581403 与 2582314446141799129 (转为二进制再比较),然后计算他们的hamming distance,简单的说就是数一数二进制之后有几位不同。整个处理流程有点像对文章去重时先算simhash再算hamming distance,很多东西都可以直接套用过来。

phash具体的实现可以很多地方都有了,可以搜到很多差不多的内容,在这我也就简单的记录下,具体可以去谷歌或者百度搜下。

  • 缩小尺寸 为了后边的步骤计算简单些
  • 简化色彩 将图片转化成灰度图像,进一步简化计算量
  • 计算DCT 计算图片的DCT变换,得到32*32的DCT系数矩阵。
  • 缩小DCT 虽然DCT的结果是32*32大小的矩阵,但我们只要保留左上角的8*8的矩阵,这部分呈现了图片中的最低频率。
  • 计算平均值 如同均值哈希一样,计算DCT的均值。
  • 计算hash值 根据8*8的DCT矩阵,设置0或1的64位的hash值,大于等于DCT均值的设为”1”,小于DCT均值的设为“0”。组合在一起,就构成了一个64位的整数,这就是这张图片的指纹。
    python 版本的实现
# -*- coding: utf-8 -*-

from compiler.ast import flatten
import cv2
import numpy as np

def pHash(imgfile):
    # 加载并调整图片为32x32灰度图片
    img = cv2.imread(imgfile, 0)
    img = cv2.resize(img, (32, 32), interpolation=cv2.INTER_CUBIC)

    # 创建二维列表
    h, w = img.shape[:2]
    vis0 = np.zeros((h, w), np.float32)
    vis0[:h, :w] = img  # 填充数据

    # 二维Dct变换
    vis1 = cv2.dct(cv2.dct(vis0))
    # 拿到左上角的8 * 8
    vis1 = vis1[0:8, 0:8]

    # 把二维list变成一维list
    img_list = flatten(vis1.tolist())

    # 计算均值
    avg = sum(img_list) * 1. / len(img_list)
    avg_list = ['0' if i < avg else '1' for i in img_list]

    # 得到哈希值
    return ''.join(['%x' % int(''.join(avg_list[x:x + 4]), 2) for x in range(0, 8 * 8, 4)])

这段代码是网上找来做测试用的,当时有个坑,他没有vis1 = vis1[0:8,0:8]这一步,然后出来的结果就很奇葩,而且指纹长的可怕(32 * 32位),准确率和召回率都低的惊人。这段代码也很简单,几乎和白话一样,就是把上面phash的流程给翻译了一遍。

然鹅,我并没有使用上面的python版的,出于两个原因,一是我上边说的坑,当时并没有发现,二是毕竟是python,虽说大部分计算的部分是用c写的(opencv),但还是觉得会慢。找到的是一个纯c的,来自 phash.org (没错,就是这么官方)。安装啥的网站里边都有,附上一个python调用的脚本。

class pHash(object):
    def __init__(self):
        self._lib = ctypes.CDLL('/opt/local/lib/libpHash.dylib', use_errno=True)

    def dct_imagehash(self, path):
        phash = ctypes.c_uint64()
        if self._lib.ph_dct_imagehash(path, ctypes.pointer(phash)):
            errno_ = ctypes.get_errno()
            err, err_msg = (errno.errorcode[errno_], os.strerror(errno_)) \
                if errno_ else ('none', 'errno was set to 0')
            print(('Failed to get image hash'
                   ' ({!r}): [{}] {}').format(path, err, err_msg), file=sys.stderr)
            return None
        return phash.value

    def hamming_distance(self, hash1, hash2):
        return self._lib.ph_hamming_distance(
            *map(ctypes.c_uint64, [hash1, hash2]))

非常贴心的还附赠了海明距的计算。
因为我的图片都是存在云端,为了速度更快,我会直接用云端图像处理把图片先缩小,压缩后再处理。我本机测试的结果是一千张图生成phash耗时1.5s,相当快了。(有个很惊悚的发现,上头那个python版本千张耗时0.7s...惊呆了...可能实现不太一样吧...)

大量数据hamming distance 计算

如标题所述,较大规模图片,我这边的大概是百万级别,但是即便是千万级别应该还是差不多的方式,亿级别的数据可能我的小破开发机就受不了了(没错...没用服务器...)

先说说海明距,咱们上边不是生成了一段64位的数呢?海明距就是数一数两个hash值有多少位的差异,一般小于5的都算近似,就是这么简单:)

假设有1000万已经处理完的phash值吧,现在来了一个新的phash,如何找出所有可能和他重复的图呢?
最简单粗暴的,直接遍历一次...即遍历1000万次....那么耗时大概...不用算了,肯定是个很夸张的值,不靠谱。

这边我采用的是一种内存换速度的方式,64位的的hash值,分为八组,每组八位。建立八个dict,每个dict代表一组,以每组的值作为key,value是一个list,存放key相同的hash值。查找的时候,把hash值分成八个,分别在八个map里边查找,如果有key相同的,取出key相同的所有hash值进行遍历。

说的相当的乱,下边是代码。

split_count = 8  # 每个64位的phash值分为八段,每段8位

def split(key, split_count):
    pre_length = 64 / split_count
    return [key[i * pre_length: (i + 1) * pre_length] for i in range(split_count)]

class ImageManager(object):
    def __init__(self):
        self.phash = pHash() # 就是上面那个pHash类
        self.phash_cache = [defaultdict(list) for i in range(split_count)] #
        self.init_phash_map()

    def init_phash_map(self):
        #我是把所有的phash存在sqlite里边,这边取出所有的Image
        for image in Image.select():
            self.add_to_image_cache(image)

    def add_to_image_cache(self, image):
        # 将hash值分割为8段
        key_split = split(bin(int(image.phash))[2:].rjust(64, '0'), split_count)
        for index, k in enumerate(key_split):
            self.phash_cache[index][k].append(image)

    def has_same(self, ori_image):
        phash = ori_image.phash
        key_split = split(bin(int(phash))[2:].rjust(64, '0'), split_count)
        result = set()
        for index, k in enumerate(key_split):
            if k in self.phash_cache[index]:
                for image in self.phash_cache[index][k]:
                    distance = self.distance(int(phash), int(image.phash))
                    if distance < 5 and ori_image.key != image.key:
                        result.add(image)
        if result:
            return True,list(result)
        return False,[]
图片站