Skip to content

Commit ba5b62d

Browse files
committed
add initial setup for gcs
1 parent 7df99e3 commit ba5b62d

File tree

2 files changed

+44
-14
lines changed

2 files changed

+44
-14
lines changed

tensorflow_io/core/plugins/gs/gcs_filesystem.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1249,6 +1249,11 @@ void ProvideFilesystemSupportFor(TF_FilesystemPluginOps* ops, const char* uri) {
12491249
ops->writable_file_ops = static_cast<TF_WritableFileOps*>(
12501250
plugin_memory_allocate(TF_WRITABLE_FILE_OPS_SIZE));
12511251
ops->writable_file_ops->cleanup = tf_writable_file::Cleanup;
1252+
ops->writable_file_ops->append = tf_writable_file::Append;
1253+
ops->writable_file_ops->tell = tf_writable_file::Tell;
1254+
ops->writable_file_ops->flush = tf_writable_file::Flush;
1255+
ops->writable_file_ops->sync = tf_writable_file::Sync;
1256+
ops->writable_file_ops->close = tf_writable_file::Close;
12521257

12531258
ops->read_only_memory_region_ops = static_cast<TF_ReadOnlyMemoryRegionOps*>(
12541259
plugin_memory_allocate(TF_READ_ONLY_MEMORY_REGION_OPS_SIZE));

tests/test_fs_plugins.py

Lines changed: 39 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,6 @@ def get_uri(key_name):
6161

6262
def read_file(key_name):
6363
response = client.get_object(Bucket=bucket_name, Key=key_name)
64-
print(bucket_name)
65-
print(key_name)
6664
return response["Body"].read()
6765

6866
def write_file(key_name, body):
@@ -71,7 +69,39 @@ def write_file(key_name, body):
7169
return get_uri, read_file, write_file
7270

7371

74-
MAP_URI_HELPER = {S3_URI: s3_init}
72+
def gcs_init(monkeypatch):
73+
from google.cloud import storage
74+
75+
bucket_name = os.environ.get("GCS_TEST_BUCKET")
76+
client = None
77+
bucket = None
78+
if bucket_name is None:
79+
# This means we are running against emulator.
80+
monkeypatch.setenv("STORAGE_EMULATOR_HOST", "http://localhost:9099")
81+
monkeypatch.setenv("CLOUD_STORAGE_EMULATOR_ENDPOINT", "http://localhost:9099")
82+
bucket_name = "tf-io-{}-{}".format(GCS_URI, int(time.time()))
83+
client = storage.Client.create_anonymous_client()
84+
client.project = "test_project"
85+
bucket = client.create_bucket(bucket_name)
86+
else:
87+
client = storage.Client()
88+
bucket = client.get_bucket(bucket)
89+
90+
def get_uri(key_name):
91+
return "{}://{}/{}".format(GCS_URI, bucket_name, key_name)
92+
93+
def read_file(key_name):
94+
blob = bucket.get_blob(key_name)
95+
return blob.download_as_bytes()
96+
97+
def write_file(key_name, body):
98+
blob = bucket.blob(key_name)
99+
blob.upload_from_string(body)
100+
101+
return get_uri, read_file, write_file
102+
103+
104+
MAP_URI_HELPER = {S3_URI: s3_init, GCS_URI: gcs_init}
75105

76106

77107
# ------------------------ URI CONDITION FOR EACH TEST ----------------------- #
@@ -138,13 +168,15 @@ def check_test_condition_and_setup_env(uri, envs, monkeypatch):
138168
pytest.param(HDFS_URI, marks=pytest.mark.skip(reason="TODO")),
139169
pytest.param(VIEWFS_URI, marks=pytest.mark.skip(reason="TODO")),
140170
pytest.param(HAR_URI, marks=pytest.mark.skip(reason="TODO")),
141-
pytest.param(GCS_URI, marks=pytest.mark.skip(reason="TODO")),
171+
GCS_URI,
142172
],
143173
)
144174
def uri_init(request):
145175
uri = request.param
146176
monkeypatch = pytest.MonkeyPatch()
147-
yield uri, *MAP_URI_HELPER[uri](monkeypatch)
177+
178+
get_uri, read_file, write_file = MAP_URI_HELPER[uri](monkeypatch)
179+
yield uri, get_uri, read_file, write_file
148180
monkeypatch.undo()
149181

150182

@@ -186,7 +218,7 @@ def test_io_write_file(uri_init, envs, monkeypatch):
186218
assert read_file(base_file_name) == body
187219

188220

189-
@pytest.mark.parametrize("envs", [("+all", None)])
221+
@pytest.mark.parametrize("envs", [("+all:-gse", None)])
190222
def test_gfile_GFile_readable(uri_init, envs, monkeypatch):
191223
uri, get_uri, _, write_file = uri_init
192224
check_test_condition_and_setup_env(uri, envs, monkeypatch)
@@ -208,7 +240,7 @@ def test_gfile_GFile_readable(uri_init, envs, monkeypatch):
208240
with pytest.raises(tf.errors.NotFoundError) as excinfo:
209241
fname_not_found = fname + "_not_found"
210242
with tf.io.gfile.GFile(fname_not_found, "rb") as f:
211-
pass
243+
_ = f.read()
212244
assert fname_not_found in str(excinfo.value)
213245

214246
# Read length
@@ -271,10 +303,3 @@ def test_gfile_GFile_writable(uri_init, envs, monkeypatch):
271303
f.write(base_body)
272304
f.flush()
273305
assert read_file(base_file_name) == body + base_body
274-
275-
# Notfound
276-
with pytest.raises(tf.errors.NotFoundError) as excinfo:
277-
fname_not_found = fname + "_not_found"
278-
with tf.io.gfile.GFile(fname_not_found, "rb") as f:
279-
pass
280-
assert fname_not_found in str(excinfo.value)

0 commit comments

Comments
 (0)