diff --git a/src/probabilistic_model/probabilistic_circuit/nx/probabilistic_circuit.py b/src/probabilistic_model/probabilistic_circuit/nx/probabilistic_circuit.py index 675d0f5..e247f5f 100644 --- a/src/probabilistic_model/probabilistic_circuit/nx/probabilistic_circuit.py +++ b/src/probabilistic_model/probabilistic_circuit/nx/probabilistic_circuit.py @@ -1121,10 +1121,10 @@ def unit_positions_for_structure_plot(self) -> Dict[Unit, Tuple[int, int]]: positions = {} for depth, layer in enumerate(layers): number_of_nodes = len(layer) - positions_in_layer = np.linspace(0, maximum_layer_width, number_of_nodes, endpoint=False) + positions_in_layer = np.linspace(0., maximum_layer_width, number_of_nodes, endpoint=False) positions_in_layer += (maximum_layer_width - len(layer)) / (2 * len(layer)) for position, node in zip(positions_in_layer, layer): - positions[node] = (depth, position) + positions[node] = (float(depth), position) return positions @@ -1184,6 +1184,9 @@ def plot_structure(self, node_colors: Optional[Dict[Unit, str]] = None, node_siz # and make the Spines Visibility as False for pos in ['right', 'top', 'bottom', 'left']: plt.gca().spines[pos].set_visible(False) + xticks, xticklabels = plt.xticks() + xmin = (3 * xticks[0] - xticks[1]) / 2. + plt.xlim(xmin, max([x for x, _ in positions.values()]) + 1)