To transform a dataframe with a column containing a json string to a typed dataframe, we have to know exactly what is the schema of our json string. We should have been able to infer schema from json string column with schema_of_json function but it was scrapped as added complexity was considered not worth it, see SPARK-24642. Here we present a method to infer a global schema from a column containing different json strings by using an user-defined aggregate function

Code in this blog post was developed using Apache Spark 3.1.2

Problem Statement

We have the following input dataframe, with a column value containing a json string:

+--------------------------------+
|value                           |
+--------------------------------+
|{"a": 1, "b": "value1"}         |
|{"b": "value2", "c": [1, 2, 3]} |
|{"a": 3, "d": {"d1" : "value3"}}|
|{"d": {"d2" : 4}}               |
+--------------------------------+

We want to transform this json string into a structured column, so we can for instance select nested fields easily. However, we don’t know exactly what json schema to apply to this column. How can we retrieve this schema from the different values of our json string column ?

To solve this, we use an user-defined aggregate function that will aggregate all schema of all json value over our dataframe. To create an user-defined aggregate function, we implement a JsonSchemaAggregator Aggregator and then transform it into an user-defined aggregate function using udaf function.

Create a JsonSchemaAggregator Aggregator

We create an Aggregator by creating a class that extends Aggregator[-IN, BUF, OUT] abstract class. We customize our JsonSchemaAggregator by defining all types and implementing all methods of Aggregator abstract class.

Here’s the types we need to define:

  • -IN: type of our input column(s)

  • BUF: type of our buffer (accumulator)

  • OUT: type of our output

And here’s the methods we need to implement:

  • zero: how to initialize buffer

  • merge: how to merge two buffers

  • reduce: how to merge value coming from input column(s) with buffer

  • finish: how to convert buffer to output value once aggregation is over

Our JsonSchemaAggregator starts from an empty schema, and for each new row in our dataframe, read the value column json, extract its json schema and enriches its buffer schema with fields present in row’s json schema but not in buffer schema. At the end, it returns the enriched schema.

-IN input column type

-IN type is the type of the input column. As we read a column containing a JSON as a String, -IN input type is String.

BUF buffer type

BUF type is the type of the buffer that contains the schema to be enriched. Obvious choice for buffer type would be StructType, as it is how schemas are defined in Spark. However, a buffer moves between Spark’s executors, so it should be able to be encoded. As objects of type StructType cannot be encoded, we can’t use StructType for BUF type.

We choose Map[String, String] as buffer type, where key are the path to the field, and value the type of the field. Path is the complete path of the field, with dot . to separate objects and braces [] if field is in an array. So for instance a field has a path a.b[].c.d[] means that this field is element of a json array called d that is a field of a json object called c that is itself a json object that is an element of a json array b that is a field of a json object a

For the types, we map the different final json types as follows. Object and Array types will be deduced from path:

Json Type Buffer Type

Null

STRING in array, else ignored

String

STRING

Boolean

BOOLEAN

Number

BIGINT or DOUBLE

To recap everything, if we take the following json:

{
  "id": 1,
  "name": "John Doe",
  "address": {
    "number": 2,
    "road": "avenue Foch",
    "city": {
      "name": "Paris",
      "postcode": 75016
    }
  },
  "is_active": true,
  "amount_spent": 1235.54,
  "created_at": "2021-10-23 10:23:45",
  "phone_numbers": [0123456789, 0987654321]
}

We represent its schema with the following map:

Map(
  "id" -> "BIGINT",
  "name" -> "STRING",
  "address.number" -> "BIGINT",
  "address.road" -> "STRING",
  "address.city.name" -> "STRING",
  "address.city.postcode" -> "BIGINT",
  "is_active" -> "BOOLEAN",
  "amount_spent" -> "DOUBLE",
  "created_at" -> "STRING",
  "phone_number[]" -> "BIGINT"
)

OUT output type

OUT is the type of our aggregated result. We can’t use StructType as Spark can’t encode it. We need a type that can be easily converted into Spark Schema. It lets us two choices: json or DDL string. We choose DDL string as it is more concise than json. DDL string are an SQL representation of a schema that can be translated into a Spark Schema using DataType.fromDDL method.

zero method that initializes buffer

For zero method, we initialize buffer with empty Map[String, String].

override def zero: Map[String, String] = Map.empty[String, String]

merge method that merges two buffers

For merge method, we merge the two Map. However, there is a special case to treat. As we use STRING as type for empty array, if we merge an array field and one of the type of this field is STRING, then the merged type should be the other type. For instance if we have two schema buffers as follows:

val schema1 = Map("a.b[]" -> "STRING")
val schema2 = Map("a.b[]" -> "BIGINT")

We want the merged schema buffer to be equal to schema2. So, for each key fields in each schema to merge, we take the two candidates coming for each schema. If the field is an array field and the first candidate is of type STRING, we save the second candidate, else we save the first candidate. It gives us the following merge method:

override def merge(schema1: Map[String, String], schema2: Map[String, String]): Map[String, String] = {
  (schema1.toSeq ++ schema2.toSeq).groupBy(_._1).map(elem => {
    if (elem._1.endsWith("[]") && elem._2.head._2 == "STRING") {
      elem._2.last
    } else {
      elem._2.head
    }
  })
}

