In Spark 3, given the following data in a dataframe:
| person | relation | other |
|---|---|---|
Roger |
likes |
Brian |
Alice |
hates |
Robert |
Alice |
loves |
Roger |
Roger |
likes |
Alice |
Robert |
likes |
Roger |
Return for each person what relation it has with others persons, as in the following dataframe:
| person | other_relations |
|---|---|
Roger |
[Brian → likes, Alice → likes] |
Alice |
[Robert → hates, Roger → loves] |
Robert |
[Roger → likes] |
To do so, we will :
-
for each line of input dataframe, create new column
other_relationcontaining astructwith columnsotherandrelation -
then for each
person, collect all theother_relationinto another_relationsmap
Translated in code, it gives the following snippet:
import org.apache.spark.sql.functions.{col, collect_set, map_from_entries, struct}
inputDataframe
.withColumn("other_relation", struct(col("other"), col("relation")))
.groupBy("person")
.agg(map_from_entries(collect_set("other_relation")).as("other_relations"))