graph cut算法

本文最后更新于 2022年3月30日 晚上

最大流最小割问题

可以参考这篇博客

graph cut思想

能量函数


其中前一项是区域项。而后一项是边界项。而函数的优化目的是找能量最低的位置。

首先我们需要一些种子点标注它是前景还是背景,然后我们可以根据前景点和后景点构建概率直方图。
图中的含义是p节点的前景的区域概率为p节点在前景直方图的概率的负对数。

而边界项的计算公式为

也就是说随着颜色差异和距离的增大,B在减小。

gprah cut的目的是为了区分出哪些是前景,哪些是背景来进行分割。因此我们可以采用最大流最小割的思想,将源点认为是前景,将汇点认为是背景,而图像中的像素是其他点,我们需要找到最小割将前景和背景分割开来。

定义完点之后还需要定义流量,流量分为端点到像素点之间的流量和像素点到像素点之间的流量:

  • 端点到像素点: 这一项是前面公式中的R(L),它的定义为如果是前景种子点,则到前景流量为无穷大,到背景点流量为0.其他点到源点和汇点计算遵从上面的公式
  • 像素点到像素点: 这一项是公式中的B(L),遵从上面公式

代码实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
import os

import cv2
import numpy as np
import maxflow
import matplotlib.pyplot as plt
from medpy import metric
from PIL import Image, ImageDraw

left_mouse_down = False
right_mouse_down = False
foreground_index = 0
background_index = 0
foreground_lines = list()
background_lines = list()


class GraphMaker:
foreground = 1
background = 0
segmented = 1
default = 0.5
MAXIMUM = 1000000000

def __init__(self, filename):
self.image = None
self.graph = None
self.segment_overlay = None
self.mask = None
self.filename:str = filename
self.load_image(filename)
self.background_seeds = []
self.foreground_seeds = []
self.background_average = np.array(3)
self.foreground_average = np.array(3)
self.nodes = []
self.edges = []

def load_image(self, filename):
self.filename = filename
self.image = cv2.imread(filename)
self.graph = None
self.segment_overlay = np.zeros(self.image.shape[:2])
self.mask = None

def add_seed(self, x, y, type):
if self.image is None:
print('Please load an image before adding seeds.')
if type == self.background:
if not self.background_seeds.__contains__((x, y)):
self.background_seeds.append((x, y))
elif type == self.foreground:
if not self.foreground_seeds.__contains__((x, y)):
self.foreground_seeds.append((x, y))

def create_graph(self):
if len(self.background_seeds) == 0 or len(self.foreground_seeds) == 0:
print("Please enter at least one foreground and background seed.")
return

print("Making graph")
print("Finding foreground and background averages")
self.find_averages()

print("Populating nodes and edges")
self.populate_graph()

def find_averages(self):
self.graph = np.zeros((self.image.shape[0], self.image.shape[1]))
print(self.graph.shape)
self.graph.fill(self.default) # 初始化填充为0.5
self.background_average = np.zeros(3)
self.foreground_average = np.zeros(3)

for coordinate in self.background_seeds:
self.graph[coordinate[1] - 1, coordinate[0] - 1] = 0
self.background_average += self.image[coordinate[1], coordinate[0]]

self.background_average /= len(self.background_seeds) # 之后没有调用,R(x)需要加上直方图

for coordinate in self.foreground_seeds:
self.graph[coordinate[1] - 1, coordinate[0] - 1] = 1
self.foreground_average += self.image[coordinate[1], coordinate[0]]

self.foreground_average /= len(self.foreground_seeds)

def populate_graph(self):
self.nodes = []
self.edges = []
for (y, x), value in np.ndenumerate(self.graph):
if value == 0.0:
# 索引,背景值,前景值
# nodes是到s,t节点的值
self.nodes.append((self.get_node_num(x, y, self.image.shape), self.MAXIMUM, 0)) # 定义节点到源点和汇点之间的流量

elif value == 1.0:
self.nodes.append((self.get_node_num(x, y, self.image.shape), 0, self.MAXIMUM))

else:
self.nodes.append((self.get_node_num(x, y, self.image.shape), 0, 0)) # 普通节点到源节点之间没有流量

for (y, x), value in np.ndenumerate(self.graph):
if y == self.graph.shape[0] - 1 or x == self.graph.shape[1] - 1:
continue
my_index = self.get_node_num(x, y, self.image.shape)

neighbor_index = self.get_node_num(x + 1, y, self.image.shape)
g = 1 / (1 + np.sum(np.power(self.image[y, x] - self.image[y, x + 1], 2))) # 定义像素节点之间的流量
# print("g is " + str(g))
# edges是节点之间的值
self.edges.append((my_index, neighbor_index, g))