reduce method that merges value from input column with buffer

For reduce method, we need to transform json string into a Map[String, String] schema, and then use merge method above to merge the buffer schema and the schema coming from json string:

override def reduce(currentSchema: Map[String, String], json: String): Map[String, String] = {
  merge(SparkSchemaParser.fromJson(json), currentSchema)
}

To implement SparkSchemaParser's fromJson method, we use Jackson’s json parser as it is already present in Spark’s dependencies. We parse our json with jackson and then iterate over all the nodes of Jackson’s json representation, adding each field with its type to a Map[String, String]:

import com.fasterxml.jackson.databind.node.{ArrayNode, ObjectNode, ValueNode}
import com.fasterxml.jackson.databind.{JsonNode, ObjectMapper}
import com.fasterxml.jackson.module.scala.DefaultScalaModule

import scala.collection.JavaConverters._

object SparkSchemaParser {

  private val mapper = {
    val inner = new ObjectMapper()
    inner.registerModule(DefaultScalaModule)
    inner
  }

  def fromJson(json: String): Map[String, String] = addKeys("", mapper.readTree(json), Map.empty[String, String])

  private def addKeys(currentPath: String, jsonNode: JsonNode, accumulator: Map[String, String]): Map[String, String] = {
    if (jsonNode.isObject()) {
      val objectNode = jsonNode.asInstanceOf[ObjectNode]
      val iter = objectNode.fields().asScala.toSeq
      val pathPrefix = if (currentPath.isEmpty()) "" else currentPath + "."
      accumulator ++ iter.map(field => addKeys(pathPrefix + field.getKey, field.getValue, Map())).reduce(_ ++ _)
    } else if (jsonNode.isArray()) {
      val arrayNode = jsonNode.asInstanceOf[ArrayNode]
      if (arrayNode.size() == 0) {
        accumulator ++ Map(currentPath + "[]" -> "STRING")
      } else {
        addKeys(currentPath + "[]", arrayNode.get(0), accumulator)
      }
    } else if (jsonNode.isValueNode()) {
      val valueNode = jsonNode.asInstanceOf[ValueNode]
      accumulator ++ Map(currentPath -> getType(valueNode))
    } else if (jsonNode.isNull) {
      accumulator
    } else {
      throw new IllegalArgumentException(s"unknown node type for node $jsonNode")
    }
  }

  private def getType(valueNode: ValueNode): String = {
    if (valueNode.isInt || valueNode.isLong) {
      "BIGINT"
    } else if (valueNode.isNumber) {
      "DOUBLE"
    } else if (valueNode.isBoolean) {
      "BOOLEAN"
    } else {
      "STRING"
    }
  }

}

finish method that transforms buffer to output

For finish method, at the end of aggregation, we need to transform our Map[String, String] into a DDL representation of our schema.

override def finish(reduction: Map[String, String]): String = toDDL(reduction.toList)

For toDDL method, we recursively treat each key, removing a level at each recursive call. So for instance, if we have the following schema:

field type

a

STRING

b.A

BIGINT

b.B

STRING

b.C.X

STRING

c[]

BOOLEAN

First recursive call treats a, b and c fields, then second call treats A, B and C fields, and third and last recursive call treats X field.

Arrays are a special case that can be break into two case:

  • If represented json is an array, first level should be of DDL type ARRAY<> instead of STRUCT<>.

  • For nested arrays, DDL type ARRAY<> should be displayed instead of putting directly the datatype. So, a[] → STRING should be converted to `a` ARRAY<STRING> instead of `a` STRING

We also need to manage edge case making empty array an array of strings. We’ve already managed the case for arrays containing BIGINT, DOUBLE and BOOLEAN in merge method, but it remains arrays containing STRUCT case. For instance, our Map schema representation can contain things like that:

Map(
  "a[]" -> "STRING",
  "a[].A" -> "BIGINT",
  "a[].B" -> "STRING"
)

In this case, the returned DDL string should be `a` ARRAY<STRUCT<`A` BIGINT, `B` STRING>>. We manage this case by removing key a[] if there are other keys starting by a[] in Map schema.

So we have the following toDDL method:

private def toDDL(schema: List[(String, String)]): String = if (schema.forall(_._1.startsWith("[]"))) {
  schema.groupBy(_._1.split("\\.").head).map {
    case ("[]", ("[]", datatype) :: Nil) => datatype
    case (_, list) => toDDL(dropPrefix(list).filter(_ != ("", "STRING")))
  }.mkString("ARRAY<", ", ", ">")
} else {
  schema.groupBy(_._1.split("\\.").head).map {
    case (name, (_, datatype) :: Nil) if name.endsWith("[]") => s"`${name.dropRight(2)}`: ARRAY<$datatype>"
    case (name, (_, datatype) :: Nil) => s"`$name`: $datatype"
    case (name, list) if name.endsWith("[]") => s"`${name.dropRight(2)}`: ARRAY<${toDDL(dropPrefix(list).filter(_ != ("", "STRING")))}>"
    case (name, list) => s"`$name`: ${toDDL(dropPrefix(list))}"
  }.mkString("STRUCT<", ", ", ">")
}

