• Home
  • About
    • Qikai Gu photo

      Qikai Gu

      Software Engineer in Machine Learning

    • Learn More
    • LinkedIn
    • Github
    • Twitter
    • StackOverflow
  • Posts
    • All Posts
    • All Tags
  • Projects

Serializable Custom Transformer with Spark 2.0 (Scala)

25 Jun 2019

Reading time ~1 minute

Introduction

The Transformer of Spark works out of the box in most cases. However, if we would like to serialize a trained model with custom transformers, it becomes a little bit painful as it’s badly documented.

The following sections explain how to create a serializable custom spark ml transformer.

Why

A custom Transformer doesn’t implement any readable/writable methods. That’s why we get errors while saving/loading Pipeline (s) having custom transformers.

How

In order to properly save and load the trained model, it is required to make the custom transformers writable and readable.

A walk around is to make the custom transformer extend DefaultParamsWritable and DefaultParamsReadable.

import org.apache.spark.ml.Transformer
import org.apache.spark.ml.param.{ Param, ParamMap }
import org.apache.spark.ml.util.{ DefaultParamsReadable, DefaultParamsWritable, Identifiable }
import org.apache.spark.sql.{ DataFrame, Dataset }
import org.apache.spark.sql.types.StructType

class ColRenameTransformer(override val uid: String) extends Transformer with DefaultParamsWritable {

  def this() = this(Identifiable.randomUID("ColRenameTransformer"))
  def setInputCol(value: String): this.type = set(inputCol, value)
  def setOutputCol(value: String): this.type = set(outputCol, value)
  def getOutputCol: String = getOrDefault(outputCol)

  val inputCol = new Param[String](this, "inputCol", "input column")
  val outputCol = new Param[String](this, "outputCol", "output column")

  override def transform(dataset: Dataset[_]): DataFrame = {
    val outCol = extractParamMap.getOrElse(outputCol, "output")
    val inCol = extractParamMap.getOrElse(inputCol, "input")

    dataset.drop(outCol).withColumnRenamed(inCol, outCol)
  }

  override def copy(extra: ParamMap): ColRenameTransformer = defaultCopy(extra)
  override def transformSchema(schema: StructType): StructType = schema
}

object ColRenameTransformer extends DefaultParamsReadable[ColRenameTransformer] {
  override def load(path: String): ColRenameTransformer = super.load(path)
}

While applying the custom transformer:

new ColRenameTransformer()
  .setInputCol(...)
  .setOutputCol(...)

Now the custom transformer ColRenameTransformer works smoothly with model.save(...) and Model.load(...).



sparkscalamachine-learningpipelinetransformerserialization Share Tweet +1