1717package com .google .cloud .examples .nio ;
1818
1919import com .google .common .base .Stopwatch ;
20+ import com .google .common .io .BaseEncoding ;
2021
21- import javax .xml .bind .annotation .adapters .HexBinaryAdapter ;
2222import java .io .IOException ;
2323import java .net .URI ;
2424import java .nio .ByteBuffer ;
2727import java .nio .file .Path ;
2828import java .nio .file .Paths ;
2929import java .security .MessageDigest ;
30+ import java .util .ArrayDeque ;
31+ import java .util .Queue ;
32+ import java .util .concurrent .Callable ;
33+ import java .util .concurrent .ExecutorService ;
34+ import java .util .concurrent .Executors ;
35+ import java .util .concurrent .Future ;
3036import java .util .concurrent .TimeUnit ;
3137
3238/**
3339 * ParallelCountBytes will read through the whole file given as input.
3440 *
3541 * <p>This example shows how to go through all the contents of a file,
36- * in order, using multithreaded NIO reads.It also reports how long it took.
42+ * in order, using multithreaded NIO reads.
43+ * It prints a MD5 hash and reports how long it took.
3744 *
3845 * <p>See the README for compilation instructions. Run this code with
3946 * {@code target/appassembler/bin/ParallelCountBytes <file>}
4047 */
4148public class ParallelCountBytes {
4249
43- private class BufWithLock {
44- public Object lock ;
45- public ByteBuffer buf ;
46- public boolean full ;
47- public Thread t ;
50+ /**
51+ * WorkUnit holds a buffer and the instructions for what to put in it.
52+ */
53+ private class WorkUnit implements Callable <WorkUnit > {
54+ public final ByteBuffer buf ;
55+ final SeekableByteChannel chan ;
56+ final int blockSize ;
57+ int blockIndex ;
4858
49- public BufWithLock (int size ) {
50- this .buf = ByteBuffer .allocate (size );
51- this .lock = new Object ();
59+ public WorkUnit (SeekableByteChannel chan , int blockSize , int blockIndex ) {
60+ this .chan = chan ;
61+ this .buf = ByteBuffer .allocate (blockSize );
62+ this .blockSize = blockSize ;
63+ this .blockIndex = blockIndex ;
64+ }
65+
66+ @ Override
67+ public WorkUnit call () throws IOException {
68+ int pos = blockSize * blockIndex ;
69+ if (pos > chan .size ()) {
70+ return this ;
71+ }
72+ chan .position (pos );
73+ // read until buffer is full, or EOF
74+ while (chan .read (buf ) > 0 ) {};
75+ return this ;
76+ }
77+
78+ public WorkUnit resetForIndex (int blockIndex ) {
79+ this .blockIndex = blockIndex ;
80+ buf .flip ();
81+ return this ;
5282 }
5383 }
5484
@@ -69,37 +99,6 @@ public void start(String[] args) throws IOException {
6999 }
70100 }
71101
72- private void stridedRead (SeekableByteChannel chan , int blockSize , int firstBlock , int stride , BufWithLock output ) {
73- try {
74- // stagger the threads a little bit.
75- Thread .sleep (250 * firstBlock );
76- long pos = firstBlock * blockSize ;
77- synchronized (output .lock ) {
78- while (true ) {
79- if (pos > chan .size ()) {
80- break ;
81- }
82- chan .position (pos );
83- // read until buffer is full, or EOF
84- while (chan .read (output .buf ) > 0 ) {};
85- output .full = true ;
86- output .lock .notifyAll ();
87- if (output .buf .hasRemaining ()) {
88- break ;
89- }
90- // wait for main thread to process it
91- while (output .full ) {
92- output .lock .wait ();
93- }
94- output .buf .flip ();
95- pos += stride * blockSize ;
96- }
97- }
98- } catch (InterruptedException | IOException o ) {
99- // this simple example doesn't handle errors, sorry.
100- }
101- }
102-
103102 /**
104103 * Print the length of the indicated file.
105104 *
@@ -109,49 +108,36 @@ private void stridedRead(SeekableByteChannel chan, int blockSize, int firstBlock
109108 private void countFile (String fname ) throws IOException {
110109 // large buffers pay off
111110 final int bufSize = 50 * 1024 * 1024 ;
111+ Queue <Future <WorkUnit >> work = new ArrayDeque <>();
112112 try {
113113 Path path = Paths .get (new URI (fname ));
114114 long size = Files .size (path );
115115 System .out .println (fname + ": " + size + " bytes." );
116- ByteBuffer buf = ByteBuffer .allocate (bufSize );
117- int nBlocks = (int )Math .ceil ( size / (double )bufSize );
118- int nThreads = nBlocks ;
116+ int nThreads = (int ) Math .ceil (size / (double ) bufSize );
119117 if (nThreads > 4 ) nThreads = 4 ;
120118 System .out .println ("Reading the whole file using " + nThreads + " threads..." );
121119 Stopwatch sw = Stopwatch .createStarted ();
122- final BufWithLock [] bufs = new BufWithLock [nThreads ];
123- for (int i = 0 ; i < nThreads ; i ++) {
124- bufs [i ] = new BufWithLock (bufSize );
125- final SeekableByteChannel chan = Files .newByteChannel (path );
126- final int finalNThreads = nThreads ;
127- final int finalI = i ;
128- bufs [i ].t = new Thread (new Runnable () {
129- @ Override
130- public void run () {
131- stridedRead (chan , bufSize , finalI , finalNThreads , bufs [finalI ]);
132- }
133- });
134- bufs [i ].t .start ();
135- }
136-
137120 long total = 0 ;
138121 MessageDigest md = MessageDigest .getInstance ("MD5" );
139- for (int block = 0 ; block < nBlocks ; block ++) {
140- BufWithLock bwl = bufs [block % bufs .length ];
141- synchronized (bwl .lock ) {
142- while (!bwl .full ) {
143- bwl .lock .wait ();
144- }
145- md .update (bwl .buf .array (), 0 , bwl .buf .position ());
146- total += bwl .buf .position ();
147- bwl .full = false ;
148- bwl .lock .notifyAll ();
122+
123+ ExecutorService exec = Executors .newFixedThreadPool (nThreads );
124+ int blockIndex ;
125+ for (blockIndex = 0 ; blockIndex < nThreads ; blockIndex ++) {
126+ work .add (exec .submit (new WorkUnit (Files .newByteChannel (path ), bufSize , blockIndex )));
127+ }
128+ while (true ) {
129+ WorkUnit full = work .remove ().get ();
130+ md .update (full .buf .array (), 0 , full .buf .position ());
131+ total += full .buf .position ();
132+ if (full .buf .hasRemaining ()) {
133+ break ;
149134 }
135+ work .add (exec .submit (full .resetForIndex (blockIndex ++)));
150136 }
151137
152138 long elapsed = sw .elapsed (TimeUnit .SECONDS );
153139 System .out .println ("Read all " + total + " bytes in " + elapsed + "s. " );
154- String hex = ( new HexBinaryAdapter ()). marshal (md .digest ());
140+ String hex = String . valueOf ( BaseEncoding . base16 (). encode (md .digest () ));
155141 System .out .println ("The MD5 is: 0x" + hex );
156142 if (total != size ) {
157143 System .out .println ("Wait, this doesn't match! We saw " + total + " bytes, " +
0 commit comments