Introduction
The Transformer of Spark works out of the box in most cases. However, if we would like to serialize a trained model with custom transformers, it becomes a little bit painful as it’s badly documented.
The following sections explain how to create a serializable custom spark ml transformer.
Why
A custom Transformer
doesn’t implement any readable/writable methods. That’s why we get errors while saving/loading Pipeline
(s) having custom transformers.
How
In order to properly save and load the trained model, it is required to make the custom transformers writable and readable.
A walk around is to make the custom transformer extend DefaultParamsWritable
and DefaultParamsReadable
.
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.param.{ Param, ParamMap }
import org.apache.spark.ml.util.{ DefaultParamsReadable, DefaultParamsWritable, Identifiable }
import org.apache.spark.sql.{ DataFrame, Dataset }
import org.apache.spark.sql.types.StructType
class ColRenameTransformer(override val uid: String) extends Transformer with DefaultParamsWritable {
def this() = this(Identifiable.randomUID("ColRenameTransformer"))
def setInputCol(value: String): this.type = set(inputCol, value)
def setOutputCol(value: String): this.type = set(outputCol, value)
def getOutputCol: String = getOrDefault(outputCol)
val inputCol = new Param[String](this, "inputCol", "input column")
val outputCol = new Param[String](this, "outputCol", "output column")
override def transform(dataset: Dataset[_]): DataFrame = {
val outCol = extractParamMap.getOrElse(outputCol, "output")
val inCol = extractParamMap.getOrElse(inputCol, "input")
dataset.drop(outCol).withColumnRenamed(inCol, outCol)
}
override def copy(extra: ParamMap): ColRenameTransformer = defaultCopy(extra)
override def transformSchema(schema: StructType): StructType = schema
}
object ColRenameTransformer extends DefaultParamsReadable[ColRenameTransformer] {
override def load(path: String): ColRenameTransformer = super.load(path)
}
While applying the custom transformer:
new ColRenameTransformer()
.setInputCol(...)
.setOutputCol(...)
Now the custom transformer ColRenameTransformer
works smoothly with model.save(...)
and Model.load(...)
.