[ Skip to the content ]

Institute of Formal and Applied Linguistics Wiki


[ Back to the navigation ]

Differences

This shows you the differences between two versions of the page.

Link to this comparison view

Both sides previous revision Previous revision
Next revision
Previous revision
Next revision Both sides next revision
spark:spark-introduction [2014/11/03 18:23]
straka
spark:spark-introduction [2014/11/03 20:37]
straka
Line 64: Line 64:
 ===== K-Means Example ===== ===== K-Means Example =====
 To show an example of iterative algorithm, consider [[http://en.wikipedia.org/wiki/K-means_clustering|Standard iterative K-Means algorithm]]. To show an example of iterative algorithm, consider [[http://en.wikipedia.org/wiki/K-means_clustering|Standard iterative K-Means algorithm]].
 +<file python>
 +import numpy as np
 +
 +def closestPoint(point, centers):   # Find index of center which is closes to given point
 +    return min((np.sum((point - centers[i]) ** 2), i) for i in range(len(centers)))[1]
 +
 +lines = sc.textFile("/net/projects/hadoop/examples/inputs/points-small/points.txt", sc.defaultParallelism)
 +data = lines.map(lambda line: np.array([float(x) for x in line.split()])).cache()
 +
 +K = 50
 +epsilon = 1e-3
 +
 +centers = data.takeSample(False, K)       # Sample K random points
 +for i in range(5):                        # Perform 5 iterations
 +    old_centers = sc.broadcast(centers)
 +    centers = (data
 +               # For each point, find its closest center index.
 +               .map(lambda point: (closestPoint(point, old_centers.value), (point, 1)))
 +               # Sum points and counts in each cluster.
 +               .reduceByKey(lambda (p1, c1), (p2, c2): (p1 + p2, c1 + c2))
 +               # Sort by cluster index.
 +               .sortByKey()
 +               # Compute the new centers by averaging points in clusters.
 +               .map(lambda (index, (sum, count)): sum / count)
 +               .collect())
 +    # If the change in center positions is less than epsilon, stop.
 +    centers_change = sum(np.sqrt(np.sum((a - b)**2)) for (a, b) in zip(centers, old_centers.value))
 +    old_centers.unpersist()
 +    if centers_change < epsilon:
 +        break
 +
 +print "Final centers: " + str(centers)
 +</file>
 +The implementation starts by loading the data and caching them in memory using ''cache''. Then, standard iterative algorithm is performed, running in parallel. 
 +
 +Note that explicit broadcasting used for ''centers'' object is not strictly needed -- if we used ''old_centers = centers'', the example would work too, but it would send a copy of ''old_centers'' to //every distributed task//, instead of once to every machine.
 +
 +For illustration, Scala version of the example follows. It works exactly as the Python version and uses ''breeze.linalg.Vector'' providing linear algebraic operations.
 +<file scala>
 +import breeze.linalg.Vector
 +
 +type Vector = breeze.linalg.Vector[Double]
 +type Vectors = Array[Vector]
 +
 +def closestPoint(point : Vector, centers : Vectors) : Double =
 +  centers.map(center => (center-point).norm(2)).zipWithIndex.min._2
 +
 +val lines = sc.textFile("/net/projects/hadoop/examples/inputs/points-small/points.txt", sc.defaultParallelism)
 +val data = lines.map(line => Vector(line.split("\\s+").map(_.toDouble))).cache()
 +
 +val K = 50
 +val epsilon = 1e-3
 +
 +var i = 0
 +var centers_change = Double.PositiveInfinity
 +var centers = data.takeSample(false, K)
 +while (i < 10 && centers_change > epsilon) {
 +  val old_centers = sc.broadcast(centers)
 +  centers = (data
 +             // For each point, find its closes center index.
 +             .map(point => (closestPoint(point, old_centers.value), (point, 1)))
 +             // Sum points and counts in each cluster.
 +             .reduceByKey((a, b) => (a._1+b._1, a._2+b._2))
 +             // Sort by cluster index.
 +             .sortByKey()
 +             // Compute the new centers by averaging corresponding points.
 +             .map({case (index, (sum, count)) => sum :/ count.toDouble})
 +             .collect())
 +
 +  // Compute change in center positions.
 +  centers_change = (centers zip old_centers.value).map({case (a,b) => (a-b).norm(2)}).sum
 +  old_centers.unpersist()
 +  i += 1
 +}
 +
 +print(centers.deep)
 +</file>
 +

[ Back to the navigation ] [ Back to the content ]