In Spark 3, given the following data in a dataframe:

person relation other

Roger

likes

Brian

Alice

hates

Robert

Alice

loves

Roger

Roger

likes

Alice

Robert

likes

Roger

Return for each person what relation it has with others persons, as in the following dataframe:

person other_relations

Roger

[Brian → likes, Alice → likes]

Alice

[Robert → hates, Roger → loves]

Robert

[Roger → likes]

To do so, we will :

  • for each line of input dataframe, create new column other_relation containing a struct with columns other and relation

  • then for each person, collect all the other_relation into an other_relations map

Translated in code, it gives the following snippet:

import org.apache.spark.sql.functions.{col, collect_set, map_from_entries, struct}

inputDataframe
  .withColumn("other_relation", struct(col("other"), col("relation")))
  .groupBy("person")
  .agg(map_from_entries(collect_set("other_relation")).as("other_relations"))