Skip to content

improvements for s3 environements variables #1343

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 29, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 25 additions & 40 deletions tensorflow_io/core/plugins/s3/s3_filesystem.cc
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,6 @@ static Aws::Client::ClientConfiguration& GetDefaultClientConfig() {
absl::MutexLock l(&cfg_lock);

if (!init) {
const char* endpoint = getenv("S3_ENDPOINT");
if (endpoint) cfg.endpointOverride = Aws::String(endpoint);
const char* region = getenv("AWS_REGION");
// TODO (yongtang): `S3_REGION` should be deprecated after 2.0.
if (!region) region = getenv("S3_REGION");
Expand Down Expand Up @@ -168,20 +166,6 @@ static Aws::Client::ClientConfiguration& GetDefaultClientConfig() {
cfg.region = profiles["default"].GetRegion();
}
}
const char* use_https = getenv("S3_USE_HTTPS");
if (use_https) {
if (use_https[0] == '0')
cfg.scheme = Aws::Http::Scheme::HTTP;
else
cfg.scheme = Aws::Http::Scheme::HTTPS;
}
const char* verify_ssl = getenv("S3_VERIFY_SSL");
if (verify_ssl) {
if (verify_ssl[0] == '0')
cfg.verifySSL = false;
else
cfg.verifySSL = true;
}
// if these timeouts are low, you may see an error when
// uploading/downloading large files: Unable to connect to endpoint
int64_t timeout;
Expand Down Expand Up @@ -241,6 +225,13 @@ static void GetS3Client(tf_s3_filesystem::S3File* s3_file) {
tf_s3_filesystem::AWSLogSystem::ShutdownAWSLogging();
}
});

int temp_value;
if (absl::SimpleAtoi(getenv("S3_DISABLE_MULTI_PART_DOWNLOAD"), &temp_value))
s3_file->use_multi_part_download = (temp_value != 1);

const char* endpoint = getenv("S3_ENDPOINT");
if (endpoint) s3_file->s3_client->OverrideEndpoint(endpoint);
}
}

Expand All @@ -263,15 +254,26 @@ static void GetTransferManager(

absl::MutexLock l(&s3_file->initialization_lock);

if (s3_file->transfer_managers[direction].get() == nullptr) {
if (s3_file->transfer_managers.count(direction) == 0) {
uint64_t temp_value;
if (direction == Aws::Transfer::TransferDirection::UPLOAD) {
if (!absl::SimpleAtoi(getenv("S3_MULTI_PART_UPLOAD_CHUNK_SIZE"),
&temp_value))
temp_value = kS3MultiPartUploadChunkSize;
} else if (direction == Aws::Transfer::TransferDirection::DOWNLOAD) {
if (!absl::SimpleAtoi(getenv("S3_MULTI_PART_DOWNLOAD_CHUNK_SIZE"),
&temp_value))
temp_value = kS3MultiPartDownloadChunkSize;
}
s3_file->multi_part_chunk_sizes.emplace(direction, temp_value);

Aws::Transfer::TransferManagerConfiguration config(s3_file->executor.get());
config.s3Client = s3_file->s3_client;
config.bufferSize = s3_file->multi_part_chunk_sizes[direction];
config.bufferSize = temp_value;
// must be larger than pool size * multi part chunk size
config.transferBufferMaxHeapSize =
(kExecutorPoolSize + 1) * s3_file->multi_part_chunk_sizes[direction];
s3_file->transfer_managers[direction] =
Aws::Transfer::TransferManager::Create(config);
config.transferBufferMaxHeapSize = (kExecutorPoolSize + 1) * temp_value;
s3_file->transfer_managers.emplace(
direction, Aws::Transfer::TransferManager::Create(config));
}
}

Expand Down Expand Up @@ -529,24 +531,7 @@ S3File::S3File()
transfer_managers(),
multi_part_chunk_sizes(),
use_multi_part_download(false), // TODO: change to true after fix
initialization_lock() {
uint64_t temp_value;
multi_part_chunk_sizes[Aws::Transfer::TransferDirection::UPLOAD] =
absl::SimpleAtoi(getenv("S3_MULTI_PART_UPLOAD_CHUNK_SIZE"), &temp_value)
? temp_value
: kS3MultiPartUploadChunkSize;
multi_part_chunk_sizes[Aws::Transfer::TransferDirection::DOWNLOAD] =
absl::SimpleAtoi(getenv("S3_MULTI_PART_DOWNLOAD_CHUNK_SIZE"), &temp_value)
? temp_value
: kS3MultiPartDownloadChunkSize;
use_multi_part_download =
absl::SimpleAtoi(getenv("S3_DISABLE_MULTI_PART_DOWNLOAD"), &temp_value)
? (temp_value != 1)
: use_multi_part_download;
transfer_managers.emplace(Aws::Transfer::TransferDirection::UPLOAD, nullptr);
transfer_managers.emplace(Aws::Transfer::TransferDirection::DOWNLOAD,
nullptr);
}
initialization_lock() {}
void Init(TF_Filesystem* filesystem, TF_Status* status) {
filesystem->plugin_filesystem = new S3File();
TF_SetStatus(status, TF_OK, "");
Expand Down
7 changes: 4 additions & 3 deletions tensorflow_io/core/plugins/s3/s3_filesystem.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,12 @@ typedef struct S3File {
std::shared_ptr<Aws::S3::S3Client> s3_client;
std::shared_ptr<Aws::Utils::Threading::PooledThreadExecutor> executor;
// We need 2 `TransferManager`, for multipart upload/download.
Aws::Map<Aws::Transfer::TransferDirection,
std::shared_ptr<Aws::Transfer::TransferManager>>
Aws::UnorderedMap<Aws::Transfer::TransferDirection,
std::shared_ptr<Aws::Transfer::TransferManager>>
transfer_managers;
// Sizes to split objects during multipart upload/download.
Aws::Map<Aws::Transfer::TransferDirection, uint64_t> multi_part_chunk_sizes;
Aws::UnorderedMap<Aws::Transfer::TransferDirection, uint64_t>
multi_part_chunk_sizes;
bool use_multi_part_download;
absl::Mutex initialization_lock;
S3File();
Expand Down
4 changes: 1 addition & 3 deletions tests/test_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,7 @@ def test_read_file():
response = client.get_object(Bucket=bucket_name, Key=key_name)
assert response["Body"].read() == body

os.environ["S3_ENDPOINT"] = "localhost:4566"
os.environ["S3_USE_HTTPS"] = "0"
os.environ["S3_VERIFY_SSL"] = "0"
os.environ["S3_ENDPOINT"] = "http://localhost:4566"

content = tf.io.read_file("s3://{}/{}".format(bucket_name, key_name))
assert content == body