首页 > 代码库 > Photoshop中磁力套索的一种简陋实现(Python)
Photoshop中磁力套索的一种简陋实现(Python)
经常用Photoshop的人应该熟悉磁力套索(Magnetic Lasso)这个功能,就是人为引导下的抠图辅助工具。在研发领域一般不这么叫,通常管这种边缘提取的办法叫Intelligent Scissors或者Livewire。
本来是给一个图像分割项目算法评估时的Python框架,觉得有点意思,就稍稍拓展了一下,用PyQt加了个壳,在非常简陋的程度上模拟了一下的磁力套索功能。为什么简陋:1) 只实现了最基本的边缘查找。路径冷却,动态训练,鼠标位置修正都没有,更别提曲线闭合,抠图,Alpha Matting等等;2) 没考虑性能规范,只为了书写方便;3) 我对Qt了解很浅,至今不会写Signal-Slot,不知道GUI写得是否合理;4) 没调试。
基本算法
相关算法我并没有做很深入的调研,不过相信这类应用中影响力最大的算法是来源于[1],也是本文的主要参考,基本思想是把图片看成是一个无向图,相邻像素之间就可以计算出一个局部cost,于是就转化成了最短路径问题了,接下来就是基于Dijkstra算法产生路径,就是需要提取的边缘。主要涉及的算法有两部分:1) 相邻像素的cost计算;2) 最短路径算法。
边缘检测
计算相邻像素cost的最终目的还是为了寻找边缘,所以本质还是边缘检测。基本思想是,通过各种不同手段检测边缘,并且根据检测到的强度来求加权值,作为cost。从最短路径的角度来说,就是边缘越明显的地方,cost的值越小。[1]中的建议是用三种指标求加权:1) 边缘检测算子;2) 梯度强度(Gradient Magnitude);3) 梯度方向(Gradient Direction)。本文的方法和[1]有那么一些不一样,因为懒,用了OpenCV中的Canny算子检测边缘而不是Laplacian Zero-Crossing Operator。表达式如下:
\[l\left( p,q \right)={{w}_{E}}{{f}_{E}}\left( q \right)+{{w}_{G}}{{f}_{G}}\left( q \right)+{{w}_{D}}{{f}_{D}}\left( p,q \right)\]
Canny算子
基本思想是根据梯度信息,先检测出许多连通的像素,然后对于每一坨连通的像素只取其中最大值且连通的部分,将周围置零,得到初始的边缘(Edges),这个过程叫做Non-Maximum Suppression。然后用二阈值的办法将这些检测到的初始边缘分为Strong, Weak, and None三个等级,顾名思义,Strong就是很确定一定是边缘了,None就被舍弃,然后从Weak中挑选和Strong连通的作为保留的边缘,得到最后的结果,这个过程叫做Hysteresis Thresholding。这个算法太经典,更多细节一Google出来一大堆,我就不赘述了。公式如下:
\[{{f}_{E}}\left( q \right)=\left\{ \begin{matrix}
0;\text{ if }q\text{ is on a edge} \\
1;\text{ if }q\text{ is not on a edge} \\
\end{matrix} \right.\]
其实从权值的计算上和最大梯度有些重复,因为如果沿着最大梯度方向找出来的路径基本上就是边缘,这一项的作用我的理解主要应该是1) 避免梯度都很大的区域出现离明显边缘的偏离;2) 保证提取边缘的连续性,一定程度上来讲也是保证平滑。
梯度强度
就是梯度求模而已,x和y两个方向的梯度值平方相加在开方,公式如下:
\[{{I}_{G}}\left( q \right)=\sqrt{{{I}_{x}}\left( q \right)+{{I}_{y}}\left( q \right)}\]
因为要求cost,所以反向并归一化:
\[{{f}_{G}}\left( q \right)=1-\frac{{{I}_{G}}\left( q \right)}{\max \left( {{I}_{G}} \right)}\]
梯度方向
这一项其实是个平滑项,会给变化剧烈的边缘赋一个比较高的cost,让提取的边缘避免噪声的影响。具体公式如下:
\[{{f}_{D}}\left( p,q \right)=\frac{2}{3\pi }\left( \arccos \left( {{d}_{p}}\left( p,q \right) \right)+\arccos \left( {{d}_{q}}\left( p,q \right) \right) \right)\]
其中,
\[{{d}_{p}}\left( p,q \right)=\left\langle {{d}_{\bot }}\left( p \right),{{l}_{D}}\left( p,q \right) \right\rangle \]
\[{{d}_{q}}\left( p,q \right)=\left\langle {{l}_{D}}\left( p,q \right),{{d}_{\bot }}\left( q \right) \right\rangle \]
\[{{l}_{D}}\left( p,q \right)=\left\{ \begin{matrix}
q-p;\text{ if }\left\langle {{d}_{\bot }}\left( p \right),q-p \right\rangle \ge 0 \\
p-q;\text{ if }\left\langle {{d}_{\bot }}\left( p \right),q-p \right\rangle <0 \\
\end{matrix} \right.\]
\({{d}_{\bot }}\left( p \right)\)是取p的垂直方向,另外注意上式中符号的判断会将\({{d}_{\bot }}\left( p \right)\)和\({{l}_{D}}\left( p,q \right)\)的取值限制在π/2以内。
\[{{d}_{\bot }}\left( p \right)=\left( {{p}_{y}},-{{p}_{x}} \right)\]
斜对角方向的cost修正
在二维图像中,相邻的像素通常按照间隔欧式距离分为两种:1) 上下左右相邻,间隔为像素边长;2) 斜对角相邻,间隔为像素边长的\(\sqrt{2}\)倍。在计算局部cost的时候通常要把这种距离差异的影响考虑进去,比如下面这幅图:
2 | 3 | 4 |
5 | 6 | 6 |
7 | 8 | 9 |
如果不考虑像素位置的影响,那么查找最小cost的时候会认为左上角的cost=2最小。然而如果考虑到像素间距的影响,我们来看左上角方向,和中心的差异是6-2,做个线性插值的话,则左上角距中心单位距离上的值应该为\(6-\left( 6-2 \right)\times 1/\sqrt{2}\ =3.17>3\),所以正上方的才是最小cost的正确方向。
最短路径查找
在磁力套索中,一般的用法是先单击一个点,然后移动鼠标,在鼠标和一开始单击的点之间就会出现自动贴近边缘的线,这里我们定义一开始单击的像素点为种子点(seed),而磁力套索其实在考虑上部分提到的边缘相关cost的情况下查找种子点到当前鼠标的最短路径。如下图,红色的就是种子点,而移动鼠标时,最贴近边缘的种子点和鼠标坐标的连线就会实时显示,这也是为什么磁力套索也叫Livewire。
实现最短路径的办法很多,一般而言就是动态规划了,这里介绍的是基于Dijkstra算法的一种实现,基本思想是,给定种子点后,执行Dijkstra算法将图像的所有像素遍历,得到每个像素到种子点的最短路径。以下面这幅图为例,在一个cost矩阵中,利用Dijkstra算法遍历每一个元素后,每个元素都会指向一个相邻的元素,这样任意一个像素都能找到一条到seed的路径,比如右上角的42和39对应的像素,沿着箭头到了0。
算法如下:
输入: |
遍历的过程会优先经过cost最低的区域,如下图:
所有像素对应的到种子像素的最短路径都找到后,移动鼠标时就直接画出到seed的最短路径就可以了。
Python实现
算法部分直接调用了OpenCV的Canny函数和Sobel函数(求梯度),对于RGB的处理也很简陋,直接用梯度最大的值来近似。另外因为懒,cost map和path map都直接用了字典(dict),而记录展开过的像素则直接采用了集合(set)。GUI部分因为不会用QThread所以用了Python的threading,只有图像显示交互区域和状态栏提示,左键点击设置种子点,右键结束,已经提取的边缘为绿色线,正在提取的为蓝色线。
代码
算法部分
1 from __future__ import division 2 import cv2 3 import numpy as np 4 5 SQRT_0_5 = 0.70710678118654757 6 7 class Livewire(): 8 """ 9 A simple livewire implementation for verification using 10 1. Canny edge detector + gradient magnitude + gradient direction 11 2. Dijkstra algorithm 12 """ 13 14 def __init__(self, image): 15 self.image = image 16 self.x_lim = image.shape[0] 17 self.y_lim = image.shape[1] 18 # The values in cost matrix ranges from 0~1 19 self.cost_edges = 1 - cv2.Canny(image, 85, 170)/255.0 20 self.grad_x, self.grad_y, self.grad_mag = self._get_grad(image) 21 self.cost_grad_mag = 1 - self.grad_mag/np.max(self.grad_mag) 22 # Weight for (Canny edges, gradient magnitude, gradient direction) 23 self.weight = (0.425, 0.425, 0.15) 24 25 self.n_pixs = self.x_lim * self.y_lim 26 self.n_processed = 0 27 28 @classmethod 29 def _get_grad(cls, image): 30 """ 31 Return the gradient magnitude of the image using Sobel operator 32 """ 33 rgb = True if len(image.shape) > 2 else False 34 grad_x = cv2.Sobel(image, cv2.CV_64F, 1, 0, ksize=3) 35 grad_y = cv2.Sobel(image, cv2.CV_64F, 0, 1, ksize=3) 36 if rgb: 37 # A very rough approximation for quick verification... 38 grad_x = np.max(grad_x, axis=2) 39 grad_y = np.max(grad_y, axis=2) 40 41 grad_mag = np.sqrt(grad_x**2+grad_y**2) 42 grad_x /= grad_mag 43 grad_y /= grad_mag 44 45 return grad_x, grad_y, grad_mag 46 47 def _get_neighbors(self, p): 48 """ 49 Return 8 neighbors around the pixel p 50 """ 51 x, y = p 52 x0 = 0 if x == 0 else x - 1 53 x1 = self.x_lim if x == self.x_lim - 1 else x + 2 54 y0 = 0 if y == 0 else y - 1 55 y1 = self.y_lim if y == self.y_lim - 1 else y + 2 56 57 return [(x, y) for x in xrange(x0, x1) for y in xrange(y0, y1) if (x, y) != p] 58 59 def _get_grad_direction_cost(self, p, q): 60 """ 61 Calculate the gradient changes refer to the link direction 62 """ 63 dp = (self.grad_y[p[0]][p[1]], -self.grad_x[p[0]][p[1]]) 64 dq = (self.grad_y[q[0]][q[1]], -self.grad_x[q[0]][q[1]]) 65 66 l = np.array([q[0]-p[0], q[1]-p[1]], np.float) 67 if 0 not in l: 68 l *= SQRT_0_5 69 70 dp_l = np.dot(dp, l) 71 l_dq = np.dot(l, dq) 72 if dp_l < 0: 73 dp_l = -dp_l 74 l_dq = -l_dq 75 76 # 2/3pi * ... 77 return 0.212206590789 * (np.arccos(dp_l)+np.arccos(l_dq)) 78 79 def _local_cost(self, p, q): 80 """ 81 1. Calculate the Canny edges & gradient magnitude cost taking into account Euclidean distance 82 2. Combine with gradient direction 83 Assumption: p & q are neighbors 84 """ 85 diagnol = q[0] == p[0] or q[1] == p[1] 86 87 # c0, c1 and c2 are costs from Canny operator, gradient magnitude and gradient direction respectively 88 if diagnol: 89 c0 = self.cost_edges[p[0]][p[1]]-SQRT_0_5*(self.cost_edges[p[0]][p[1]]-self.cost_edges[q[0]][q[1]]) 90 c1 = self.cost_grad_mag[p[0]][p[1]]-SQRT_0_5*(self.cost_grad_mag[p[0]][p[1]]-self.cost_grad_mag[q[0]][q[1]]) 91 c2 = SQRT_0_5 * self._get_grad_direction_cost(p, q) 92 else: 93 c0 = self.cost_edges[q[0]][q[1]] 94 c1 = self.cost_grad_mag[q[0]][q[1]] 95 c2 = self._get_grad_direction_cost(p, q) 96 97 if np.isnan(c2): 98 c2 = 0.0 99 100 w0, w1, w2 = self.weight 101 cost_pq = w0*c0 + w1*c1 + w2*c2 102 103 return cost_pq * cost_pq 104 105 def get_path_matrix(self, seed): 106 """ 107 Get the back tracking matrix of the whole image from the cost matrix 108 """ 109 neighbors = [] # 8 neighbors of the pixel being processed 110 processed = set() # Processed point 111 cost = {seed: 0.0} # Accumulated cost, initialized with seed to itself 112 paths = {} 113 114 self.n_processed = 0 115 116 while cost: 117 # Expand the minimum cost point 118 p = min(cost, key=cost.get) 119 neighbors = self._get_neighbors(p) 120 processed.add(p) 121 122 # Record accumulated costs and back tracking point for newly expanded points 123 for q in [x for x in neighbors if x not in processed]: 124 temp_cost = cost[p] + self._local_cost(p, q) 125 if q in cost: 126 if temp_cost < cost[q]: 127 cost.pop(q) 128 else: 129 cost[q] = temp_cost 130 processed.add(q) 131 paths[q] = p 132 133 # Pop traversed points 134 cost.pop(p) 135 136 self.n_processed += 1 137 138 return paths
GUI部分
1 from __future__ import division 2 import time 3 import cv2 4 from PyQt4 import QtGui, QtCore 5 from threading import Thread 6 from livewire import Livewire 7 8 class ImageWin(QtGui.QWidget): 9 def __init__(self): 10 super(ImageWin, self).__init__() 11 self.setupUi() 12 self.active = False 13 self.seed_enabled = True 14 self.seed = None 15 self.path_map = {} 16 self.path = [] 17 18 def setupUi(self): 19 self.hbox = QtGui.QVBoxLayout(self) 20 21 # Load and initialize image 22 self.image_path = ‘‘ 23 while self.image_path == ‘‘: 24 self.image_path = QtGui.QFileDialog.getOpenFileName(self, ‘‘, ‘‘, ‘(*.bmp *.jpg *.png)‘) 25 self.image = QtGui.QPixmap(self.image_path) 26 self.cv2_image = cv2.imread(str(self.image_path)) 27 self.lw = Livewire(self.cv2_image) 28 self.w, self.h = self.image.width(), self.image.height() 29 30 self.canvas = QtGui.QLabel(self) 31 self.canvas.setMouseTracking(True) 32 self.canvas.setPixmap(self.image) 33 34 self.status_bar = QtGui.QStatusBar(self) 35 self.status_bar.showMessage(‘Left click to set a seed‘) 36 37 self.hbox.addWidget(self.canvas) 38 self.hbox.addWidget(self.status_bar) 39 self.setLayout(self.hbox) 40 41 def mousePressEvent(self, event): 42 if self.seed_enabled: 43 pos = event.pos() 44 x, y = pos.x()-self.canvas.x(), pos.y()-self.canvas.y() 45 46 if x < 0: 47 x = 0 48 if x >= self.w: 49 x = self.w - 1 50 if y < 0: 51 y = 0 52 if y >= self.h: 53 y = self.h - 1 54 55 # Get the mouse cursor position 56 p = y, x 57 seed = self.seed 58 59 # Export bitmap 60 if event.buttons() == QtCore.Qt.MidButton: 61 filepath = QtGui.QFileDialog.getSaveFileName(self, ‘Save image audio to‘, ‘‘, ‘*.bmp\n*.jpg\n*.png‘) 62 image = self.image.copy() 63 64 draw = QtGui.QPainter() 65 draw.begin(image) 66 draw.setPen(QtCore.Qt.blue) 67 if self.path_map: 68 while p != seed: 69 draw.drawPoint(p[1], p[0]) 70 for q in self.lw._get_neighbors(p): 71 draw.drawPoint(q[1], q[0]) 72 p = self.path_map[p] 73 if self.path: 74 draw.setPen(QtCore.Qt.green) 75 for p in self.path: 76 draw.drawPoint(p[1], p[0]) 77 for q in self.lw._get_neighbors(p): 78 draw.drawPoint(q[1], q[0]) 79 draw.end() 80 81 image.save(filepath, quality=100) 82 83 else: 84 self.seed = p 85 86 if self.path_map: 87 while p != seed: 88 p = self.path_map[p] 89 self.path.append(p) 90 91 # Calculate path map 92 if event.buttons() == QtCore.Qt.LeftButton: 93 Thread(target=self._cal_path_matrix).start() 94 Thread(target=self._update_path_map_progress).start() 95 96 # Finish current task and reset 97 elif event.buttons() == QtCore.Qt.RightButton: 98 self.path_map = {} 99 self.status_bar.showMessage(‘Left click to set a seed‘) 100 self.active = False 101 102 def mouseMoveEvent(self, event): 103 if self.active and event.buttons() == QtCore.Qt.NoButton: 104 pos = event.pos() 105 x, y = pos.x()-self.canvas.x(), pos.y()-self.canvas.y() 106 107 if x < 0 or x >= self.w or y < 0 or y >= self.h: 108 pass 109 else: 110 # Draw livewire 111 p = y, x 112 path = [] 113 while p != self.seed: 114 p = self.path_map[p] 115 path.append(p) 116 117 image = self.image.copy() 118 draw = QtGui.QPainter() 119 draw.begin(image) 120 draw.setPen(QtCore.Qt.blue) 121 for p in path: 122 draw.drawPoint(p[1], p[0]) 123 if self.path: 124 draw.setPen(QtCore.Qt.green) 125 for p in self.path: 126 draw.drawPoint(p[1], p[0]) 127 draw.end() 128 self.canvas.setPixmap(image) 129 130 def _cal_path_matrix(self): 131 self.seed_enabled = False 132 self.active = False 133 self.status_bar.showMessage(‘Calculating path map...‘) 134 path_matrix = self.lw.get_path_matrix(self.seed) 135 self.status_bar.showMessage(r‘Left: new seed / Right: finish‘) 136 self.seed_enabled = True 137 self.active = True 138 139 self.path_map = path_matrix 140 141 def _update_path_map_progress(self): 142 while not self.seed_enabled: 143 time.sleep(0.1) 144 message = ‘Calculating path map... {:.1f}%‘.format(self.lw.n_processed/self.lw.n_pixs*100.0) 145 self.status_bar.showMessage(message) 146 self.status_bar.showMessage(r‘Left: new seed / Right: finish‘)
主函数
1 import sys 2 from PyQt4 import QtGui 3 from gui import ImageWin 4 5 def main(): 6 app = QtGui.QApplication(sys.argv) 7 window = ImageWin() 8 window.setMouseTracking(True) 9 window.setWindowTitle(‘Livewire Demo‘) 10 window.show() 11 window.setFixedSize(window.size()) 12 sys.exit(app.exec_()) 13 14 if __name__ == ‘__main__‘: 15 main()
蛋疼地上传到了Github(传送门),欢迎fork。
效率的改进
因为这个代码的原型只是为了用C++开发之前的Python评估和验证,所以完全没考虑效率,执行速度是完全不行的,基本上400x400的图片就不能忍了……至于基于Python版本的效率提升我没有仔细想过,只是大概看来有这么几个比较明显的地方:
1) 取出当前最小cost像素操作
p = min(cost, key=cost.get)
这个虽然写起来很爽但显然是不行的,至少得用个min heap什么的。因为我是用dict同时表示待处理像素和cost了,也懒得想一下怎么和Python的heapq结合起来,所以直接用了粗暴省事的min()。
2) 梯度方向的计算
三角函数的计算应该是尽量避免的,另外原文可能是为了将值域扩展到>π所以把q-p也用上了,其实这一项本来权重就小,那怕直接用两个像素各自的梯度方向向量做点积然后归一化一下结果也是还行的。即使要用arccos,也可以考虑写个look-up table近似。当然我最后想说的是个人觉得其实这项真没那么必要,直接自适应spilne或者那怕三点均值平滑去噪效果就不错了。
3) 计算相邻像素的位置
如果两个像素相邻,则他们各自周围的8个相邻像素也会重合。的我的办法比较原始,可以考率不用模块化直接计算。
4) 替换部分数据结构
比如path map其实本质是给出每个像素在最短路径上的上一个像素,是个矩阵。其实可以考虑用线性的数据结构代替,不过如果真这样做一般来说都是在C/C++代码里了。
5) numpy
我印象中对numpy的调用顺序也会影响到效率,连续调用numpy的内置方法似乎会带来效率的整体提升,不过话还是说回来,实际应用中如果到了这一步,应该也属于C/C++代码范畴了。
6) 算法层面的改进
这块没有深入研究,第一感觉是实际应用中没必要一上来就计算整幅图像,可以根据seed位置做一些区块划分,鼠标本身也会留下轨迹,也或许可以考虑只在鼠标轨迹方向进行启发式搜索。另外计算路径的时候也许可以考虑借鉴有点类似于Image Pyramid的思想,没必要一上来就对全分辨率下的路径进行查找。由于后来做的项目没有采用这个算法,所以我也没有继续研究,虽然挺好奇的,其实有好多现成的代码,比如GIMP,不过没有精力去看了。
更多的改进
虽然都没做,大概介绍一下,都是考虑了实用性的改进。
路径冷却(Path Cooling)
用过Photoshop和GIMP磁力套索的人都知道,即使鼠标不点击图片,在移动过程中也会自动生成一些将抠图轨迹固定住的点,这些点其实就是新的种子点,而这种使用过程中自动生成新的种子点的方法叫Path cooling。这个方法的基本思路如下:随着鼠标移动过程中如果一定时间内一段路径都保持固定不变,那么就把这段路径中离种子最远的点设置为新的种子,其实背后隐藏的还是动态规划的思想,贝尔曼最优。这个名字也是比较形象的,路径冷却。
动态训练(Interactive Dynamic Training)
单纯的最短路径查找在使用的时候常常出现找到的边缘不是想要的边缘的问题,比如上图,绿色的线是上一段提取的边缘,蓝色的是当前正在提取的边缘。左图中,镜子外面Lena的帽子边缘是我们想要提取的,然而由于镜子里的Lena的帽子边缘的cost更低,所以实际提取出的蓝色线段如右图中贴到右边了。所以Interactive Dynamic Training的思想是,认为绿色的线段是正确提取的边缘,然后利用绿色线段作为训练数据来给当前提取边缘的cost函数附加一个修正值。
[1]中采用的方法是统计前一段边缘上点的梯度强度的直方图,然后按照梯度出现频率给当前图中的像素加权。举例来说如果绿色线段中的所有像素对应的梯度强度都是在50到100之间的话,那么可以将50到100以10为单位分为5个bin,统计每个bin里的出现频率,也就是直方图,然后对当前检测到的梯度强度,做个线性加权。比方说50~60区间内对应的像素最多有10个,那么把10作为最大值,并且对当前检测到的梯度强度处于50~60之间的像素均乘上系数1.0;如果训练数据中70~80之间有5个,那么cost加权系数为5/10=0.5,则对所有当前检测到的梯度强度处于70~80之间的像素均乘上系数0.5;如果训练数据中100以上没有,所以cost附加为0/10=0,则加权系数为0,这样即使检测到更强的边缘也不会偏离前一段边缘了。这是基本思想,当然实际的实现没有这么简单,除了边缘上的像素还要考虑垂直边缘上左边和右边的两个像素点,这样保证了边缘的pattern。另外随着鼠标越来越远离训练边缘,检测到的边缘的pattern可能会出现不一样,所以Training可能会起反作用,所以这种Training的作用范围也需要考虑到鼠标离种子点的距离,最后还要有一些平滑去噪的处理,具体都在[1]里有讲到,挺繁琐的(那会好像还没有SIFT),不详述了。
种子点位置的修正(Cursor Snap)
虽然这个算法可以自动找出种子点和鼠标之间最贴近边缘的路径,不过,人的手,常常抖,所以种子点未必能很好地设置到边缘上。所以可以在用户设置完种子点位置之后,自动在其坐标周围小范围内,比如7x7的区域内搜索cost最低的像素,作为真正的种子点位置,这个过程叫做Cursor snap。
参考文献:
[1] Mortensen, Eric N., and William A. Barrett. "Intelligent scissors for image composition." Proceedings of the 22nd annual conference on Computer graphics and interactive techniques. ACM, 1995.