neighbor_index = self.get_node_num(x, y + 1, self.image.shape)
g = 1 / (1 + np.sum(np.power(self.image[y, x] - self.image[y + 1, x], 2)))
self.edges.append((my_index, neighbor_index, g))

def cut_graph(self):
self.segment_overlay = np.zeros_like(self.segment_overlay)
self.mask = np.zeros_like(self.image, dtype=bool)
g = maxflow.Graph[float](len(self.nodes), len(self.edges))
nodelist = g.add_nodes(len(self.nodes)) # 创建图中的点

for node in self.nodes:
g.add_tedge(nodelist[node[0]], node[1], node[2]) # 普通点到源点和汇点之间的路,其中源点是background,汇点是foreground

for edge in self.edges:
g.add_edge(edge[0], edge[1], edge[2], edge[2]) # 普通点和普通点之间的流量

flow = g.maxflow() # 执行最大流最小割算法,返回的是最小割的流量
print("maximum flow is {}".format(flow))

for index in range(len(self.nodes)):
if g.get_segment(index) == 1: # 获得划分属于前景的点
xy = self.get_xy(index, self.image.shape) # 前景的xy坐标
self.segment_overlay[xy[1], xy[0]] = 1
self.mask[xy[1], xy[0]] = (True, True, True)

def swap_overlay(self, overlay_num):
self.current_overlay = overlay_num

def save_image(self, outfilename):
if self.mask is None:
print('Please segment the image before saving.')
return
print(outfilename)
# print(self.image.name())
to_save = np.zeros_like(self.image)

np.copyto(to_save, self.image, where=self.mask)
cv2.imwrite(outfilename, to_save)
save_stroke = np.zeros_like(self.image)
np.copyto(save_stroke, self.image)
for foreground_line in foreground_lines:
self.draw_polyline(save_stroke, foreground_line, (0, 0, 255))
for background_line in background_lines:
self.draw_polyline(save_stroke, background_line, (255, 0, 0))
cv2.imwrite(outfilename[: -4] + "storke.jpg", save_stroke)
return self.segment_overlay

@staticmethod
def evaluate(prediction_path, reference_path):
reference = cv2.imread(reference_path)
prediction = cv2.imread(prediction_path)
dice = metric.binary.dc(prediction, reference)
hd = metric.binary.hd95(prediction, reference)
sensitivity = metric.binary.sensitivity(prediction, reference)
specificity = metric.binary.specificity(prediction, reference)
accuracy = metric.positive_predictive_value(prediction, reference)
print("{:.2f} {:.2f} {:.2f} {:.2f} {:.2f}".format(dice, hd, sensitivity, specificity, accuracy))

@staticmethod
def get_node_num(x, y, array_shape):
return y * array_shape[1] + x

@staticmethod
def get_xy(nodenum, array_shape):
return (nodenum % array_shape[1]), (int(nodenum / array_shape[1]))

@staticmethod
def draw_polyline(img, lines, color):
for i in range(1, len(lines)):
cv2.line(img, lines[i-1], lines[i], color=color, thickness=3)


def onMouseClick(event, x, y, flags, param):
global left_mouse_down
global right_mouse_down
global foreground_index
global background_index
global foreground_lines
global background_lines
if event == cv2.EVENT_LBUTTONDOWN:
foreground_lines.append([])
left_mouse_down = True
elif event == cv2.EVENT_LBUTTONUP:
foreground_index = foreground_index + 1
left_mouse_down = False
elif event == cv2.EVENT_RBUTTONDOWN:
background_lines.append([])
right_mouse_down = True
elif event == cv2.EVENT_RBUTTONUP:
background_index = background_index + 1
right_mouse_down = False
elif event == cv2.EVENT_MOUSEMOVE:
if left_mouse_down:
param.add_seed(x, y, param.foreground)
foreground_lines[foreground_index].append((x, y))
elif right_mouse_down:
param.add_seed(x, y, param.background)
background_lines[background_index].append((x, y))


if __name__ == '__main__':
# 完成main函数
files = os.walk("data/img")
for path, dir_list, file_list in files:
for file in file_list:
foreground_index = 0
foreground_lines.clear()
background_lines.clear()
background_index = 0
marker = GraphMaker(path + "/" + file)
img = cv2.imread(path + "/" + file)
cv2.imshow(file, img)
cv2.setMouseCallback(file, onMouseClick, marker)
cv2.waitKey(0)
cv2.destroyAllWindows()
marker.create_graph()
marker.cut_graph()
marker.save_image("data/res/" + file)
marker.evaluate("data/res/" + file, "data/mask/" + file[:-4] + ".png")


graph cut算法
https://www.xinhecuican.tech/post/c8d2fbab.html
作者
星河璀璨
发布于
2022年3月30日
许可协议