diff --git a/rastervision_aws_s3/rastervision/aws_s3/s3_file_system.py b/rastervision_aws_s3/rastervision/aws_s3/s3_file_system.py index 12cacfa6c..e45ac6c65 100644 --- a/rastervision_aws_s3/rastervision/aws_s3/s3_file_system.py +++ b/rastervision_aws_s3/rastervision/aws_s3/s3_file_system.py @@ -6,7 +6,6 @@ from urllib.parse import urlparse import boto3 -from everett.manager import ConfigurationMissingError from tqdm.auto import tqdm from rastervision.pipeline.file_system import (FileSystem, NotReadableError, @@ -102,16 +101,18 @@ class S3FileSystem(FileSystem): """ @staticmethod - def get_request_payer(): - # Import here to avoid circular reference. - from rastervision.pipeline import rv_config_ as rv_config - try: - s3_config = rv_config.get_namespace_config(AWS_S3) - # 'None' needs the quotes because boto3 cannot handle None. - return ('requester' if s3_config( - 'requester_pays', parser=bool, default='False') else 'None') - except ConfigurationMissingError: - return 'None' + def get_request_payer() -> str: + # attempt to get from environ + request_payer = os.getenv('AWS_REQUEST_PAYER', 'None') + # attempt to get from RV config + if request_payer == 'None': + # Import here to avoid circular reference. + from rastervision.pipeline import rv_config_ as rv_config + requester_pays = rv_config.get_namespace_option( + AWS_S3, 'requester_pays', as_bool=True) + if requester_pays: + request_payer = 'requester' + return request_payer @staticmethod def get_session(): diff --git a/rastervision_core/rastervision/core/data/raster_source/rasterio_source.py b/rastervision_core/rastervision/core/data/raster_source/rasterio_source.py index 6c8d3fd29..a3ebb637e 100644 --- a/rastervision_core/rastervision/core/data/raster_source/rasterio_source.py +++ b/rastervision_core/rastervision/core/data/raster_source/rasterio_source.py @@ -2,7 +2,7 @@ import logging import numpy as np -import rasterio +import rasterio as rio from rastervision.pipeline.file_system import download_if_needed, get_tmp_dir from rastervision.core.box import Box @@ -11,8 +11,8 @@ from rastervision.core.data.utils import (listify_uris, parse_array_slices_Nd) from rastervision.core.data.utils.raster import fill_overflow from rastervision.core.data.utils.rasterio import ( - read_window, get_channel_order_from_dataset, download_and_build_vrt, - is_masked) + download_and_build_vrt, get_aws_session, get_channel_order_from_dataset, + is_masked, read_window) if TYPE_CHECKING: from rastervision.core.data import RasterTransformer @@ -76,7 +76,11 @@ def __init__(self, self.imagery_path = self.download_data( self.tmp_dir, stream=self.allow_streaming) - self.image_dataset = rasterio.open(self.imagery_path) + self.session = None + if 's3://' in self.imagery_path.lower(): + self.session = get_aws_session() + with rio.Env(session=self.session): + self.image_dataset = rio.open(self.imagery_path) block_shapes = set(self.image_dataset.block_shapes) if len(block_shapes) > 1: @@ -165,7 +169,8 @@ def _get_chip(self, bands=bands, window=window.rasterio_format(), is_masked=self.is_masked, - out_shape=out_shape) + out_shape=out_shape, + session=self.session) chip = fill_overflow(self.bbox, window, chip) return chip diff --git a/rastervision_core/rastervision/core/data/utils/rasterio.py b/rastervision_core/rastervision/core/data/utils/rasterio.py index 00a3ac4df..4fb31f32a 100644 --- a/rastervision_core/rastervision/core/data/utils/rasterio.py +++ b/rastervision_core/rastervision/core/data/utils/rasterio.py @@ -8,6 +8,7 @@ import rasterio.windows as rio_windows from rasterio.transform import from_origin from rasterio.enums import (ColorInterp, MaskFlags, Resampling) +from rasterio.session import AWSSession from rastervision.pipeline.file_system.utils import ( file_to_json, get_local_path, get_tmp_dir, make_dir, upload_or_copy, @@ -16,6 +17,7 @@ if TYPE_CHECKING: from rasterio.io import DatasetReader + from rasterio.session import Session log = logging.getLogger(__name__) @@ -168,7 +170,8 @@ def read_window(dataset: 'DatasetReader', bands: int | Sequence[int] | None = None, window: tuple[tuple[int, int], tuple[int, int]] | None = None, is_masked: bool = False, - out_shape: tuple[int, ...] | None = None) -> np.ndarray: + out_shape: tuple[int, ...] | None = None, + session: 'Session | None' = None) -> np.ndarray: """Load a window of an image using Rasterio. Args: @@ -181,19 +184,21 @@ def read_window(dataset: 'DatasetReader', Defaults to ``False``. out_shape: (height, width) of the output chip. If ``None``, no resizing is done. Defaults to ``None``. + session: Rasterio :class:`.Session`. Returns: np.ndarray: array of shape (height, width, channels). """ if bands is not None: bands = tuple(bands) - im = dataset.read( - indexes=bands, - window=window, - boundless=True, - masked=is_masked, - out_shape=out_shape, - resampling=Resampling.bilinear) + with rio.Env(session=session): + im = dataset.read( + indexes=bands, + window=window, + boundless=True, + masked=is_masked, + out_shape=out_shape, + resampling=Resampling.bilinear) if is_masked: im = np.ma.filled(im, fill_value=0) @@ -241,3 +246,15 @@ def is_masked(dataset: 'DatasetReader') -> bool: mask_flags = dataset.mask_flag_enums is_masked = any(m for m in mask_flags if m != MaskFlags.all_valid) return is_masked + + +def get_aws_session() -> 'Session': + """Build a rasterio AWS session from environment variables.""" + try: + from rastervision.aws_s3 import S3FileSystem + requester_pays = S3FileSystem.get_request_payer() + except ModuleNotFoundError: + requester_pays = os.getenv('AWS_REQUEST_PAYER', + '').lower() == 'requestor' + session = AWSSession.from_environ(requester_pays=requester_pays) + return session diff --git a/tests/aws_s3/test_s3_file_system.py b/tests/aws_s3/test_s3_file_system.py index d2e54283a..6970ead37 100644 --- a/tests/aws_s3/test_s3_file_system.py +++ b/tests/aws_s3/test_s3_file_system.py @@ -12,6 +12,21 @@ def test_get_client_unsigned(self): s3 = S3FileSystem.get_client() self.assertEqual(s3._client_config.signature_version, UNSIGNED) + @patch.dict('os.environ', AWS_REQUEST_PAYER='requester') + def test_get_request_payer_env(self): + request_payer = S3FileSystem.get_request_payer() + self.assertEqual(request_payer, 'requester') + + @patch.dict('os.environ', AWS_S3_REQUESTER_PAYS='yes') + def test_get_request_payer_rvconfig_env_true(self): + request_payer = S3FileSystem.get_request_payer() + self.assertEqual(request_payer, 'requester') + + @patch.dict('os.environ', AWS_S3_REQUESTER_PAYS='false') + def test_get_request_payer_rvconfig_env_false(self): + request_payer = S3FileSystem.get_request_payer() + self.assertEqual(request_payer, 'None') + if __name__ == '__main__': unittest.main() diff --git a/tests/core/data/utils/test_rasterio.py b/tests/core/data/utils/test_rasterio.py index 37e754eac..100514b95 100644 --- a/tests/core/data/utils/test_rasterio.py +++ b/tests/core/data/utils/test_rasterio.py @@ -1,4 +1,5 @@ import unittest +from unittest.mock import patch from os.path import join import numpy as np @@ -7,7 +8,7 @@ from rastervision.pipeline.file_system.utils import get_tmp_dir from rastervision.core.box import Box from rastervision.core.data.utils.rasterio import ( - crop_geotiff, write_geotiff_like_geojson, write_bbox) + crop_geotiff, get_aws_session, write_geotiff_like_geojson, write_bbox) from rastervision.core.data import RasterioSource, GeoJSONVectorSource from tests import data_file_path @@ -63,6 +64,17 @@ def test_write_geotiff_like_geojson(self): decimal=3) self.assertEqual(rs.shape, (10, 10, 1)) + @patch.dict('os.environ', AWS_REQUEST_PAYER='requester') + def test_get_aws_session(self): + session = get_aws_session() + self.assertTrue(session.requester_pays) + + @patch.dict('sys.modules', {'rastervision.aws_s3': None}) + @patch.dict('os.environ', AWS_S3_REQUESTER_PAYS='false') + def test_get_aws_session_no_rv_aws_s3_and_not_requester_pays(self): + session = get_aws_session() + self.assertFalse(session.requester_pays) + if __name__ == '__main__': unittest.main()