In Spark 3, given the following data in a dataframe:
person | relation | other |
---|---|---|
Roger |
likes |
Brian |
Alice |
hates |
Robert |
Alice |
loves |
Roger |
Roger |
likes |
Alice |
Robert |
likes |
Roger |
Return for each person what relation it has with others persons, as in the following dataframe:
person | other_relations |
---|---|
Roger |
[Brian → likes, Alice → likes] |
Alice |
[Robert → hates, Roger → loves] |
Robert |
[Roger → likes] |
To do so, we will :
-
for each line of input dataframe, create new column
other_relation
containing astruct
with columnsother
andrelation
-
then for each
person
, collect all theother_relation
into another_relations
map
Translated in code, it gives the following snippet:
import org.apache.spark.sql.functions.{col, collect_set, map_from_entries, struct}
inputDataframe
.withColumn("other_relation", struct(col("other"), col("relation")))
.groupBy("person")
.agg(map_from_entries(collect_set("other_relation")).as("other_relations"))