sql - read - ¿Cómo obtener otras columnas al usar Spark DataFrame groupby?
spark sql functions (7)
cuando uso DataFrame groupby así:
df.groupBy(df("age")).agg(Map("id"->"count"))
Solo obtendré un DataFrame con las columnas "age" y "count (id)", pero en df, hay muchas otras columnas como "name".
En general, quiero obtener el resultado como en MySQL,
"seleccione nombre, edad, recuento (id) del grupo df por edad"
¿Qué debo hacer cuando uso groupby en Spark?
Aquí un ejemplo que encontré en el taller de chispas
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.StringType
import scala.collection.mutable
object TestJob3 {
def main (args: Array[String]): Unit = {
val sparkSession = SparkSession
.builder()
.appName(this.getClass.getName.replace("$", ""))
.master("local")
.getOrCreate()
val sc = sparkSession.sparkContext
import sparkSession.sqlContext.implicits._
val rawDf = Seq(
(1, "Moe", "Slap", 2.0, 18),
(2, "Larry", "Spank", 3.0, 15),
(3, "Curly", "Twist", 5.0, 15),
(4, "Laurel", "Whimper", 3.0, 9),
(5, "Hardy", "Laugh", 6.0, 18),
(6, "Charley", "Ignore", 5.0, 5)
).toDF("id", "name", "requisite", "money", "age")
rawDf.show(false)
rawDf.printSchema
val rawSchema = rawDf.schema
val fUdf = udf(reduceByMoney, rawSchema)
val nameUdf = udf(extractName, StringType)
val aggDf = rawDf
.groupBy("age")
.agg(
count(struct("*")).as("count"),
max(col("money")),
collect_list(struct("*")).as("horizontal")
)
.withColumn("short", fUdf($"horizontal"))
.withColumn("name", nameUdf($"short"))
.drop("horizontal")
aggDf.printSchema
aggDf.show(false)
}
def reduceByMoney= (x: Any) => {
val d = x.asInstanceOf[mutable.WrappedArray[GenericRowWithSchema]]
val red = d.reduce((r1, r2) => {
val money1 = r1.getAs[Double]("money")
val money2 = r2.getAs[Double]("money")
val r3 = money1 match {
case a if a >= money2 =>
r1
case _ =>
r2
}
r3
})
red
}
def extractName = (x: Any) => {
val d = x.asInstanceOf[GenericRowWithSchema]
d.getAs[String]("name")
}
}
val maxPopulationDF = populationDF.agg(max(''population).as("populationmax"))
Para obtener otras columnas, hago una unión simple entre el DF original y el agregado
+---+-----+----------+----------------------------+-------+
|age|count|max(money)|short |name |
+---+-----+----------+----------------------------+-------+
|5 |1 |5.0 |[6, Charley, Ignore, 5.0, 5]|Charley|
|15 |2 |5.0 |[3, Curly, Twist, 5.0, 15] |Curly |
|9 |1 |3.0 |[4, Laurel, Whimper, 3.0, 9]|Laurel |
|18 |2 |6.0 |[5, Hardy, Laugh, 6.0, 18] |Hardy |
+---+-----+----------+----------------------------+-------+
Debe recordar que las funciones agregadas reducen las filas y, por lo tanto, debe especificar cuál de los nombres de filas desea con una función reductora. Si desea retener todas las filas de un grupo (¡advertencia! Esto puede causar explosiones o particiones sesgadas), puede recopilarlas como una lista. Luego puede usar un UDF (función definida por el usuario) para reducirlos según sus criterios, en mi ejemplo dinero. Y luego expanda las columnas de la fila reducida única con otro UDF. A los fines de esta respuesta, supongo que desea conservar el nombre de la persona que tiene más dinero.
name age id
abc 24 1001
cde 24 1002
efg 22 1003
ghi 21 1004
ijk 20 1005
klm 19 1006
mno 18 1007
pqr 18 1008
rst 26 1009
tuv 27 1010
pqr 18 1012
rst 28 1013
tuv 29 1011
aquí está la salida
df.select("name","age","id").groupBy("name","age").count().show();
En pocas palabras, en general, debe unir los resultados agregados con la tabla original. Spark SQL sigue la misma convención anterior a SQL: 1999 que la mayoría de las principales bases de datos (PostgreSQL, Oracle, MS SQL Server) que no permite columnas adicionales en consultas de agregación.
Dado que para las agregaciones como los resultados de conteo no están bien definidos y el comportamiento tiende a variar en los sistemas que admiten este tipo de consultas, puede incluir columnas adicionales utilizando agregados arbitrarios como
first
o
last
.
En algunos casos, puede reemplazar
agg
usando
select
con funciones de ventana y posterior
where
pero dependiendo del contexto, puede ser bastante costoso.
Las funciones agregadas reducen los valores de filas para columnas específicas dentro del grupo. Si desea retener otros valores de fila, debe implementar una lógica de reducción que especifique una fila de la que proviene cada valor. Por ejemplo, mantenga todos los valores de la primera fila con el valor máximo de edad. Para este fin, puede usar un UDAF (función agregada definida por el usuario) para reducir las filas dentro del grupo.
import org.apache.spark.sql._
import org.apache.spark.sql.functions._
object AggregateKeepingRowJob {
def main (args: Array[String]): Unit = {
val sparkSession = SparkSession
.builder()
.appName(this.getClass.getName.replace("$", ""))
.master("local")
.getOrCreate()
val sc = sparkSession.sparkContext
sc.setLogLevel("ERROR")
import sparkSession.sqlContext.implicits._
val rawDf = Seq(
(1L, "Moe", "Slap", 2.0, 18),
(2L, "Larry", "Spank", 3.0, 15),
(3L, "Curly", "Twist", 5.0, 15),
(4L, "Laurel", "Whimper", 3.0, 15),
(5L, "Hardy", "Laugh", 6.0, 15),
(6L, "Charley", "Ignore", 5.0, 5)
).toDF("id", "name", "requisite", "money", "age")
rawDf.show(false)
rawDf.printSchema
val maxAgeUdaf = new KeepRowWithMaxAge
val aggDf = rawDf
.groupBy("age")
.agg(
count("id"),
max(col("money")),
maxAgeUdaf(
col("id"),
col("name"),
col("requisite"),
col("money"),
col("age")).as("KeepRowWithMaxAge")
)
aggDf.printSchema
aggDf.show(false)
}
}
El UDAF:
import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
class KeepRowWithMaxAmt extends UserDefinedAggregateFunction {
// This is the input fields for your aggregate function.
override def inputSchema: org.apache.spark.sql.types.StructType =
StructType(
StructField("store", StringType) ::
StructField("prod", StringType) ::
StructField("amt", DoubleType) ::
StructField("units", IntegerType) :: Nil
)
// This is the internal fields you keep for computing your aggregate.
override def bufferSchema: StructType = StructType(
StructField("store", StringType) ::
StructField("prod", StringType) ::
StructField("amt", DoubleType) ::
StructField("units", IntegerType) :: Nil
)
// This is the output type of your aggregation function.
override def dataType: DataType =
StructType((Array(
StructField("store", StringType),
StructField("prod", StringType),
StructField("amt", DoubleType),
StructField("units", IntegerType)
)))
override def deterministic: Boolean = true
// This is the initial value for your buffer schema.
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = ""
buffer(1) = ""
buffer(2) = 0.0
buffer(3) = 0
}
// This is how to update your buffer schema given an input.
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
val amt = buffer.getAs[Double](2)
val candidateAmt = input.getAs[Double](2)
amt match {
case a if a < candidateAmt =>
buffer(0) = input.getAs[String](0)
buffer(1) = input.getAs[String](1)
buffer(2) = input.getAs[Double](2)
buffer(3) = input.getAs[Int](3)
case _ =>
}
}
// This is how to merge two objects with the bufferSchema type.
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1(0) = buffer2.getAs[String](0)
buffer1(1) = buffer2.getAs[String](1)
buffer1(2) = buffer2.getAs[Double](2)
buffer1(3) = buffer2.getAs[Int](3)
}
// This is where you output the final value, given the final value of your bufferSchema.
override def evaluate(buffer: Row): Any = {
buffer
}
}
Puede ser esta solución será útil.
from pyspark.sql import SQLContext
from pyspark import SparkContext, SparkConf
from pyspark.sql import functions as F
from pyspark.sql import Window
name_list = [(101, ''abc'', 24), (102, ''cde'', 24), (103, ''efg'', 22), (104, ''ghi'', 21),
(105, ''ijk'', 20), (106, ''klm'', 19), (107, ''mno'', 18), (108, ''pqr'', 18),
(109, ''rst'', 26), (110, ''tuv'', 27), (111, ''pqr'', 18), (112, ''rst'', 28), (113, ''tuv'', 29)]
age_w = Window.partitionBy("age")
name_age_df = sqlContext.createDataFrame(name_list, [''id'', ''name'', ''age''])
name_age_count_df = name_age_df.withColumn("count", F.count("id").over(age_w)).orderBy("count")
name_age_count_df.show()
Salida:
+---+----+---+-----+
| id|name|age|count|
+---+----+---+-----+
|109| rst| 26| 1|
|113| tuv| 29| 1|
|110| tuv| 27| 1|
|106| klm| 19| 1|
|103| efg| 22| 1|
|104| ghi| 21| 1|
|105| ijk| 20| 1|
|112| rst| 28| 1|
|101| abc| 24| 2|
|102| cde| 24| 2|
|107| mno| 18| 3|
|111| pqr| 18| 3|
|108| pqr| 18| 3|
+---+----+---+-----+
Puedes hacer así:
Data de muestra:
+----+---+-----+
|name|age|count|
+----+---+-----+
| efg| 22| 1|
| tuv| 29| 1|
| rst| 28| 1|
| klm| 19| 1|
| pqr| 18| 2|
| cde| 24| 1|
| tuv| 27| 1|
| ijk| 20| 1|
| abc| 24| 1|
| mno| 18| 1|
| ghi| 21| 1|
| rst| 26| 1|
+----+---+-----+
df.select("name","age","id").groupBy("name","age").count().show();
Salida:
+----+---+-----+ |name|age|count| +----+---+-----+ | efg| 22| 1| | tuv| 29| 1| | rst| 28| 1| | klm| 19| 1| | pqr| 18| 2| | cde| 24| 1| | tuv| 27| 1| | ijk| 20| 1| | abc| 24| 1| | mno| 18| 1| | ghi| 21| 1| | rst| 26| 1| +----+---+-----+
Una forma de obtener todas las columnas después de hacer un groupBy es usar la función de unión.
feature_group = [''name'', ''age'']
data_counts = df.groupBy(feature_group).count().alias("counts")
data_joined = df.join(data_counts, feature_group)
data_joined ahora tendrá todas las columnas, incluidos los valores de conteo.