Spark UDF memoization

3 min read
Apr 16, 2018

Memoization is a powerful technique that allows you to improve performance of repeatable computations. Although it would be a pretty handy feature, there is no memoization or result cache for UDFs in Spark as of today. In fact it's something we can easily implement. All examples below are in Scala.

The problem

Imagine we have a relatively expensive function [code lang="scala"] spark.udf.register("expensive", udf((x: Int) => { Thread.sleep(1); 1 })) [/code] And assume this function needs to be executed many times for a small set of arguments: [code lang="scala"] spark.range(50000).toDF() .withColumn("parent_id", col("id").mod(100)) .repartition(col("parent_id")) .createTempView("myTable") spark.sql("select id, expensive(parent_id) as hostname from myTable") [/code] Let's run some tests. I modified this function to increment invocation counting accumulator. I executed the test on a small dataproc cluster (Spark 2.2.0) [code lang="scala"] scala> import import scala> val invocations = spark.sparkContext.longAccumulator("invocations") invocations: org.apache.spark.util.LongAccumulator = LongAccumulator(id: 0, name: Some(invocations), value: 0) scala> def timing[T](body: => T): T = { | val t0 = System.nanoTime() | invocations.reset() | val res = body | val t1 = System.nanoTime() | println(s"invocations=${invocations.value}, time=${(t1 - t0) / 1e9}") | res | } timing: [T](body: => T)T scala> def expensive(n: Int) = { | Thread.sleep(1) | invocations.add(1) | 1 | } expensive: (n: Int)Int scala> spark.udf.register("expensive", udf((x: Int) => expensive(x))) res0: org.apache.spark.sql.expressions.UserDefinedFunction = UserDefinedFunction(<function1>,IntegerType,Some(List(IntegerType))) scala> spark.range(50000).toDF(). | withColumn("parent_id", col("id").mod(100)). | createTempView("myTable") scala> spark.sql("select id, expensive(parent_id) as one from myTable").createTempView("expensive_table") scala> timing(spark.sql("select sum(one) from expensive_table").show(truncate = false)) +--------+ |sum(one)| +--------+ |50000 | +--------+ invocations=50000, time=9.493999374 [/code] The sum (one) expensive function was called 50000 times and the job took around 10 seconds.

Simple memoization

How can we improve this timing? We can memoize function results with following simple code: [code lang="scala"] def memo[T, U](f: T => U): T => U = { lazy val cache = new ConcurrentHashMap[T, U]() (t: T) => cache.computeIfAbsent(t, new JF[T, U] { def apply(t: T): U = f(t) }) } spark.udf.register("memoized", udf(memo((x: Int) => expensive(x)))) [/code] It uses lazy val from closure so there is an instance of cache for udf instance. Let's run more tests! [code lang="scala"] scala> import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.ConcurrentHashMap scala> import java.util.function.{Function => JF} import java.util.function.{Function=>JF} scala> implicit def toJF[T, U](f: T => U): JF[T, U] = new JF[T, U] { | def apply(t: T): U = f(t) | } warning: there was one feature warning; re-run with -feature for details toJF: [T, U](f: T => U)java.util.function.Function[T,U] scala> def memo[T, U](f: T => U): T => U = { | lazy val cache = new ConcurrentHashMap[T, U]() | (t: T) => cache.computeIfAbsent(t, f) | } memo: [T, U](f: T => U)T => U scala> spark.udf.register("memoized", udf(memo((x: Int) => expensive(x)))) res4: org.apache.spark.sql.expressions.UserDefinedFunction = UserDefinedFunction(<function1>,IntegerType,Some(List(IntegerType))) scala> spark.sql("select id, memoized(parent_id) as one from myTable").createTempView("memoized_table") scala> timing(spark.sql("select sum(one) from memoized_table").show(truncate = false)) +--------+ |sum(one)| +--------+ |50000 | +--------+ invocations=600, time=0.378553505 [/code] This time we see that the function was invoked just 600 times which gave us 0.4 sec instead of almost 10!

Filtering and UDFs

Another pain is the filtering of columns based on udf. Let's modify our example a little bit and add a few filters. How many times will UDF be invoked here? [code lang="scala"] scala> timing(spark.sql("select sum(one) from expensive_table where one * one = one ").show(truncate = false)) +--------+ |sum(one)| +--------+ |50000 | +--------+ invocations=200000, time=37.449562222 [/code] It was executed four times for each input row, 200000 total and took 37 seconds! That's too much. What if we use a memoized version? [code lang="scala"] scala> timing(spark.sql("select sum(one) from memoized_table where one * one = one ").show(truncate = false)) +--------+ |sum(one)| +--------+ |50000 | +--------+ invocations=600, time=0.34141222 [/code] Really no difference by comparison, the same counts and timings.

Next steps

This is just a simple example and is meant for production use - you would have to think about possible problems such as:
  • ConcurrentHashMap can't store null keys nor values, so you would need to wrap them into Option for example;
  • What if there are not 100 but 100M of different arguments? How should we limit cache size, what strategy should we use: keep last, keep first, most used, lru, ...?
  • What if invocation of the function may take a long, unknown length of time or hang? computeIfAbsent will be blocked for the same keys (even if the next invocation of the same function can be instant). Should it be bypassed after some reasonable timeout or should you use optimistic lock strategy?

Get Email Notifications

No Comments Yet

Let us know what you think