Spark Scala UDF primitive type bug
I was working on an instrumentation framework for Scala UDFs in Spark when I noticed a subtle difference in the execution plan depending on whether I used wrappers or not. It looked like some code was added or was not predicate to check nulls: [code lang="scala"] val f = (x: Long) => x val udf0 = udf(f) ... .withColumn("udf0", udf0(...)) ... // in explain if (isnull(...)) null else UDF(...) AS udf0#111L [/code] vs [code lang="scala"] def identity[T, U](f: T => U): T => U = (t: T) => f(t) val udf1 = udf(identity(f)) ... .withColumn("udf1", udf1(...)) ... // in explain UDF(...) AS udf1#115L [/code] Quick doc checking sheds light on the special case of UDFs based on functions with primitive input arguments:
Note that if you use primitive parameters, you are not able to check if it is null or not, and the UDF will return null for you if the primitive input is null.In my case I have no really changed types, but I used high order function, something like this: [code lang="scala"] val f = (x: Long) => x def identity[T, U](f: T => U): T => U = (t: T) => f(t) val udf0 = udf(f) val udf1 = udf(identity(f)) [/code] Both udf0 and udf1 look pretty the same at first sight: [code lang="scala"] scala> def identity[T, U](f: T => U): T => U = (t: T) => f(t) identity: [T, U](f: T => U)T => U scala> val udf0 = udf(f) udf0: org.apache.spark.sql.expressions.UserDefinedFunction = UserDefinedFunction(<function1>,LongType,Some(List(LongType))) scala> val udf1 = udf(identity(f)) udf1: org.apache.spark.sql.expressions.UserDefinedFunction = UserDefinedFunction(<function1>,LongType,Some(List(LongType))) [/code] While during the execution they worked differently for null input: [code lang="scala"] scala> val getNull = udf(() => null.asInstanceOf[java.lang.Long]) getNull: org.apache.spark.sql.expressions.UserDefinedFunction = UserDefinedFunction(<function0>,LongType,Some(List())) scala> spark.range(5).toDF(). | withColumn("udf0", udf0(getNull())). | withColumn("udf1", udf1(getNull())). | show() +---+----+----+ | id|udf0|udf1| +---+----+----+ | 0|null| 0| | 1|null| 0| | 2|null| 0| | 3|null| 0| | 4|null| 0| +---+----+----+ scala> spark.range(5).toDF(). | withColumn("udf0", udf0(getNull())). | withColumn("udf1", udf1(getNull())). | explain() == Physical Plan == *Project [id#106L, if (isnull(UDF())) null else UDF(UDF()) AS udf0#111L, UDF(UDF()) AS udf1#115L] +- *Range (0, 5, step=1, splits=2) [/code] I tracked why this happen through Spark sources:
-
- udf [code lang="scala"] def udf[RT: TypeTag, A1: TypeTag](f: Function1[A1, RT]): UserDefinedFunction = { val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: Nil).toOption UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, inputTypes) } [/code]
- UserDefinedFunction [code lang="scala"] case class UserDefinedFunction protected[sql] ( f: AnyRef, dataType: DataType, inputTypes: Option[Seq[DataType]]) { ... def apply(exprs: Column*): Column = { Column(ScalaUDF(f, dataType, exprs.map(_.expr), inputTypes.getOrElse(Nil))) } } [/code]
- ScalaUDF [code lang="scala"] case class ScalaUDF( function: AnyRef, dataType: DataType, children: Seq[Expression], inputTypes: Seq[DataType] = Nil, udfName: Option[String] = None) extends Expression with ImplicitCastInputTypes with NonSQLExpression { ... [/code]
- HandleNullInputsForUDF from Catalyst Analyzer (TODO from this piece explained the fact of mess with nullability, it simply doesn't work when I would expect it does): [code lang="scala"] /** * Correctly handle null primitive inputs for UDF by adding extra [[If]] expression to do the * null check. When user defines a UDF with primitive parameters, there is no way to tell if the * primitive parameter is null or not, so here we assume the primitive input is null-propagatable * and we should return null if the input is null. */ object HandleNullInputsForUDF extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case p if !p.resolved => p // Skip unresolved nodes. case p => p transformExpressionsUp { case udf @ ScalaUDF(func, _, inputs, _, _) => val parameterTypes = ScalaReflection.getParameterTypes(func) assert(parameterTypes.length == inputs.length) val inputsNullCheck = parameterTypes.zip(inputs) // TODO: skip null handling for not-nullable primitive inputs after we can completely // trust the `nullable` information. // .filter { case (cls, expr) => cls.isPrimitive && expr.nullable } .filter { case (cls, _) => cls.isPrimitive } .map { case (_, expr) => IsNull(expr) } .reduceLeftOption[Expression]((e1, e2) => Or(e1, e2)) inputsNullCheck.map(If(_, Literal.create(null, udf.dataType), udf)).getOrElse(udf) } } } [/code]
- And final piece [code lang="scala"] def getParameterTypes(func: AnyRef): Seq[Class[_]] = { val methods = func.getClass.getMethods.filter(m => m.getName == "apply" && !m.isBridge) assert(methods.length == 1) methods.head.getParameterTypes } [/code]