scala apache-spark dataframe apache-spark-sql pivot

scala - ¿Cómo pivotar Spark DataFrame?



apache-spark apache-spark-sql (9)

Hay muchos ejemplos de operaciones dinámicas en el conjunto de datos / marco de datos, pero no pude encontrar muchos usando SQL. Aquí hay un ejemplo que funcionó para mí.

# pass an optional list of string to avoid computation of columns def pivot(df, group_by, key, aggFunction, levels=[]): if not levels: levels = [row[key] for row in df.filter(col(key).isNotNull()).groupBy(col(key)).agg(count(key)).select(key).collect()] return df.filter(col(key).isin(*levels) == True).groupBy(group_by).agg(map_from_entries(collect_list(struct(key, expr(aggFunction)))).alias("group_map")).select([group_by] + ["group_map." + l for l in levels]) # Usage pivot(df, "id", "key", "value") pivot(df, "id", "key", "array(value)")

Estoy empezando a usar Spark DataFrames y necesito poder pivotar los datos para crear múltiples columnas de 1 columna con múltiples filas. Hay una funcionalidad integrada para eso en Scalding y creo en Pandas en Python, pero no puedo encontrar nada para el nuevo Spark Dataframe.

Supongo que puedo escribir una función personalizada de algún tipo que haga esto, pero ni siquiera estoy seguro de cómo comenzar, especialmente porque soy un novato con Spark. Si alguien sabe cómo hacer esto con la funcionalidad integrada o sugerencias sobre cómo escribir algo en Scala, es muy apreciado.


Hay una solución simple y elegante.

import java.util.Date import org.apache.spark.SparkConf import org.apache.spark.SparkContext import org.apache.spark.sql.Row import org.apache.spark.sql.DataFrame import org.apache.spark.sql.types.StructType import org.apache.spark.sql.hive.HiveContext import org.apache.spark.sql.types.StructField ... ... def main(args: Array[String]): Unit = { val sc = new SparkContext(conf) val sqlContext = new org.apache.spark.sql.SQLContext(sc) val dfdata1 = sqlContext.read.format("com.databricks.spark.csv").option("header", "true").option("inferSchema", "true") .load("data.csv") dfdata1.show() val dfOutput = transpose(new HiveContext(sc), dfdata1, List("id"), "tag", "value") dfOutput.show }


He resuelto un problema similar usando marcos de datos con los siguientes pasos:

Cree columnas para todos sus países, con ''valor'' como valor:

import org.apache.spark.sql.functions._ val countries = List("US", "UK", "Can") val countryValue = udf{(countryToCheck: String, countryInRow: String, value: Long) => if(countryToCheck == countryInRow) value else 0 } val countryFuncs = countries.map{country => (dataFrame: DataFrame) => dataFrame.withColumn(country, countryValue(lit(country), df("tag"), df("value"))) } val dfWithCountries = Function.chain(countryFuncs)(df).drop("tag").drop("value")

Su marco de datos ''dfWithCountries'' se verá así:

+--+--+---+---+ |id|US| UK|Can| +--+--+---+---+ | 1|50| 0| 0| | 1| 0|100| 0| | 1| 0| 0|125| | 2|75| 0| 0| | 2| 0|150| 0| | 2| 0| 0|175| +--+--+---+---+

Ahora puede sumar todos los valores para el resultado deseado:

dfWithCountries.groupBy("id").sum(countries: _*).show

Resultado:

+--+-------+-------+--------+ |id|SUM(US)|SUM(UK)|SUM(Can)| +--+-------+-------+--------+ | 1| 50| 100| 125| | 2| 75| 150| 175| +--+-------+-------+--------+

Sin embargo, no es una solución muy elegante. Tuve que crear una cadena de funciones para agregar en todas las columnas. Además, si tengo muchos países, expandiré mi conjunto de datos temporales a un conjunto muy amplio con muchos ceros.


Inicialmente adopté la solución de Al M. Más tarde tomó el mismo pensamiento y reescribió esta función como una función de transposición.

Este método transpone cualquier fila de df a columnas de cualquier formato de datos con el uso de la columna de clave y valor

para entrada csv

scala> spark.sql("select * from k_tags limit 10").show() +---------------+-------------+------+ | imsi| name| value| +---------------+-------------+------+ |246021000000000| age| 37| |246021000000000| gender|Female| |246021000000000| arpu| 22| |246021000000000| DeviceType| Phone| |246021000000000|DataAllowance| 6GB| +---------------+-------------+------+ scala> spark.sql("select * from k_tags limit 10").groupBy($"imsi").pivot("name").agg(min($"value")).show() +---------------+-------------+----------+---+----+------+ | imsi|DataAllowance|DeviceType|age|arpu|gender| +---------------+-------------+----------+---+----+------+ |246021000000000| 6GB| Phone| 37| 22|Female| |246021000000001| 1GB| Phone| 72| 10| Male| +---------------+-------------+----------+---+----+------+

salida

id,tag,value 1,US,50a 1,UK,100 1,Can,125 2,US,75 2,UK,150 2,Can,175

método de transposición:

