@@ -38,6 +38,7 @@ def setUp(self):
3838 self .test_df = test_utils .get_test_df ()
3939 self .test_region = 'us-central1'
4040 self .test_project = 'foo'
41+ self .test_wheel = '/my/path/wheel.whl'
4142
4243 @mock .patch ('tfrecorder.client.beam_pipeline' )
4344 def test_create_tfrecords_direct_runner (self , mock_beam ):
@@ -71,7 +72,8 @@ def test_create_tfrecords_dataflow_runner(self, mock_beam):
7172 runner = 'DataflowRunner' ,
7273 output_dir = outdir ,
7374 region = self .test_region ,
74- project = self .test_project )
75+ project = self .test_project ,
76+ tfrecorder_wheel = self .test_wheel )
7577 self .assertEqual (r , expected )
7678
7779
@@ -84,6 +86,7 @@ def setUp(self):
8486 self .test_df = test_utils .get_test_df ()
8587 self .test_region = 'us-central1'
8688 self .test_project = 'foo'
89+ self .test_wheel = '/my/path/wheel.whl'
8790
8891 def test_valid_dataframe (self ):
8992 """Tests valid DataFrame input."""
@@ -126,7 +129,8 @@ def test_valid_runner(self):
126129 self .test_df ,
127130 runner = 'DirectRunner' ,
128131 project = self .test_project ,
129- region = self .test_region ))
132+ region = self .test_region ,
133+ tfrecorder_wheel = None ))
130134
131135 def test_invalid_runner (self ):
132136 """Tests invalid runner."""
@@ -135,7 +139,8 @@ def test_invalid_runner(self):
135139 self .test_df ,
136140 runner = 'FooRunner' ,
137141 project = self .test_project ,
138- region = self .test_region )
142+ region = self .test_region ,
143+ tfrecorder_wheel = None )
139144
140145 def test_local_path_with_dataflow_runner (self ):
141146 """Tests DataflowRunner conflict with local path."""
@@ -144,7 +149,8 @@ def test_local_path_with_dataflow_runner(self):
144149 self .df_test ,
145150 runner = 'DataflowRunner' ,
146151 project = self .test_project ,
147- region = self .test_region )
152+ region = self .test_region ,
153+ tfrecorder_wheel = self .test_wheel )
148154
149155 def test_gcs_path_with_dataflow_runner (self ):
150156 """Tests DataflowRunner with GCS path."""
@@ -155,7 +161,8 @@ def test_gcs_path_with_dataflow_runner(self):
155161 df2 ,
156162 runner = 'DataflowRunner' ,
157163 project = self .test_project ,
158- region = self .test_region ))
164+ region = self .test_region ,
165+ tfrecorder_wheel = self .test_wheel ))
159166
160167 def test_gcs_path_with_dataflow_runner_missing_param (self ):
161168 """Tests DataflowRunner with missing required parameter."""
@@ -168,11 +175,27 @@ def test_gcs_path_with_dataflow_runner_missing_param(self):
168175 df2 ,
169176 runner = 'DataflowRunner' ,
170177 project = p ,
171- region = r )
178+ region = r ,
179+ tfrecorder_wheel = self .test_wheel )
172180 self .assertTrue ('DataflowRunner requires valid `project` and `region`'
173181 in repr (context .exception ))
174182
175183
184+ def test_gcs_path_with_dataflow_runner_missing_wheel (self ):
185+ """Tests DataflowRunner with missing required whl path."""
186+ df2 = self .test_df .copy ()
187+ df2 [constants .IMAGE_URI_KEY ] = 'gs://' + df2 [constants .IMAGE_URI_KEY ]
188+ with self .assertRaises (AttributeError ) as context :
189+ client ._validate_runner (
190+ df2 ,
191+ runner = 'DataflowRunner' ,
192+ project = self .test_project ,
193+ region = self .test_region ,
194+ tfrecorder_wheel = None )
195+ self .assertTrue ('requires a tfrecorder whl file for remote execution.'
196+ in repr (context .exception ))
197+
198+
176199def _make_csv_tempfile (data : List [List [str ]]) -> tempfile .NamedTemporaryFile :
177200 """Returns `NamedTemporaryFile` representing an image CSV."""
178201
0 commit comments