Skip to content

Commit ef362be

Browse files
committed
Update c_CvDTree
1 parent 90874fc commit ef362be

File tree

2 files changed

+21
-16
lines changed

2 files changed

+21
-16
lines changed

c_CvDTree/TestTree.py

+1
Original file line numberDiff line numberDiff line change
@@ -110,5 +110,6 @@ def main(visualize=True):
110110

111111
tree.show_timing_log()
112112

113+
113114
if __name__ == '__main__':
114115
main(False)

c_CvDTree/Tree.py

+20-16
Original file line numberDiff line numberDiff line change
@@ -188,20 +188,24 @@ def view(self):
188188
self.root.view()
189189

190190
@CvDBaseTiming.timeit(level=2, prefix="[API] ")
191-
def visualize(self, radius=24, width=1200, height=800, padding=0.2, plot_num=30, title="CvDTree"):
191+
def visualize(self, radius=24, width=1200, height=800,
192+
height_padding_ratio=0.2, width_padding=30, title="CvDTree"):
192193
self._update_layers()
193-
units = [len(layer) for layer in self.layers]
194+
n_units = [len(layer) for layer in self.layers]
194195

195196
img = np.ones((height, width, 3), np.uint8) * 255
196-
axis0_padding = int(height / (len(self.layers) - 1 + 2 * padding)) * padding + plot_num
197-
axis0 = np.linspace(
198-
axis0_padding, height - axis0_padding, len(self.layers), dtype=np.int)
199-
axis1_padding = plot_num
200-
axis1 = [np.linspace(axis1_padding, width - axis1_padding, unit + 2, dtype=np.int)
201-
for unit in units]
202-
axis1 = [axis[1:-1] for axis in axis1]
203-
204-
for i, (y, xs) in enumerate(zip(axis0, axis1)):
197+
height_padding = int(
198+
height / (len(self.layers) - 1 + 2 * height_padding_ratio)
199+
) * height_padding_ratio + width_padding
200+
height_axis = np.linspace(
201+
height_padding, height - height_padding, len(self.layers), dtype=np.int)
202+
width_axis = [
203+
np.linspace(width_padding, width - width_padding, unit + 2, dtype=np.int)
204+
for unit in n_units
205+
]
206+
width_axis = [axis[1:-1] for axis in width_axis]
207+
208+
for i, (y, xs) in enumerate(zip(height_axis, width_axis)):
205209
for j, x in enumerate(xs):
206210
if i == 0:
207211
cv2.circle(img, (x, y), radius, (225, 100, 125), 1)
@@ -216,13 +220,13 @@ def visualize(self, radius=24, width=1200, height=800, padding=0.2, plot_num=30,
216220
color = (0, 255, 0)
217221
cv2.putText(img, text, (x-7*len(text)+2, y+3), cv2.LINE_AA, 0.6, color, 1)
218222

219-
for i, y in enumerate(axis0):
220-
if i == len(axis0) - 1:
223+
for i, y in enumerate(height_axis):
224+
if i == len(height_axis) - 1:
221225
break
222-
for j, x in enumerate(axis1[i]):
223-
new_y = axis0[i + 1]
226+
for j, x in enumerate(width_axis[i]):
227+
new_y = height_axis[i + 1]
224228
dy = new_y - y - 2 * radius
225-
for k, new_x in enumerate(axis1[i + 1]):
229+
for k, new_x in enumerate(width_axis[i + 1]):
226230
dx = new_x - x
227231
length = np.sqrt(dx**2+dy**2)
228232
ratio = 0.5 - min(0.4, 1.2 * 24/length)

0 commit comments

Comments
 (0)