+--+---+---+---+ |id| UK| US|Can| +--+---+---+---+ | 2|150| 75|175| | 1|100|50a|125| +--+---+---+---+

fragmento principal

def transpose(hc : HiveContext , df: DataFrame,compositeId: List[String], key: String, value: String) = { val distinctCols = df.select(key).distinct.map { r => r(0) }.collect().toList val rdd = df.map { row => (compositeId.collect { case id => row.getAs(id).asInstanceOf[Any] }, scala.collection.mutable.Map(row.getAs(key).asInstanceOf[Any] -> row.getAs(value).asInstanceOf[Any])) } val pairRdd = rdd.reduceByKey(_ ++ _) val rowRdd = pairRdd.map(r => dynamicRow(r, distinctCols)) hc.createDataFrame(rowRdd, getSchema(df.schema, compositeId, (key, distinctCols))) } private def dynamicRow(r: (List[Any], scala.collection.mutable.Map[Any, Any]), colNames: List[Any]) = { val cols = colNames.collect { case col => r._2.getOrElse(col.toString(), null) } val array = r._1 ++ cols Row(array: _*) } private def getSchema(srcSchema: StructType, idCols: List[String], distinctCols: (String, List[Any])): StructType = { val idSchema = idCols.map { idCol => srcSchema.apply(idCol) } val colSchema = srcSchema.apply(distinctCols._1) val colsSchema = distinctCols._2.map { col => StructField(col.asInstanceOf[String], colSchema.dataType, colSchema.nullable) } StructType(idSchema ++ colsSchema) }


La función de pivote de chispa incorporada es ineficiente. La siguiente implementación funciona en spark 2.4+: la idea es agregar un mapa y extraer los valores como columnas. La única limitación es que no maneja la función de agregado en las columnas pivotantes, solo las columnas.

En una tabla de 8M, esas funciones se aplican en 3 segundos , frente a 40 minutos en la versión de chispa incorporada:

// pass an optional list of string to avoid computation of columns def pivot(df: DataFrame, groupBy: Column, key: Column, aggFunct: String, _levels: List[String] = Nil): DataFrame = { val levels = if (_levels.isEmpty) df.filter(key.isNotNull).select(key).distinct().collect().map(row => row.getString(0)).toList else _levels df .filter(key.isInCollection(levels)) .groupBy(groupBy) .agg(map_from_entries(collect_list(struct(key, expr(aggFunct)))).alias("group_map")) .select(groupBy.toString, levels.map(f => "group_map." + f): _*) } // Usage: pivot(df, col("id"), col("key"), "value") pivot(df, col("id"), col("key"), "array(value)")

// pass an optional list of string to avoid computation of columns def pivot(df: DataFrame, groupBy: Column, key: Column, aggFunct: String, _levels: List[String] = Nil): DataFrame = { val levels = if (_levels.isEmpty) df.filter(key.isNotNull).select(key).distinct().collect().map(row => row.getString(0)).toList else _levels df .filter(key.isInCollection(levels)) .groupBy(groupBy) .agg(map_from_entries(collect_list(struct(key, expr(aggFunct)))).alias("group_map")) .select(groupBy.toString, levels.map(f => "group_map." + f): _*) } // Usage: pivot(df, col("id"), col("key"), "value") pivot(df, col("id"), col("key"), "array(value)")


Lo superé escribiendo un bucle for para crear dinámicamente una consulta SQL. Digamos que tengo:

id tag value 1 US 50 1 UK 100 1 Can 125 2 US 75 2 UK 150 2 Can 175

y yo quiero:

id US UK Can 1 50 100 125 2 75 150 175

Puedo crear una lista con el valor que quiero pivotar y luego crear una cadena que contenga la consulta SQL que necesito.

val countries = List("US", "UK", "Can") val numCountries = countries.length - 1 var query = "select *, " for (i <- 0 to numCountries-1) { query += """case when tag = """" + countries(i) + """" then value else 0 end as """ + countries(i) + ", " } query += """case when tag = """" + countries.last + """" then value else 0 end as """ + countries.last + " from myTable" myDataFrame.registerTempTable("myTable") val myDF1 = sqlContext.sql(query)

Puedo crear una consulta similar para luego hacer la agregación. No es una solución muy elegante, pero funciona y es flexible para cualquier lista de valores, que también se puede pasar como argumento cuando se llama a su código.



Spark ha estado proporcionando mejoras a Pivoting the Spark DataFrame. Se ha agregado una función pivote a la API Spark DataFrame a la versión Spark 1.6 y tiene un problema de rendimiento que se ha corregido en Spark 2.0

sin embargo, si está utilizando una versión inferior; tenga en cuenta que pivote es una operación muy costosa, por lo tanto, se recomienda proporcionar datos de columna (si se conoce) como argumento para funcionar como se muestra a continuación.

val countries = Seq("USA","China","Canada","Mexico") val pivotDF = df.groupBy("Product").pivot("Country", countries).sum("Amount") pivotDF.show()

Esto se ha explicado detalladamente en Pivoting and Unpivoting Spark DataFrame

Feliz aprendizaje !!