private def dropPrefix(list: List[(String, String)]): List[(String, String)] = {
  list.map(elem => (elem._1.split("\\.").tail.mkString("."), elem._2))
}

Complete JsonSchemaAggregator code

Here is the complete JsonSchemaAggregator code. Besides all the methods we described in previous section, we added the two methods bufferEncoder and outputEncoder to encode buffer and output object. This aggregator works with the SparkSchemaParser object defined in previous section.

import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.{Encoder, Encoders}

object JsonSchemaAggregator extends Aggregator[String, Map[String, String], String] {

  override def zero: Map[String, String] = Map.empty[String, String]

  override def reduce(currentSchema: Map[String, String], json: String): Map[String, String] = {
    merge(SparkSchemaParser.fromJson(json), currentSchema)
  }

  override def merge(schema1: Map[String, String], schema2: Map[String, String]): Map[String, String] = {
    (schema1.toSeq ++ schema2.toSeq).groupBy(_._1).map(elem => {
      if (elem._1.endsWith("[]") && elem._2.head._2 == "STRING") {
        elem._2.last
      } else {
        elem._2.head
      }
    })
  }

  override def finish(reduction: Map[String, String]): String = toDDL(reduction.toList)

  private def toDDL(schema: List[(String, String)]): String = if (schema.forall(_._1.startsWith("[]"))) {
    schema.groupBy(_._1.split("\\.").head).map {
      case ("[]", ("[]", datatype) :: Nil) => datatype
      case (_, list) => toDDL(dropPrefix(list).filter(_ != ("", "STRING")))
    }.mkString("ARRAY<", ", ", ">")
  } else {
    schema.groupBy(_._1.split("\\.").head).map {
      case (name, (_, datatype) :: Nil) if name.endsWith("[]") => s"`${name.dropRight(2)}`: ARRAY<$datatype>"
      case (name, (_, datatype) :: Nil) => s"`$name`: $datatype"
      case (name, list) if name.endsWith("[]") => s"`${name.dropRight(2)}`: ARRAY<${toDDL(dropPrefix(list).filter(_ != ("", "STRING")))}>"
      case (name, list) => s"`$name`: ${toDDL(dropPrefix(list))}"
    }.mkString("STRUCT<", ", ", ">")
  }

  private def dropPrefix(list: List[(String, String)]): List[(String, String)] = {
    list.map(elem => (elem._1.split("\\.").tail.mkString("."), elem._2))
  }

  override def bufferEncoder: Encoder[Map[String, String]] = ExpressionEncoder[Map[String, String]]

  override def outputEncoder: Encoder[String] = Encoders.STRING

}

Transform JsonSchemaAggregator to user-defined aggregate function

We transform our JsonSchemaAggregator to an user-defined aggregated functions using udaf function, as follows:

import org.apache.spark.sql.functions.udaf

val json_schema = udaf(JsonSchemaAggregator)

Apply our json_schema user-defined aggregated function

We can then apply json_schema function on any dataframe. So if we want to apply this function on input dataframe we defined in Problem statement’s section:

+--------------------------------+
|value                           |
+--------------------------------+
|{"a": 1, "b": "value1"}         |
|{"b": "value2", "c": [1, 2, 3]} |
|{"a": 3, "d": {"d1" : "value3"}}|
|{"d": {"d2" : 4}}               |
+--------------------------------+

We use the following code:

import org.apache.spark.sql.functions.col

val output = input.agg(json_schema(col("value")).alias("schema"))

That give the following output dataframe:

+---------------------------------------------------------------------------------------------+
|schema                                                                                       |
+---------------------------------------------------------------------------------------------+
|STRUCT<`a`: BIGINT, `d`: STRUCT<`d2`: BIGINT, `d1`: STRING>, `b`: STRING, `c`: ARRAY<BIGINT>>|
+---------------------------------------------------------------------------------------------+

That we can collect and retrieve Spark’s schema using DataType.toDDL method:

import org.apache.spark.sql.types.DataType

val schema = DataType.fromDDL(output.collect().head.getString(0))

Voilà ! We can extract schema from a dataframe of json string values.

Going further

JsonSchemaAggregator code is more a proof of concept and has several limits:

  • Recursive functions we defined to parse schemas are not tail-recursive. So for very complex json schema it may throw stack overflows errors.

  • In JsonSchemaAggregator's merge method, there may be a more efficient way to merge two Map, for instance by using mutable map.

  • When merging two Map representing schema, we didn’t take into account if two types conflict. For instance if we have the following two JSON {"a": 1} and {"a": "hello"}, our code should raise an error instead of taking one of the type randomly

  • We can use more precise types instead of being limited to STRING, BOOLEAN, BIGINT and DOUBLE. For instance add DECIMAL type.

Conclusion

We presented a JsonSchemaAggregator Aggregator implementation that can be used to extract spark schema from a dataframe containing json string as values. However, it is more a proof of concept, it may need some improvement to use it in production application