import java.io.*; import java.util.*; import org.apache.commons.logging.*; import org.apache.hadoop.conf.*; import org.apache.hadoop.fs.Path; import org.apache.hadoop.io.*; import org.apache.hadoop.mapreduce.*; import org.apache.hadoop.mapreduce.lib.allreduce.*; import org.apache.hadoop.mapreduce.lib.input.*; import org.apache.hadoop.mapreduce.lib.output.*; import org.apache.hadoop.util.*; public class KMeans extends Configured implements Tool { public static final Log LOG = LogFactory.getLog(KMeans.class); public static class Point { public double[] p; public Point(String str) { String[] parts = str.split("\\s+"); p = new double[parts.length]; for (int i = 0; i < parts.length; i++) p[i] = Double.valueOf(parts[i]); } public double distanceTo(Point other) { double result = 0; for (int i = 0; i < p.length; i++) result += (p[i] - other.p[i]) * (p[i] - other.p[i]); return result; } public String toString() { StringBuilder sb = new StringBuilder(); for (int i = 0; i < p.length; i++) { if (i > 0) sb.append(' '); sb.append(String.valueOf(p[i])); } return sb.toString(); } } public static class TheMapper extends AllReduceMapper{ ArrayList points = new ArrayList(); int dim = 0; public void map(LongWritable key, Text value, Context context) throws IOException, InterruptedException { Point point = new Point(value.toString()); if (dim == 0) dim = point.p.length; else assert(point.p.length == dim); points.add(point); } public void cooperate(Context context, boolean writeResults) throws IOException, InterruptedException { if (dim == 0) return; // Read initial clusters. int clusters_num = context.getConfiguration().getInt("clusters.num", -1); String file = context.getConfiguration().get("clusters.file", null); Point[] clusters = new Point[clusters_num]; double[][] coords = new double[clusters_num][]; BufferedReader in = new BufferedReader(new FileReader(file)); for (int i = 0; i < clusters_num; i++) { clusters[i] = new Point(in.readLine()); assert(clusters[i].p.length == dim); coords[i] = new double[dim + 1]; } in.close(); for (int iter = 0; iter < 100; iter++) { //0. Clear coords. for (int i = 0; i < coords.length; i++) for (int j = 0; j < coords[i].length; j++) coords[i][j] = 0; //1. Add points to their clusters. for (Point point : points) { int besti = 0; double best = point.distanceTo(clusters[0]); for (int i = 1; i < clusters.length; i++) { double act = point.distanceTo(clusters[i]); if (act < best) { besti = i; best = act; } } coords[besti][0]++; for (int i = 0; i < dim; i++) coords[besti][i+1] += point.p[i]; } //2. AllReduce allReduce(context, coords); //3. Compute new clusters double change = 0; for (int i = 0; i < clusters.length; i++) if (coords[i][0] != 0) { double distance = 0; for (int j = 0; j < dim; j++) { double newpj = coords[i][j+1] / coords[i][0]; distance += (newpj - clusters[i].p[j]) * (newpj - clusters[i].p[j]); clusters[i].p[j] = newpj; } change += Math.sqrt(distance); } if (writeResults) LOG.info(String.format("Finished iteration %d with change %f.", iter, change)); //4. Synchronize on change value. change = allReduce(context, change, REDUCE_MAX); if (writeResults) LOG.info(String.format("Change synchronized to %f.", change)); if (change < 1e-3) break; } if (writeResults) { LOG.info("Done iterating"); for (int i = 0; i < clusters.length; i++) context.write(new IntWritable(i), new Text(clusters[i].toString())); } } } // Job configuration public int run(String[] args) throws Exception { if (args.length < 2) { System.err.printf("Usage: %s.jar in-path out-path", this.getClass().getName()); return 1; } Job job = new Job(getConf(), this.getClass().getName()); // Check we have enough information for AllReduce. if (job.getConfiguration().getInt("clusters.num", -1) == -1) throw new IOException("Missing number of clusters. Use -Dclusters.num=number."); if (job.getConfiguration().getInt("clusters.num", -1) <= 0) throw new IOException("At least one cluster in -Dclusters.num=number must be specified."); if (job.getConfiguration().get("clusters.file", null) == null) throw new IOException("Missing file with clusters. Use -Dclusters.file=path."); job.setJarByClass(this.getClass()); job.setMapperClass(TheMapper.class); AllReduce.init(job); job.setOutputKeyClass(IntWritable.class); job.setOutputValueClass(Text.class); job.setInputFormatClass(TextInputFormat.class); FileInputFormat.addInputPath(job, new Path(args[0])); FileOutputFormat.setOutputPath(job, new Path(args[1])); return job.waitForCompletion(true) ? 0 : 1; } // Main method public static void main(String[] args) throws Exception { int res = ToolRunner.run(new KMeans(), args); System.exit(res); } }