Como lo menciona David Anderson, Spark proporciona la función pivot desde la versión 1.6. La sintaxis general tiene el siguiente aspecto:

df .groupBy(grouping_columns) .pivot(pivot_column, [values]) .agg(aggregate_expressions)

Ejemplos de uso con el formato nycflights13 y csv :

Python :

from pyspark.sql.functions import avg flights = (sqlContext .read .format("csv") .options(inferSchema="true", header="true") .load("flights.csv") .na.drop()) flights.registerTempTable("flights") sqlContext.cacheTable("flights") gexprs = ("origin", "dest", "carrier") aggexpr = avg("arr_delay") flights.count() ## 336776 %timeit -n10 flights.groupBy(*gexprs ).pivot("hour").agg(aggexpr).count() ## 10 loops, best of 3: 1.03 s per loop

Scala :

val flights = sqlContext .read .format("csv") .options(Map("inferSchema" -> "true", "header" -> "true")) .load("flights.csv") flights .groupBy($"origin", $"dest", $"carrier") .pivot("hour") .agg(avg($"arr_delay"))

Java :

import static org.apache.spark.sql.functions.*; import org.apache.spark.sql.*; Dataset<Row> df = spark.read().format("csv") .option("inferSchema", "true") .option("header", "true") .load("flights.csv"); df.groupBy(col("origin"), col("dest"), col("carrier")) .pivot("hour") .agg(avg(col("arr_delay")));

R / SparkR :

library(magrittr) flights <- read.df("flights.csv", source="csv", header=TRUE, inferSchema=TRUE) flights %>% groupBy("origin", "dest", "carrier") %>% pivot("hour") %>% agg(avg(column("arr_delay")))

R / sparklyr

library(dplyr) flights <- spark_read_csv(sc, "flights", "flights.csv") avg.arr.delay <- function(gdf) { expr <- invoke_static( sc, "org.apache.spark.sql.functions", "avg", "arr_delay" ) gdf %>% invoke("agg", expr, list()) } flights %>% sdf_pivot(origin + dest + carrier ~ hour, fun.aggregate=avg.arr.delay)

SQL :

Tenga en cuenta que la palabra clave PIVOT en Spark SQL es compatible a partir de la versión 2.4.

CREATE TEMPORARY VIEW flights USING csv OPTIONS (header ''true'', path ''flights.csv'', inferSchema ''true'') ; SELECT * FROM ( SELECT origin, dest, carrier, arr_delay, hour FROM flights ) PIVOT ( avg(arr_delay) FOR hour IN (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23) );

Datos de ejemplo :

"year","month","day","dep_time","sched_dep_time","dep_delay","arr_time","sched_arr_time","arr_delay","carrier","flight","tailnum","origin","dest","air_time","distance","hour","minute","time_hour" 2013,1,1,517,515,2,830,819,11,"UA",1545,"N14228","EWR","IAH",227,1400,5,15,2013-01-01 05:00:00 2013,1,1,533,529,4,850,830,20,"UA",1714,"N24211","LGA","IAH",227,1416,5,29,2013-01-01 05:00:00 2013,1,1,542,540,2,923,850,33,"AA",1141,"N619AA","JFK","MIA",160,1089,5,40,2013-01-01 05:00:00 2013,1,1,544,545,-1,1004,1022,-18,"B6",725,"N804JB","JFK","BQN",183,1576,5,45,2013-01-01 05:00:00 2013,1,1,554,600,-6,812,837,-25,"DL",461,"N668DN","LGA","ATL",116,762,6,0,2013-01-01 06:00:00 2013,1,1,554,558,-4,740,728,12,"UA",1696,"N39463","EWR","ORD",150,719,5,58,2013-01-01 05:00:00 2013,1,1,555,600,-5,913,854,19,"B6",507,"N516JB","EWR","FLL",158,1065,6,0,2013-01-01 06:00:00 2013,1,1,557,600,-3,709,723,-14,"EV",5708,"N829AS","LGA","IAD",53,229,6,0,2013-01-01 06:00:00 2013,1,1,557,600,-3,838,846,-8,"B6",79,"N593JB","JFK","MCO",140,944,6,0,2013-01-01 06:00:00 2013,1,1,558,600,-2,753,745,8,"AA",301,"N3ALAA","LGA","ORD",138,733,6,0,2013-01-01 06:00:00

Consideraciones de rendimiento :

En general, pivotar es una operación costosa.

  • si puede, intente proporcionar una lista de values , ya que esto evita un golpe adicional para calcular los únicos:

    vs = list(range(25)) %timeit -n10 flights.groupBy(*gexprs ).pivot("hour", vs).agg(aggexpr).count() ## 10 loops, best of 3: 392 ms per loop

  • en algunos casos demostró ser beneficioso (probablemente ya no valga la pena el esfuerzo en 2.0 o posterior ) para repartition y / o agregar previamente los datos

  • solo para remodelar, puede usar first : la columna Pivot String en Pyspark Dataframe

Preguntas relacionadas :

  • ¿Cómo derretir Spark DataFrame?
  • Desvincular en spark-sql / pyspark
  • Transponer columna a fila con Spark