Skip to content

Insight and analysis of technology and business strategy

Spark Scala UDF primitive type bug

I was working on an instrumentation framework for Scala UDFs in Spark when I noticed a subtle difference in the execution plan depending on whether I used wrappers or not. It looked like some code was added or was not predicate to check nulls: [code lang="scala"] val f = (x: Long) => x val udf0 = udf(f) ... .withColumn("udf0", udf0(...)) ... // in explain if (isnull(...)) null else UDF(...) AS udf0#111L [/code] vs [code lang="scala"] def identity[T, U](f: T => U): T => U = (t: T) => f(t) val udf1 = udf(identity(f)) ... .withColumn("udf1", udf1(...)) ... // in explain UDF(...) AS udf1#115L [/code] Quick doc checking sheds light on the special case of UDFs based on functions with primitive input arguments:
Note that if you use primitive parameters, you are not able to check if it is null or not, and the UDF will return null for you if the primitive input is null.
In my case I have no really changed types, but I used high order function, something like this: [code lang="scala"] val f = (x: Long) => x def identity[T, U](f: T => U): T => U = (t: T) => f(t) val udf0 = udf(f) val udf1 = udf(identity(f)) [/code] Both udf0 and udf1 look pretty the same at first sight: [code lang="scala"] scala> def identity[T, U](f: T => U): T => U = (t: T) => f(t) identity: [T, U](f: T => U)T => U scala> val udf0 = udf(f) udf0: org.apache.spark.sql.expressions.UserDefinedFunction = UserDefinedFunction(<function1>,LongType,Some(List(LongType))) scala> val udf1 = udf(identity(f)) udf1: org.apache.spark.sql.expressions.UserDefinedFunction = UserDefinedFunction(<function1>,LongType,Some(List(LongType))) [/code] While during the execution they worked differently for null input: [code lang="scala"] scala> val getNull = udf(() => null.asInstanceOf[java.lang.Long]) getNull: org.apache.spark.sql.expressions.UserDefinedFunction = UserDefinedFunction(<function0>,LongType,Some(List())) scala> spark.range(5).toDF(). | withColumn("udf0", udf0(getNull())). | withColumn("udf1", udf1(getNull())). | show() +---+----+----+ | id|udf0|udf1| +---+----+----+ | 0|null| 0| | 1|null| 0| | 2|null| 0| | 3|null| 0| | 4|null| 0| +---+----+----+ scala> spark.range(5).toDF(). | withColumn("udf0", udf0(getNull())). | withColumn("udf1", udf1(getNull())). | explain() == Physical Plan == *Project [id#106L, if (isnull(UDF())) null else UDF(UDF()) AS udf0#111L, UDF(UDF()) AS udf1#115L] +- *Range (0, 5, step=1, splits=2) [/code] I tracked why this happen through Spark sources:
    • udf [code lang="scala"] def udf[RT: TypeTag, A1: TypeTag](f: Function1[A1, RT]): UserDefinedFunction = { val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: Nil).toOption UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, inputTypes) } [/code]
    • UserDefinedFunction [code lang="scala"] case class UserDefinedFunction protected[sql] ( f: AnyRef, dataType: DataType, inputTypes: Option[Seq[DataType]]) { ... def apply(exprs: Column*): Column = { Column(ScalaUDF(f, dataType,, inputTypes.getOrElse(Nil))) } } [/code]
    • ScalaUDF [code lang="scala"] case class ScalaUDF( function: AnyRef, dataType: DataType, children: Seq[Expression], inputTypes: Seq[DataType] = Nil, udfName: Option[String] = None) extends Expression with ImplicitCastInputTypes with NonSQLExpression { ... [/code]
    • HandleNullInputsForUDF from Catalyst Analyzer (TODO from this piece explained the fact of mess with nullability, it simply doesn't work when I would expect it does): [code lang="scala"] /** * Correctly handle null primitive inputs for UDF by adding extra [[If]] expression to do the * null check. When user defines a UDF with primitive parameters, there is no way to tell if the * primitive parameter is null or not, so here we assume the primitive input is null-propagatable * and we should return null if the input is null. */ object HandleNullInputsForUDF extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case p if !p.resolved => p // Skip unresolved nodes. case p => p transformExpressionsUp { case udf @ ScalaUDF(func, _, inputs, _, _) => val parameterTypes = ScalaReflection.getParameterTypes(func) assert(parameterTypes.length == inputs.length) val inputsNullCheck = // TODO: skip null handling for not-nullable primitive inputs after we can completely // trust the `nullable` information. // .filter { case (cls, expr) => cls.isPrimitive && expr.nullable } .filter { case (cls, _) => cls.isPrimitive } .map { case (_, expr) => IsNull(expr) } .reduceLeftOption[Expression]((e1, e2) => Or(e1, e2)), Literal.create(null, udf.dataType), udf)).getOrElse(udf) } } } [/code]
    • And final piece [code lang="scala"] def getParameterTypes(func: AnyRef): Seq[Class[_]] = { val methods = func.getClass.getMethods.filter(m => m.getName == "apply" && !m.isBridge) assert(methods.length == 1) methods.head.getParameterTypes } [/code]
As you can see, it uses java runtime class information, and it's no surprise "isPrimitive" does't work the way we would expect due to the type erasure. In this case that is: [code lang="scala"] scala> ScalaReflection.getParameterTypes(f) res1: Seq[Class[_]] = WrappedArray(long) scala> ScalaReflection.getParameterTypes(identity(f)) res2: Seq[Class[_]] = WrappedArray(class java.lang.Object) [/code] Instead it should use TypeTag we have in udf declaration, like this: [code lang="scala"] scala> def myGetParameterTypes[T : TypeTag, U](func: T => U) = { | typeTag[T].tpe.typeSymbol.asClass | } myGetParameterTypes: [T, U](func: T => U)(implicit evidence$1: reflect.runtime.universe.TypeTag[T])reflect.runtime.universe.ClassSymbol scala> myGetParameterTypes(f) res3: reflect.runtime.universe.ClassSymbol = class Long scala> myGetParameterTypes(f).isPrimitive res4: Boolean = true [/code] The workaround is quite ugly though, it is to use specialization: [code lang="scala"] scala> def identity2[@specialized(Long) T, U](f: T => U): T => U = (t: T) => f(t) identity2: [T, U](f: T => U)T => U scala> val udf2 = udf(identity2(f)) udf2: org.apache.spark.sql.expressions.UserDefinedFunction = UserDefinedFunction(<function1>,LongType,Some(List(LongType))) scala> ScalaReflection.getParameterTypes(identity2(f)) res10: Seq[Class[_]] = WrappedArray(long) [/code] As result I submitted Spark Jira issue SPARK-23833 Be careful when using udf operating primitive types if nullable data can be passed to it. There are many possible scenarios when behavior may be different. It should be a rule that: if nullable data can be passed then you have to use boxed types or Option.

Top Categories

  • There are no suggestions because the search field is empty.

Tell us how we can help!