@@ -38,6 +38,7 @@ def setUp(self):
3838    self .test_df  =  test_utils .get_test_df ()
3939    self .test_region  =  'us-central1' 
4040    self .test_project  =  'foo' 
41+     self .test_wheel  =  '/my/path/wheel.whl' 
4142
4243  @mock .patch ('tfrecorder.client.beam_pipeline' ) 
4344  def  test_create_tfrecords_direct_runner (self , mock_beam ):
@@ -71,7 +72,8 @@ def test_create_tfrecords_dataflow_runner(self, mock_beam):
7172        runner = 'DataflowRunner' ,
7273        output_dir = outdir ,
7374        region = self .test_region ,
74-         project = self .test_project )
75+         project = self .test_project ,
76+         tfrecorder_wheel = self .test_wheel )
7577    self .assertEqual (r , expected )
7678
7779
@@ -84,6 +86,7 @@ def setUp(self):
8486    self .test_df  =  test_utils .get_test_df ()
8587    self .test_region  =  'us-central1' 
8688    self .test_project  =  'foo' 
89+     self .test_wheel  =  '/my/path/wheel.whl' 
8790
8891  def  test_valid_dataframe (self ):
8992    """Tests valid DataFrame input.""" 
@@ -126,7 +129,8 @@ def test_valid_runner(self):
126129        self .test_df ,
127130        runner = 'DirectRunner' ,
128131        project = self .test_project ,
129-         region = self .test_region ))
132+         region = self .test_region ,
133+         tfrecorder_wheel = None ))
130134
131135  def  test_invalid_runner (self ):
132136    """Tests invalid runner.""" 
@@ -135,7 +139,8 @@ def test_invalid_runner(self):
135139          self .test_df ,
136140          runner = 'FooRunner' ,
137141          project = self .test_project ,
138-           region = self .test_region )
142+           region = self .test_region ,
143+           tfrecorder_wheel = None )
139144
140145  def  test_local_path_with_dataflow_runner (self ):
141146    """Tests DataflowRunner conflict with local path.""" 
@@ -144,7 +149,8 @@ def test_local_path_with_dataflow_runner(self):
144149          self .df_test ,
145150          runner = 'DataflowRunner' ,
146151          project = self .test_project ,
147-           region = self .test_region )
152+           region = self .test_region ,
153+           tfrecorder_wheel = self .test_wheel )
148154
149155  def  test_gcs_path_with_dataflow_runner (self ):
150156    """Tests DataflowRunner with GCS path.""" 
@@ -155,7 +161,8 @@ def test_gcs_path_with_dataflow_runner(self):
155161            df2 ,
156162            runner = 'DataflowRunner' ,
157163            project = self .test_project ,
158-             region = self .test_region ))
164+             region = self .test_region ,
165+             tfrecorder_wheel = self .test_wheel ))
159166
160167  def  test_gcs_path_with_dataflow_runner_missing_param (self ):
161168    """Tests DataflowRunner with missing required parameter.""" 
@@ -168,11 +175,27 @@ def test_gcs_path_with_dataflow_runner_missing_param(self):
168175            df2 ,
169176            runner = 'DataflowRunner' ,
170177            project = p ,
171-             region = r )
178+             region = r ,
179+             tfrecorder_wheel = self .test_wheel )
172180      self .assertTrue ('DataflowRunner requires valid `project` and `region`' 
173181                      in  repr (context .exception ))
174182
175183
184+   def  test_gcs_path_with_dataflow_runner_missing_wheel (self ):
185+     """Tests DataflowRunner with missing required whl path.""" 
186+     df2  =  self .test_df .copy ()
187+     df2 [constants .IMAGE_URI_KEY ] =  'gs://'  +  df2 [constants .IMAGE_URI_KEY ]
188+     with  self .assertRaises (AttributeError ) as  context :
189+       client ._validate_runner (
190+           df2 ,
191+           runner = 'DataflowRunner' ,
192+           project = self .test_project ,
193+           region = self .test_region ,
194+           tfrecorder_wheel = None )
195+       self .assertTrue ('requires a tfrecorder whl file for remote execution.' 
196+                       in  repr (context .exception ))
197+ 
198+ 
176199def  _make_csv_tempfile (data : List [List [str ]]) ->  tempfile .NamedTemporaryFile :
177200  """Returns `NamedTemporaryFile` representing an image CSV.""" 
178201
0 commit comments