@@ -188,20 +188,24 @@ def view(self):
188
188
self .root .view ()
189
189
190
190
@CvDBaseTiming .timeit (level = 2 , prefix = "[API] " )
191
- def visualize (self , radius = 24 , width = 1200 , height = 800 , padding = 0.2 , plot_num = 30 , title = "CvDTree" ):
191
+ def visualize (self , radius = 24 , width = 1200 , height = 800 ,
192
+ height_padding_ratio = 0.2 , width_padding = 30 , title = "CvDTree" ):
192
193
self ._update_layers ()
193
- units = [len (layer ) for layer in self .layers ]
194
+ n_units = [len (layer ) for layer in self .layers ]
194
195
195
196
img = np .ones ((height , width , 3 ), np .uint8 ) * 255
196
- axis0_padding = int (height / (len (self .layers ) - 1 + 2 * padding )) * padding + plot_num
197
- axis0 = np .linspace (
198
- axis0_padding , height - axis0_padding , len (self .layers ), dtype = np .int )
199
- axis1_padding = plot_num
200
- axis1 = [np .linspace (axis1_padding , width - axis1_padding , unit + 2 , dtype = np .int )
201
- for unit in units ]
202
- axis1 = [axis [1 :- 1 ] for axis in axis1 ]
203
-
204
- for i , (y , xs ) in enumerate (zip (axis0 , axis1 )):
197
+ height_padding = int (
198
+ height / (len (self .layers ) - 1 + 2 * height_padding_ratio )
199
+ ) * height_padding_ratio + width_padding
200
+ height_axis = np .linspace (
201
+ height_padding , height - height_padding , len (self .layers ), dtype = np .int )
202
+ width_axis = [
203
+ np .linspace (width_padding , width - width_padding , unit + 2 , dtype = np .int )
204
+ for unit in n_units
205
+ ]
206
+ width_axis = [axis [1 :- 1 ] for axis in width_axis ]
207
+
208
+ for i , (y , xs ) in enumerate (zip (height_axis , width_axis )):
205
209
for j , x in enumerate (xs ):
206
210
if i == 0 :
207
211
cv2 .circle (img , (x , y ), radius , (225 , 100 , 125 ), 1 )
@@ -216,13 +220,13 @@ def visualize(self, radius=24, width=1200, height=800, padding=0.2, plot_num=30,
216
220
color = (0 , 255 , 0 )
217
221
cv2 .putText (img , text , (x - 7 * len (text )+ 2 , y + 3 ), cv2 .LINE_AA , 0.6 , color , 1 )
218
222
219
- for i , y in enumerate (axis0 ):
220
- if i == len (axis0 ) - 1 :
223
+ for i , y in enumerate (height_axis ):
224
+ if i == len (height_axis ) - 1 :
221
225
break
222
- for j , x in enumerate (axis1 [i ]):
223
- new_y = axis0 [i + 1 ]
226
+ for j , x in enumerate (width_axis [i ]):
227
+ new_y = height_axis [i + 1 ]
224
228
dy = new_y - y - 2 * radius
225
- for k , new_x in enumerate (axis1 [i + 1 ]):
229
+ for k , new_x in enumerate (width_axis [i + 1 ]):
226
230
dx = new_x - x
227
231
length = np .sqrt (dx ** 2 + dy ** 2 )
228
232
ratio = 0.5 - min (0.4 , 1.2 * 24 / length )
0 commit comments