2626import java .io .IOException ;
2727import java .io .InputStreamReader ;
2828import java .net .URI ;
29- import java .util .*;
29+ import java .util .ArrayList ;
30+ import java .util .List ;
3031
3132
3233public class KMeans {
@@ -37,8 +38,7 @@ public class KMeans {
3738 public static class CustomMapper
3839 extends Mapper <Object , Text , Text , Text > {
3940
40- List <List <Float >> centroids = null ;
41- Map <String , Integer > centroidIdMap = new HashMap <>();
41+ List <List <Double >> centroids = null ;
4242
4343 protected void setup (Context context ) throws IOException , InterruptedException {
4444 centroids = new ArrayList <>();
@@ -50,15 +50,12 @@ protected void setup(Context context) throws IOException, InterruptedException {
5050 FileSystem fs = FileSystem .get (context .getConfiguration ());
5151 Path getFilePath = new Path (cacheFiles [0 ].toString ());
5252 BufferedReader reader = new BufferedReader (new InputStreamReader (fs .open (getFilePath )));
53- int lineNumber = 0 ;
5453 while ((line = reader .readLine ()) != null ) {
55- lineNumber ++;
5654 String [] values = line .split ("," );
57- List <Float > c = new ArrayList <>();
55+ List <Double > c = new ArrayList <>();
5856 for (int i = 0 ; i < values .length ; i ++) {
59- c .add (Float . parseFloat (values [i ]));
57+ c .add (Double . parseDouble (values [i ]. trim () ));
6058 }
61- centroidIdMap .putIfAbsent (StringUtils .join (c , "," ), lineNumber );
6259 centroids .add (c );
6360
6461 }
@@ -69,19 +66,19 @@ protected void setup(Context context) throws IOException, InterruptedException {
6966 }
7067 }
7168
72- public float getEuclideanDistance (List <Float > arr1 , List <Float > arr2 ) {
73- float sum = 0 ;
69+ public double getEuclideanDistance (List <Double > arr1 , List <Double > arr2 ) {
70+ double sum = 0 ;
7471 for (int i = 0 ; i < arr1 .size (); i ++) {
7572 sum = sum + (arr1 .get (i ) - arr2 .get (i )) * (arr1 .get (i ) - arr2 .get (i ));
7673 }
77- return ( float ) Math .sqrt (sum );
74+ return Math .sqrt (sum );
7875 }
7976
80- public List <Float > getClosestCentroidForDataPoint (List <List <Float >> centroids , List <Float > dataPoint ) {
81- float leastDistance = Integer .MAX_VALUE ;
82- List <Float > closestCentroid = new ArrayList <>();
77+ public List <Double > getClosestCentroidForDataPoint (List <List <Double >> centroids , List <Double > dataPoint ) {
78+ double leastDistance = Integer .MAX_VALUE ;
79+ List <Double > closestCentroid = new ArrayList <>();
8380 for (int i = 0 ; i < centroids .size (); i ++) {
84- float distance = getEuclideanDistance (centroids .get (i ), dataPoint );
81+ double distance = getEuclideanDistance (centroids .get (i ), dataPoint );
8582 if (distance < leastDistance ) {
8683 closestCentroid = centroids .get (i );
8784 leastDistance = distance ;
@@ -94,15 +91,19 @@ public List<Float> getClosestCentroidForDataPoint(List<List<Float>> centroids, L
9491 public void map (Object key , Text value , Context context
9592 ) throws IOException , InterruptedException {
9693
97- List <Float > dataPointList = new ArrayList <>();
98- Float [] dataPoint = Arrays .stream (value .toString ().split ("," )).map (Float ::valueOf ).toArray (Float []::new );
99- Collections .addAll (dataPointList , dataPoint );
100- List <Float > closestCentroid = getClosestCentroidForDataPoint (centroids , dataPointList );
94+ List <Double > dataPointList = new ArrayList <>();
95+ String [] dataPointStr = value .toString ().split ("," );
96+ for (String s : dataPointStr )
97+ dataPointList .add (Double .parseDouble (s ));
98+ List <Double > closestCentroid = getClosestCentroidForDataPoint (centroids , dataPointList );
10199
102- String k = centroidIdMap .get (StringUtils .join (closestCentroid , "," )).toString ();
103100 // key - id of centroid that matches the data point
104101 // value - data point
105- context .write (new Text (k ), new Text (value ));
102+ String a = StringUtils .join (dataPointList , "," );
103+ if (a .split ("," ).length > 2 ) {
104+ System .out .println ("" );
105+ }
106+ context .write (new Text (StringUtils .join (closestCentroid , "," )), new Text (StringUtils .join (dataPointList , "," )));
106107 }
107108 }
108109
@@ -114,23 +115,25 @@ public void reduce(Text key, Iterable<Text> values,
114115 ) throws IOException , InterruptedException {
115116
116117 // input
117- // key - centroid id
118+ // key - centroid
118119 // value - list of data points belonging to the centroid
119-
120- List <Float > summedUpValues = new ArrayList <>();
120+ List <Integer > summedUpValues = new ArrayList <>();
121121 int count = 0 ;
122122 for (Text val : values ) {
123123 String [] dataPoint = val .toString ().split ("," );
124- for (int j = 0 ; j < dataPoint .length ; j ++) {
125- if (count == 0 )
126- summedUpValues .add (Float .parseFloat (dataPoint [j ]));
127- else
128- summedUpValues .set (j , summedUpValues .get (j ) + Float .parseFloat (dataPoint [j ]));
124+ for (int i = 0 ; i < dataPoint .length ; i ++) {
125+ if (count == 0 ) {
126+ summedUpValues .add ((int ) Double .parseDouble (dataPoint [i ]));
127+ } else {
128+ summedUpValues .set (i , (int ) ((summedUpValues .get (i ) + Double .parseDouble (dataPoint [i ]))));
129+ }
129130 }
130131 count ++;
132+
131133 }
134+
132135 // output
133- // key - centroid id
136+ // key - centroid
134137 // value - summed up values as one string, count
135138 context .write (key , new Text (StringUtils .join (summedUpValues , "," ) + "," + count ));
136139 }
@@ -144,37 +147,39 @@ public void reduce(Text key, Iterable<Text> values,
144147 ) throws IOException , InterruptedException {
145148
146149 // input
147- // key - centroid id
150+ // key - centroid
148151 // value - summed up values as one string, count
152+ int localCount = 0 ;
149153 int globalCount = 0 ;
150- int reducerCount = 0 ;
151- List < Float > summedUpValues = new ArrayList <>() ;
154+ List < Double > summedUpValues = new ArrayList <>() ;
155+ double [] arr = { 0 , 0 } ;
152156 for (Text val : values ) {
153157 String [] dataPoint = val .toString ().split ("," );
154- globalCount += Integer .parseInt (dataPoint [dataPoint .length - 1 ]);
155- for (int j = 0 ; j < dataPoint .length - 1 ; j ++) {
156- if (reducerCount == 0 )
157- summedUpValues .add (Float .parseFloat (dataPoint [j ]));
158- else
159- summedUpValues .set (j , summedUpValues .get (j ) + Float .parseFloat (dataPoint [j ]));
158+ globalCount += Double .parseDouble (dataPoint [dataPoint .length - 1 ]);
159+ for (int i = 0 ; i < dataPoint .length - 1 ; i ++) {
160+ if (localCount == 0 ) {
161+ summedUpValues .add (Double .parseDouble (dataPoint [i ]));
162+ } else {
163+ summedUpValues .set (i , summedUpValues .get (i ) + Double .parseDouble (dataPoint [i ]));
164+ }
160165 }
161- reducerCount ++;
166+ localCount ++;
167+
162168 }
163169 // finding the new average i.e the centroid
164170 for (int i = 0 ; i < summedUpValues .size (); i ++) {
165171 summedUpValues .set (i , summedUpValues .get (i ) / globalCount );
166172 }
167- // key - centroid id
173+
174+ // key - empty
168175 // value - new centroid
169- context .write (new Text ("" ), new Text (StringUtils . join ( summedUpValues , "," ) ));
176+ context .write (new Text ("" ), new Text (arr [ 0 ] + "," + arr [ 1 ] ));
170177 }
171178
172179 }
173180
174181
175182 public static void main (String [] args ) throws Exception {
176- // System.out.println();
177- // System.exit(1);
178183
179184 if (args .length < 4 ) {
180185 throw new Exception ("Pass all the required arguments. Input file, Centroids file path, Output filepath, number of iterations" );
@@ -183,8 +188,9 @@ public static void main(String[] args) throws Exception {
183188 String inputData = args [0 ];
184189 String inputCentroidsPath = args [1 ];
185190 String newCentroidsPath = args [2 ];
191+
186192 String centroidsFilename = "centroids.txt" ;
187- int numOfReducers = 5 ;
193+ int numOfReducers = 1 ;
188194 int numOfIterations = Integer .parseInt (args [3 ]);
189195 String hadoopHome = System .getenv ("HADOOP_HOME" );
190196 if (hadoopHome == null ) {
@@ -210,7 +216,6 @@ public static void main(String[] args) throws Exception {
210216
211217 FileInputFormat .addInputPath (job , new Path (inputData ));
212218 FileOutputFormat .setOutputPath (job , new Path (newCentroidsPath + "/" + i ));
213-
214219 // Add distributed cache file
215220 try {
216221 if (i == 0 )
@@ -232,7 +237,7 @@ public static void main(String[] args) throws Exception {
232237 GeneralUtilities .writeIterableToFileHDFS (intermediateCentroids , newCentroidsPath + "/" + i + "/" + centroidsFilename , fs );
233238
234239
235- // if i==0 which means its the first iteration, do not run the compare logic
240+ // if i==0 which means its the first iteration, do not run the compare logic
236241 if (i != 0 ) {
237242
238243 // compare logic, pick the single file that was saved in previous step and compare it to the single file that was
@@ -244,8 +249,6 @@ public static void main(String[] args) throws Exception {
244249 }
245250 }
246251 }
247- if (fs != null )
248- fs .close ();
249252 }
250253
251254}
0 commit comments