diff --git a/contextily/plotting.py b/contextily/plotting.py index b5e62a50..b316b1e6 100644 --- a/contextily/plotting.py +++ b/contextily/plotting.py @@ -12,7 +12,7 @@ def add_basemap(ax, zoom=ZOOM, url=sources.ST_TERRAIN, interpolation=INTERPOLATION, attribution = ATTRIBUTION, - **extra_imshow_args): + reset_extent=True, **extra_imshow_args): """ Add a (web/local) basemap to `ax` ... @@ -38,6 +38,10 @@ def add_basemap(ax, zoom=ZOOM, url=sources.ST_TERRAIN, attribution : str [Optional. Defaults to standard `ATTRIBUTION`] Text to be added at the bottom of the axis. + reset_extent : Boolean + [Optiona. Default=True] If True, the extent of the + basemap added is reset to the original extent (xlim, + ylim) of `ax` **extra_imshow_args : dict Other parameters to be passed to `imshow`. @@ -68,11 +72,12 @@ def add_basemap(ax, zoom=ZOOM, url=sources.ST_TERRAIN, >>> plt.show() """ + xmin, xmax, ymin, ymax = ax.axis() + # If web source if url[:4] == 'http': # Extent - left, right = ax.get_xlim() - bottom, top = ax.get_ylim() + left, right, bottom, top = xmin, xmax, ymin, ymax # Zoom if isinstance(zoom, str) and (zoom.lower() == 'auto'): min_ll = _sm2ll(left, bottom) @@ -92,8 +97,13 @@ def add_basemap(ax, zoom=ZOOM, url=sources.ST_TERRAIN, # Plotting ax.imshow(image, extent=extent, interpolation=interpolation, **extra_imshow_args) + + if reset_extent is True: + ax.axis((xmin, xmax, ymin, ymax)) + if attribution: add_attribution(ax, attribution) + return ax def add_attribution(ax, att=ATTRIBUTION): diff --git a/tests/test_ctx.py b/tests/test_ctx.py index 639d30a7..ee588591 100644 --- a/tests/test_ctx.py +++ b/tests/test_ctx.py @@ -154,11 +154,13 @@ def test_add_basemap(): ax.set_ylim(y1, y2) ax = ctx.add_basemap(ax, zoom=10) - ax_extent = (-11740727.544603072, -11662456.027639052, - 4852834.0517692715, 4891969.810251278) - assert_array_almost_equal(ax_extent, ax.images[0].get_extent()) - assert ax.images[0].get_array().sum() == 75853866 - assert ax.images[0].get_array().shape == (256, 512, 3) + ax_extent = (x1, x2, y1, y2) + + # ensure add_basemap did not change the axis limits of ax + assert ax.axis() == ax_extent + + assert ax.images[0].get_array().sum() == 75687792 + assert ax.images[0].get_array().shape == (256, 511, 3) assert_array_almost_equal(ax.images[0].get_array().mean(), 192.90635681152344)