Skip to content

Commit 4ff8d58

Browse files
committed
tested ok with opencl
1 parent 91ae45f commit 4ff8d58

File tree

1 file changed

+68
-6
lines changed

1 file changed

+68
-6
lines changed

MTM/__init__.py

Lines changed: 68 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ def findLocalMax(corrMap, score_threshold=0.6):
1313
'''
1414
Get coordinates of the local maximas with values above a threshold in the image of the correlation map
1515
'''
16+
# Get back an array if UMat provided
17+
if isinstance(corrMap, cv2.UMat): corrMap = corrMap.get()
1618

1719
# IF depending on the shape of the correlation map
1820
if corrMap.shape == (1,1): ## Template size = Image size -> Correlation map is a single digit')
@@ -64,7 +66,49 @@ def computeScoreMap(template, image, method=cv2.TM_CCOEFF_NORMED):
6466
return cv2.matchTemplate(template, image, method)
6567

6668

67-
def findMatches(listTemplates, image, method=cv2.TM_CCOEFF_NORMED, N_object=float("inf"), score_threshold=0.5, searchBox=None):
69+
def checkTypes(listTemplates, image, useOpencl=False):
70+
'''
71+
Check that the templates and image have the same bitDepthand 8 or 32-bit'''
72+
templatesType = list( set( [template[1].dtype for template in listTemplates] ) ) # get a list of unique template types
73+
74+
if (image.dtype =="float64") or ("float64" in templatesType):
75+
raise ValueError("64-bit not supported, max 32-bit")
76+
77+
all8 = image.dtype=="uint8" and templatesType==["uint8"]
78+
all32 = image.dtype=="float32" and templatesType==["float32"]
79+
80+
if all8 or all32:
81+
82+
if useOpencl:
83+
listTemplates = [ (template[0], cv2.UMat(template[1]) ) for template in listTemplates ]
84+
image = cv2.UMat(image)
85+
86+
else:
87+
pass # images are either all 8-bit or all 32-bit and no need to convert to UMat
88+
89+
90+
else:
91+
# Create a lambda function for conversion
92+
if useOpencl:
93+
convert32 = lambda array: cv2.UMat( np.float32(array) )
94+
else:
95+
convert32 = lambda array: cv2.UMat(array)
96+
97+
# convert to 32-bit + UMat if necessary
98+
listTemplates = [ (template[0], convert32(template[1]) ) for template in listTemplates ]
99+
image = convert32(image)
100+
101+
return listTemplates, image
102+
103+
104+
105+
def findMatches(listTemplates,
106+
image,
107+
method=cv2.TM_CCOEFF_NORMED,
108+
N_object=float("inf"),
109+
score_threshold=0.5,
110+
searchBox=None,
111+
useOpencl=False):
68112
'''
69113
Find all possible templates locations provided a list of template to search and an image
70114
Parameters
@@ -98,17 +142,20 @@ def findMatches(listTemplates, image, method=cv2.TM_CCOEFF_NORMED, N_object=floa
98142
image = image[yOffset:yOffset+searchHeight, xOffset:xOffset+searchWidth]
99143
else:
100144
xOffset=yOffset=0
145+
146+
listTemplates, image = checkTypes(listTemplates, image, useOpencl) # also convert to UMat if using opencl
101147

102148
listHit = []
103149
for templateName, template in listTemplates:
104150

105151
#print('\nSearch with template : ',templateName)
106152

107-
corrMap = computeScoreMap(template, image, method)
153+
corrMap = cv2.matchTemplate(template, image, method) # automatically run with opencl if provided a UMat
108154

109155
## Find possible location of the object
110156
if N_object==1: # Detect global Min/Max
111157
minVal, maxVal, minLoc, maxLoc = cv2.minMaxLoc(corrMap)
158+
if isinstance(corrMap, cv2.UMat): corrMap = corrMap.get()
112159

113160
if method==1:
114161
Peaks = [minLoc[::-1]] # opposite sorting than in the multiple detection
@@ -117,7 +164,9 @@ def findMatches(listTemplates, image, method=cv2.TM_CCOEFF_NORMED, N_object=floa
117164
Peaks = [maxLoc[::-1]]
118165

119166

120-
else:# Detect local max or min
167+
else: # Detect local max or min
168+
if isinstance(corrMap, cv2.UMat): corrMap = corrMap.get()
169+
121170
if method==1: # Difference => look for local minima
122171
Peaks = findLocalMin(corrMap, score_threshold)
123172

@@ -130,7 +179,7 @@ def findMatches(listTemplates, image, method=cv2.TM_CCOEFF_NORMED, N_object=floa
130179

131180
# Once every peak was detected for this given template
132181
## Create a dictionnary for each hit with {'TemplateName':, 'BBox': (x,y,Width, Height), 'Score':coeff}
133-
182+
if isinstance(template, cv2.UMat): template = template.get() # get back to array from UMat
134183
height, width = template.shape[0:2] # slicing make sure it works for RGB too
135184

136185
for peak in Peaks :
@@ -143,7 +192,14 @@ def findMatches(listTemplates, image, method=cv2.TM_CCOEFF_NORMED, N_object=floa
143192
return pd.DataFrame(listHit) # All possible hits before Non-Maxima Supression
144193

145194

146-
def matchTemplates(listTemplates, image, method=cv2.TM_CCOEFF_NORMED, N_object=float("inf"), score_threshold=0.5, maxOverlap=0.25, searchBox=None):
195+
def matchTemplates(listTemplates,
196+
image,
197+
method=cv2.TM_CCOEFF_NORMED,
198+
N_object=float("inf"),
199+
score_threshold=0.5,
200+
maxOverlap=0.25,
201+
searchBox=None,
202+
useOpencl=False):
147203
'''
148204
Search each template in the image, and return the best N_object location which offer the best score and which do not overlap
149205
Parameters
@@ -174,7 +230,13 @@ def matchTemplates(listTemplates, image, method=cv2.TM_CCOEFF_NORMED, N_object=f
174230
if maxOverlap<0 or maxOverlap>1:
175231
raise ValueError("Maximal overlap between bounding box is in range [0-1]")
176232

177-
tableHit = findMatches(listTemplates, image, method, N_object, score_threshold, searchBox)
233+
tableHit = findMatches(listTemplates,
234+
image,
235+
method,
236+
N_object,
237+
score_threshold,
238+
searchBox,
239+
useOpencl)
178240

179241
if method == 1: bestHits = NMS(tableHit, N_object=N_object, maxOverlap=maxOverlap, sortAscending=True)
180242

0 commit comments

Comments
 (0)