1818import static org .hamcrest .Matchers .containsInAnyOrder ;
1919import static org .hamcrest .Matchers .emptyIterable ;
2020import static org .hamcrest .Matchers .equalTo ;
21+ import static org .hamcrest .Matchers .is ;
2122import static org .junit .Assert .assertThat ;
2223import static org .mockito .Mockito .mock ;
2324import static org .mockito .Mockito .when ;
2425
26+ import com .google .cloud .dataflow .sdk .coders .BigEndianLongCoder ;
27+ import com .google .cloud .dataflow .sdk .coders .Coder ;
2528import com .google .cloud .dataflow .sdk .io .BoundedSource ;
29+ import com .google .cloud .dataflow .sdk .io .BoundedSource .BoundedReader ;
2630import com .google .cloud .dataflow .sdk .io .CountingSource ;
2731import com .google .cloud .dataflow .sdk .io .Read ;
2832import com .google .cloud .dataflow .sdk .io .Read .Bounded ;
33+ import com .google .cloud .dataflow .sdk .options .PipelineOptions ;
34+ import com .google .cloud .dataflow .sdk .runners .inprocess .InProcessPipelineRunner .CommittedBundle ;
2935import com .google .cloud .dataflow .sdk .runners .inprocess .InProcessPipelineRunner .UncommittedBundle ;
3036import com .google .cloud .dataflow .sdk .testing .TestPipeline ;
37+ import com .google .cloud .dataflow .sdk .transforms .AppliedPTransform ;
3138import com .google .cloud .dataflow .sdk .transforms .windowing .BoundedWindow ;
3239import com .google .cloud .dataflow .sdk .util .WindowedValue ;
3340import com .google .cloud .dataflow .sdk .values .PCollection ;
41+ import com .google .common .collect .ImmutableList ;
3442
43+ import org .joda .time .Instant ;
3544import org .junit .Before ;
3645import org .junit .Test ;
3746import org .junit .runner .RunWith ;
3847import org .junit .runners .JUnit4 ;
48+ import org .mockito .Mock ;
49+
50+ import java .io .IOException ;
51+ import java .util .Arrays ;
52+ import java .util .List ;
53+ import java .util .NoSuchElementException ;
3954
4055/**
4156 * Tests for {@link BoundedReadEvaluatorFactory}.
@@ -45,7 +60,7 @@ public class BoundedReadEvaluatorFactoryTest {
4560 private BoundedSource <Long > source ;
4661 private PCollection <Long > longs ;
4762 private TransformEvaluatorFactory factory ;
48- private InProcessEvaluationContext context ;
63+ @ Mock private InProcessEvaluationContext context ;
4964
5065 @ Before
5166 public void setup () {
@@ -146,6 +161,125 @@ public void boundedSourceEvaluatorSimultaneousEvaluations() throws Exception {
146161 gw (1L ), gw (2L ), gw (4L ), gw (8L ), gw (9L ), gw (7L ), gw (6L ), gw (5L ), gw (3L ), gw (0L )));
147162 }
148163
164+ @ Test
165+ public void boundedSourceEvaluatorClosesReader () throws Exception {
166+ TestSource <Long > source = new TestSource <>(BigEndianLongCoder .of (), 1L , 2L , 3L );
167+
168+ TestPipeline p = TestPipeline .create ();
169+ PCollection <Long > pcollection = p .apply (Read .from (source ));
170+ AppliedPTransform <?, ?, ?> sourceTransform = pcollection .getProducingTransformInternal ();
171+
172+ UncommittedBundle <Long > output = InProcessBundle .unkeyed (longs );
173+ when (context .createRootBundle (pcollection )).thenReturn (output );
174+
175+ TransformEvaluator <?> evaluator = factory .forApplication (sourceTransform , null , context );
176+ evaluator .finishBundle ();
177+ CommittedBundle <Long > committed = output .commit (Instant .now ());
178+ assertThat (committed .getElements (), containsInAnyOrder (gw (2L ), gw (3L ), gw (1L )));
179+ assertThat (TestSource .readerClosed , is (true ));
180+ }
181+
182+ @ Test
183+ public void boundedSourceEvaluatorNoElementsClosesReader () throws Exception {
184+ TestSource <Long > source = new TestSource <>(BigEndianLongCoder .of ());
185+
186+ TestPipeline p = TestPipeline .create ();
187+ PCollection <Long > pcollection = p .apply (Read .from (source ));
188+ AppliedPTransform <?, ?, ?> sourceTransform = pcollection .getProducingTransformInternal ();
189+
190+ UncommittedBundle <Long > output = InProcessBundle .unkeyed (longs );
191+ when (context .createRootBundle (pcollection )).thenReturn (output );
192+
193+ TransformEvaluator <?> evaluator = factory .forApplication (sourceTransform , null , context );
194+ evaluator .finishBundle ();
195+ CommittedBundle <Long > committed = output .commit (Instant .now ());
196+ assertThat (committed .getElements (), emptyIterable ());
197+ assertThat (TestSource .readerClosed , is (true ));
198+ }
199+
200+ private static class TestSource <T > extends BoundedSource <T > {
201+ private static boolean readerClosed ;
202+ private final Coder <T > coder ;
203+ private final T [] elems ;
204+
205+ public TestSource (Coder <T > coder , T ... elems ) {
206+ this .elems = elems ;
207+ this .coder = coder ;
208+ readerClosed = false ;
209+ }
210+
211+ @ Override
212+ public List <? extends BoundedSource <T >> splitIntoBundles (
213+ long desiredBundleSizeBytes , PipelineOptions options ) throws Exception {
214+ return ImmutableList .of (this );
215+ }
216+
217+ @ Override
218+ public long getEstimatedSizeBytes (PipelineOptions options ) throws Exception {
219+ return 0 ;
220+ }
221+
222+ @ Override
223+ public boolean producesSortedKeys (PipelineOptions options ) throws Exception {
224+ return false ;
225+ }
226+
227+ @ Override
228+ public BoundedSource .BoundedReader <T > createReader (PipelineOptions options ) throws IOException {
229+ return new TestReader <>(this , elems );
230+ }
231+
232+ @ Override
233+ public void validate () {
234+ }
235+
236+ @ Override
237+ public Coder <T > getDefaultOutputCoder () {
238+ return coder ;
239+ }
240+ }
241+
242+ private static class TestReader <T > extends BoundedReader <T > {
243+ private final BoundedSource <T > source ;
244+ private final List <T > elems ;
245+ private int index ;
246+
247+ public TestReader (BoundedSource <T > source , T ... elems ) {
248+ this .source = source ;
249+ this .elems = Arrays .asList (elems );
250+ this .index = -1 ;
251+ }
252+
253+ @ Override
254+ public BoundedSource <T > getCurrentSource () {
255+ return source ;
256+ }
257+
258+ @ Override
259+ public boolean start () throws IOException {
260+ return advance ();
261+ }
262+
263+ @ Override
264+ public boolean advance () throws IOException {
265+ if (elems .size () > index + 1 ) {
266+ index ++;
267+ return true ;
268+ }
269+ return false ;
270+ }
271+
272+ @ Override
273+ public T getCurrent () throws NoSuchElementException {
274+ return elems .get (index );
275+ }
276+
277+ @ Override
278+ public void close () throws IOException {
279+ TestSource .readerClosed = true ;
280+ }
281+ }
282+
149283 private static WindowedValue <Long > gw (Long elem ) {
150284 return WindowedValue .valueInGlobalWindow (elem );
151285 }
0 commit comments