Skip to content

Commit

Permalink
use rasterio AWS session when reading files on s3 (#2197)
Browse files Browse the repository at this point in the history
  • Loading branch information
AdeelH authored Jul 26, 2024
1 parent a25fa34 commit 8ab5837
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 25 deletions.
23 changes: 12 additions & 11 deletions rastervision_aws_s3/rastervision/aws_s3/s3_file_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from urllib.parse import urlparse

import boto3
from everett.manager import ConfigurationMissingError
from tqdm.auto import tqdm

from rastervision.pipeline.file_system import (FileSystem, NotReadableError,
Expand Down Expand Up @@ -102,16 +101,18 @@ class S3FileSystem(FileSystem):
"""

@staticmethod
def get_request_payer():
# Import here to avoid circular reference.
from rastervision.pipeline import rv_config_ as rv_config
try:
s3_config = rv_config.get_namespace_config(AWS_S3)
# 'None' needs the quotes because boto3 cannot handle None.
return ('requester' if s3_config(
'requester_pays', parser=bool, default='False') else 'None')
except ConfigurationMissingError:
return 'None'
def get_request_payer() -> str:
# attempt to get from environ
request_payer = os.getenv('AWS_REQUEST_PAYER', 'None')
# attempt to get from RV config
if request_payer == 'None':
# Import here to avoid circular reference.
from rastervision.pipeline import rv_config_ as rv_config
requester_pays = rv_config.get_namespace_option(
AWS_S3, 'requester_pays', as_bool=True)
if requester_pays:
request_payer = 'requester'
return request_payer

@staticmethod
def get_session():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging

import numpy as np
import rasterio
import rasterio as rio

from rastervision.pipeline.file_system import download_if_needed, get_tmp_dir
from rastervision.core.box import Box
Expand All @@ -11,8 +11,8 @@
from rastervision.core.data.utils import (listify_uris, parse_array_slices_Nd)
from rastervision.core.data.utils.raster import fill_overflow
from rastervision.core.data.utils.rasterio import (
read_window, get_channel_order_from_dataset, download_and_build_vrt,
is_masked)
download_and_build_vrt, get_aws_session, get_channel_order_from_dataset,
is_masked, read_window)

if TYPE_CHECKING:
from rastervision.core.data import RasterTransformer
Expand Down Expand Up @@ -76,7 +76,11 @@ def __init__(self,

self.imagery_path = self.download_data(
self.tmp_dir, stream=self.allow_streaming)
self.image_dataset = rasterio.open(self.imagery_path)
self.session = None
if 's3://' in self.imagery_path.lower():
self.session = get_aws_session()
with rio.Env(session=self.session):
self.image_dataset = rio.open(self.imagery_path)

block_shapes = set(self.image_dataset.block_shapes)
if len(block_shapes) > 1:
Expand Down Expand Up @@ -165,7 +169,8 @@ def _get_chip(self,
bands=bands,
window=window.rasterio_format(),
is_masked=self.is_masked,
out_shape=out_shape)
out_shape=out_shape,
session=self.session)
chip = fill_overflow(self.bbox, window, chip)
return chip

Expand Down
33 changes: 25 additions & 8 deletions rastervision_core/rastervision/core/data/utils/rasterio.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import rasterio.windows as rio_windows
from rasterio.transform import from_origin
from rasterio.enums import (ColorInterp, MaskFlags, Resampling)
from rasterio.session import AWSSession

from rastervision.pipeline.file_system.utils import (
file_to_json, get_local_path, get_tmp_dir, make_dir, upload_or_copy,
Expand All @@ -16,6 +17,7 @@

if TYPE_CHECKING:
from rasterio.io import DatasetReader
from rasterio.session import Session

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -168,7 +170,8 @@ def read_window(dataset: 'DatasetReader',
bands: int | Sequence[int] | None = None,
window: tuple[tuple[int, int], tuple[int, int]] | None = None,
is_masked: bool = False,
out_shape: tuple[int, ...] | None = None) -> np.ndarray:
out_shape: tuple[int, ...] | None = None,
session: 'Session | None' = None) -> np.ndarray:
"""Load a window of an image using Rasterio.
Args:
Expand All @@ -181,19 +184,21 @@ def read_window(dataset: 'DatasetReader',
Defaults to ``False``.
out_shape: (height, width) of the output chip. If ``None``, no
resizing is done. Defaults to ``None``.
session: Rasterio :class:`.Session`.
Returns:
np.ndarray: array of shape (height, width, channels).
"""
if bands is not None:
bands = tuple(bands)
im = dataset.read(
indexes=bands,
window=window,
boundless=True,
masked=is_masked,
out_shape=out_shape,
resampling=Resampling.bilinear)
with rio.Env(session=session):
im = dataset.read(
indexes=bands,
window=window,
boundless=True,
masked=is_masked,
out_shape=out_shape,
resampling=Resampling.bilinear)

if is_masked:
im = np.ma.filled(im, fill_value=0)
Expand Down Expand Up @@ -241,3 +246,15 @@ def is_masked(dataset: 'DatasetReader') -> bool:
mask_flags = dataset.mask_flag_enums
is_masked = any(m for m in mask_flags if m != MaskFlags.all_valid)
return is_masked


def get_aws_session() -> 'Session':
"""Build a rasterio AWS session from environment variables."""
try:
from rastervision.aws_s3 import S3FileSystem
requester_pays = S3FileSystem.get_request_payer()
except ModuleNotFoundError:
requester_pays = os.getenv('AWS_REQUEST_PAYER',
'').lower() == 'requestor'
session = AWSSession.from_environ(requester_pays=requester_pays)
return session
15 changes: 15 additions & 0 deletions tests/aws_s3/test_s3_file_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,21 @@ def test_get_client_unsigned(self):
s3 = S3FileSystem.get_client()
self.assertEqual(s3._client_config.signature_version, UNSIGNED)

@patch.dict('os.environ', AWS_REQUEST_PAYER='requester')
def test_get_request_payer_env(self):
request_payer = S3FileSystem.get_request_payer()
self.assertEqual(request_payer, 'requester')

@patch.dict('os.environ', AWS_S3_REQUESTER_PAYS='yes')
def test_get_request_payer_rvconfig_env_true(self):
request_payer = S3FileSystem.get_request_payer()
self.assertEqual(request_payer, 'requester')

@patch.dict('os.environ', AWS_S3_REQUESTER_PAYS='false')
def test_get_request_payer_rvconfig_env_false(self):
request_payer = S3FileSystem.get_request_payer()
self.assertEqual(request_payer, 'None')


if __name__ == '__main__':
unittest.main()
14 changes: 13 additions & 1 deletion tests/core/data/utils/test_rasterio.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import unittest
from unittest.mock import patch
from os.path import join

import numpy as np
Expand All @@ -7,7 +8,7 @@
from rastervision.pipeline.file_system.utils import get_tmp_dir
from rastervision.core.box import Box
from rastervision.core.data.utils.rasterio import (
crop_geotiff, write_geotiff_like_geojson, write_bbox)
crop_geotiff, get_aws_session, write_geotiff_like_geojson, write_bbox)
from rastervision.core.data import RasterioSource, GeoJSONVectorSource
from tests import data_file_path

Expand Down Expand Up @@ -63,6 +64,17 @@ def test_write_geotiff_like_geojson(self):
decimal=3)
self.assertEqual(rs.shape, (10, 10, 1))

@patch.dict('os.environ', AWS_REQUEST_PAYER='requester')
def test_get_aws_session(self):
session = get_aws_session()
self.assertTrue(session.requester_pays)

@patch.dict('sys.modules', {'rastervision.aws_s3': None})
@patch.dict('os.environ', AWS_S3_REQUESTER_PAYS='false')
def test_get_aws_session_no_rv_aws_s3_and_not_requester_pays(self):
session = get_aws_session()
self.assertFalse(session.requester_pays)


if __name__ == '__main__':
unittest.main()

0 comments on commit 8ab5837

Please sign in to comment.