Higher Order Blog

home

vectormap and pvectormap

14 Oct 2010

So after attending Brian Goetz' talk and Rich Hickey's talk at JAOO Aarhus (eer, I mean Goto Aarhus), I was thinking about how to construct Clojure data structures in parallel. To start with something that wasn't too complex, I decided to try and create a parallel version of mapping a function for vectors, i.e., an eager function that would take a vector and a function as input and produce a mapped vector as output (instead of a seq). This would replace a pattern I've often used:
(into [] (map f vs))
with (vectormap f vs), which avoids the overhead of constructing a seq and the reconstructing a vector by conj'ing from the seq. My initial goal was to produce a parallel version e.g., (pvectormap f vs). As Rich has pointed out: Clojure's persistent data structures are excellent candidates for parallel processing using divide and conquer since they are trees which are already "sitting there divided!" Further immutability means no synchronization is needed. For example, remember that PersistentVector is a balanced 32-way tree consisting of size 32 arrays of Objects (Nodes in the tree or the actual values stored in the vector). I decided to warm up by implementing vectormap, i.e., the serial version, first.

Remember this?

Apparently Rich Hickey presented this piece of code at JVM Lang summit:
static public Object ret1(Object ret, Object nil) {
    return ret;
}

public static int count(Object o){
    if(o instanceof Counted)
        return ((Counted) o).count();
    return countFrom(Util.ret1(o, o = null));
}
Has Rich gone mad? A two argument static method which simply returns the first argument?? When I first read Ola's blog post I simply couldn't figure out why he would use that code... However, when I was writing the vectormap code I was thinking: suppose the input vector is really large -- in fact, so large that we don't have memory enough to hold both the input and the output vector. Then vectormap would produce an OutOfMemoryError. But suppose the calling code didn't actually need the input vector what if we would release references to the elements of the input vector as we construct the corresponding mapped elements in the output vector (but without destroying the input vector)? We would need to only keep references to Node vectors we hadn't already processed, and then null out our local variables to those we had. This was when I realized that this is exactly what Rich's function can help with: when calling the countFrom function, he provides as argument the seq referenced by the local variable o via the ret1 function. The side effect of using ret1 is that since Java is strict, both argument expressions to ret1 are evaluated (left-to-right), and consequently o is null'ed out. No more holding on the the head :) I could use this to only hold on to the vector arrays I hadn't processed. For example:
private static Node mapNode(IFn f, Node node, int level) {
	if (node == null) {return null;}
	if (level == 0) {
		return new Node(null,mapArray(f, Util.ret1(node.array, node=null)));
	}
	Object[] newArr = new Object[node.array.length];
	System.arraycopy(node.array, 0, newArr, 0, node.array.length);
	node=null;
	level -= 5;
	for (int i=0;i<newArr.length;i++) {
		newArr[i] = mapNode(f,Util.ret1((Node) newArr[i], newArr[i]=null),level);
	}
	return new Node(null,newArr);
}
This code is available in my clj-ds project.

Parallelize with Fork/Join

The pvectormap function uses Fork/Join to parallelize the mapping. I'm not sure about the granularity of the tasks, but I decided that processing a size 32 array was too small a task, and went with processing 32 size 32 arrays instead. Starting at the root array of nodes, the code simply forks 32 tasks which recursively process each child of the root. This forking continues recursively until we hit the second lowest level in the tree -- this is processed sequentially using the mapNode function from above. This is implemented as a RecursiveTask in the Fork/Join framework.
static final class PMapTask extends RecursiveTask<Node> {
   
	private IFn f;
	private int shift;
	private Node node;

	public PMapTask(IFn f, int shift, Node node) {
		this.f = f;
		this.shift = shift;
		this.node = node;
	}
   
	public Node compute() {
		if (node == null) {
			return null;
		}
	   if (this.shift <= 5) {
		   return mapNode(f,node,shift);
	   }

	   PMapTask[] tasks = new PMapTask[node.array.length];
	   shift -= 5;
	   for (int i=0;i<tasks.length;i++) {
		   tasks[i] = new PMapTask(f,shift,(Node) node.array[i]);
	   }
	   invokeAll(tasks);
	   Node[] nodes = new Node[node.array.length];
	   try {
		   for (int i=0;i<tasks.length;i++) {  
				nodes[i] = tasks[i].get();
		   }
		   return new Node(null,nodes);
	   } catch (InterruptedException e) {
			Thread.currentThread().interrupt();
			throw new RuntimeException(e);
		} catch (ExecutionException e) {
			throw new RuntimeException(e);
		}
   }
}	
On my dual core system with a non trivial function f that actually does some work, using pvectormap is about twice as fast as vectormap: see PersistentVectorTest of clj-ds. Next stop: optimize and add vectormap and pvectormap to Clojure core :)