Skip to content

Commit 72f583d

Browse files
committed
add test for filesystem plugins
1 parent 778ceac commit 72f583d

File tree

1 file changed

+111
-0
lines changed

1 file changed

+111
-0
lines changed

tests/test_fs_plugins.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
4+
# use this file except in compliance with the License. You may obtain a copy of
5+
# the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11+
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12+
# License for the specific language governing permissions and limitations under
13+
# the License.
14+
# ==============================================================================
15+
"""Tests for file system plugins"""
16+
17+
import os
18+
import time
19+
20+
import pytest
21+
import tensorflow as tf
22+
import tensorflow_io as tfio # pylint: disable=unused-import
23+
24+
S3_URI = "s3e"
25+
AZ_URI = "az"
26+
HDFS_URI = "hdfse"
27+
VIEWFS_URI = "viewfse"
28+
HAR_URI = "hare"
29+
GCS_URI = "gse"
30+
31+
32+
def setup_env(uri, envs, monkeypatch):
33+
# when `envs is None`, it is the default case without
34+
# additional envs and should run with every plugins.
35+
additional_env = {}
36+
if envs is not None:
37+
uri_env, additional_env = envs
38+
if uri != uri_env:
39+
pytest.skip()
40+
41+
# ------------------------------------ s3 ------------------------------------ #
42+
if uri == S3_URI:
43+
monkeypatch.setenv("AWS_REGION", "us-east-1")
44+
monkeypatch.setenv("AWS_ACCESS_KEY_ID", "ACCESS_KEY")
45+
monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "SECRET_KEY")
46+
monkeypatch.setenv("S3_ENDPOINT", "localhost:4566")
47+
monkeypatch.setenv("S3_USE_HTTPS", "0")
48+
monkeypatch.setenv("S3_VERIFY_SSL", "0")
49+
50+
for key, value in additional_env.items():
51+
monkeypatch.setenv(key, value)
52+
53+
54+
@pytest.fixture(
55+
scope="session",
56+
autouse=True,
57+
params=[
58+
S3_URI,
59+
pytest.param(AZ_URI, marks=pytest.mark.skip(reason="TODO")),
60+
pytest.param(HDFS_URI, marks=pytest.mark.skip(reason="TODO")),
61+
pytest.param(VIEWFS_URI, marks=pytest.mark.skip(reason="TODO")),
62+
pytest.param(HAR_URI, marks=pytest.mark.skip(reason="TODO")),
63+
pytest.param(GCS_URI, marks=pytest.mark.skip(reason="TODO")),
64+
],
65+
)
66+
def uri_builder(request):
67+
uri = request.param
68+
bucket_name = "{}{}e/tf-io-test".format(uri, time.time())
69+
70+
def get_uri(object_name):
71+
return "{}://{}/{}".format(uri, bucket_name, object_name)
72+
73+
tf.io.gfile.makedirs("{}://{}/".format(uri, bucket_name))
74+
return uri, bucket_name, get_uri
75+
76+
77+
# `@pytest.mark.parametrize("envs", [parameter])` is used to control which filesystem
78+
# should run and the environments during that test. if we pass a parameter `None`, that
79+
# test will run for all filesystems. In addition, parameter should be a tuple `('uri'
80+
# , env_dict )` where `uri` is the uri of the filesystem which we want to run with
81+
# the enviroment variables from `env_dict`.
82+
83+
84+
@pytest.mark.parametrize("envs", [None])
85+
def test_init(uri_builder, envs, monkeypatch):
86+
uri, _, get_uri = uri_builder
87+
setup_env(uri, envs, monkeypatch)
88+
assert tf.io.gfile.exists(get_uri("")) is True
89+
90+
91+
@pytest.mark.parametrize(
92+
"envs", [None, (S3_URI, {"S3_DISABLE_MULTI_PART_DOWNLOAD": "1"})]
93+
)
94+
def test_write_read_file(uri_builder, envs, monkeypatch):
95+
"""Test write/read file."""
96+
uri, _, get_uri = uri_builder
97+
setup_env(uri, envs, monkeypatch)
98+
99+
# Setup and check preconditions.
100+
file_name = get_uri("writereadfile")
101+
if tf.io.gfile.exists(file_name):
102+
tf.io.gfile.remove(file_name)
103+
104+
# Write data.
105+
with tf.io.gfile.GFile(file_name, "w") as w:
106+
w.write("Hello\n, world!")
107+
108+
# Read data.
109+
with tf.io.gfile.GFile(file_name, "r") as r:
110+
file_read = r.read()
111+
assert file_read == "Hello\n, world!"

0 commit comments

Comments
 (0)