reducebykey - Cómo obtener elemento por índice en Spark RDD(Java)
spark group by (3)
Sé el método rdd.first () que me da el primer elemento en un RDD.
También está el método rdd.take (num) que me da los primeros elementos "num".
Pero, ¿no existe la posibilidad de obtener un elemento por índice?
Gracias.
Intenté esta clase para buscar un elemento por índice. Primero, cuando construye new IndexedFetcher(rdd, itemClass)
, cuenta la cantidad de elementos en cada partición del RDD. Luego, cuando llama a indexedFetcher.get(n)
, ejecuta un trabajo solo en la partición que contiene ese índice.
Tenga en cuenta que necesitaba compilar esto usando Java 1.7 en lugar de 1.8; a partir de Spark 1.1.0, el paquete org.objectweb.asm dentro de com.esotericsoftware.reflectasm aún no puede leer las clases Java 1.8 (arroja IllegalStateException cuando intentas ejecutar una función Java 1.8).
import java.io.Serializable;
import org.apache.spark.SparkContext;
import org.apache.spark.TaskContext;
import org.apache.spark.rdd.RDD;
import scala.reflect.ClassTag;
public static class IndexedFetcher<E> implements Serializable {
private static final long serialVersionUID = 1L;
public final RDD<E> rdd;
public Integer[] elementsPerPartitions;
private Class<?> clazz;
public IndexedFetcher(RDD<E> rdd, Class<?> clazz){
this.rdd = rdd;
this.clazz = clazz;
SparkContext context = this.rdd.context();
ClassTag<Integer> intClassTag = scala.reflect.ClassTag$.MODULE$.<Integer>apply(Integer.class);
elementsPerPartitions = (Integer[]) context.<E, Integer>runJob(rdd, IndexedFetcher.<E>countFunction(), intClassTag);
}
public static class IteratorCountFunction<E> extends scala.runtime.AbstractFunction2<TaskContext, scala.collection.Iterator<E>, Integer> implements Serializable {
private static final long serialVersionUID = 1L;
@Override public Integer apply(TaskContext taskContext, scala.collection.Iterator<E> iterator) {
int count = 0;
while (iterator.hasNext()) {
count++;
iterator.next();
}
return count;
}
}
static <E> scala.Function2<TaskContext, scala.collection.Iterator<E>, Integer> countFunction() {
scala.Function2<TaskContext, scala.collection.Iterator<E>, Integer> function = new IteratorCountFunction<E>();
return function;
}
public E get(long index) {
long remaining = index;
long totalCount = 0;
for (int partition = 0; partition < elementsPerPartitions.length; partition++) {
if (remaining < elementsPerPartitions[partition]) {
return getWithinPartition(partition, remaining);
}
remaining -= elementsPerPartitions[partition];
totalCount += elementsPerPartitions[partition];
}
throw new IllegalArgumentException(String.format("Get %d within RDD that has only %d elements", index, totalCount));
}
public static class FetchWithinPartitionFunction<E> extends scala.runtime.AbstractFunction2<TaskContext, scala.collection.Iterator<E>, E> implements Serializable {
private static final long serialVersionUID = 1L;
private final long indexWithinPartition;
public FetchWithinPartitionFunction(long indexWithinPartition) {
this.indexWithinPartition = indexWithinPartition;
}
@Override public E apply(TaskContext taskContext, scala.collection.Iterator<E> iterator) {
int count = 0;
while (iterator.hasNext()) {
E element = iterator.next();
if (count == indexWithinPartition)
return element;
count++;
}
throw new IllegalArgumentException(String.format("Fetch %d within partition that has only %d elements", indexWithinPartition, count));
}
}
public E getWithinPartition(int partition, long indexWithinPartition) {
System.out.format("getWithinPartition(%d, %d)%n", partition, indexWithinPartition);
SparkContext context = rdd.context();
scala.Function2<TaskContext, scala.collection.Iterator<E>, E> function = new FetchWithinPartitionFunction<E>(indexWithinPartition);
scala.collection.Seq<Object> partitions = new scala.collection.mutable.WrappedArray.ofInt(new int[] {partition});
ClassTag<E> classTag = scala.reflect.ClassTag$.MODULE$.<E>apply(this.clazz);
E[] result = (E[]) context.<E, E>runJob(rdd, function, partitions, true, classTag);
return result[0];
}
}
Me quedé atrapado en esto por un tiempo también, así que expandir la respuesta de Maasg pero respondiendo para buscar un rango de valores por índice para Java (necesitarás definir las 4 variables en la parte superior):
DataFrame df;
SQLContext sqlContext;
Long start;
Long end;
JavaPairRDD<Row, Long> indexedRDD = df.toJavaRDD().zipWithIndex();
JavaRDD filteredRDD = indexedRDD.filter((Tuple2<Row,Long> v1) -> v1._2 >= start && v1._2 < end);
DataFrame filteredDataFrame = sqlContext.createDataFrame(filteredRDD, df.schema());
Recuerde que cuando ejecuta este código, su clúster deberá tener Java 8 (ya que se está utilizando una expresión lambda).
Además, ¡zipWithIndex es probablemente costoso!
Esto debería ser posible indexando primero el RDD. La transformación zipWithIndex
proporciona una indexación estable, numerando cada elemento en su orden original.
Dado: rdd = (a,b,c)
val withIndex = rdd.zipWithIndex // ((a,0),(b,1),(c,2))
Para buscar un elemento por índice, este formulario no es útil. Primero necesitamos usar el índice como clave:
val indexKey = withIndex.map{case (k,v) => (v,k)} //((0,a),(1,b),(2,c))
Ahora, es posible usar la acción de lookup
en PairRDD para encontrar un elemento por clave:
val b = indexKey.lookup(1) // Array(b)
Si esperas utilizar la lookup
menudo en el mismo RDD, te recomendaría indexKey
en caché el indexKey
RDD para mejorar el rendimiento.
Cómo hacer esto usando la API de Java es un ejercicio que queda para el lector.