apache-spark - org - spark 2.2 mllib
Tratar con conjuntos de datos desequilibrados en Spark MLlib (3)
Estoy trabajando en un problema de clasificación binario en particular con un conjunto de datos altamente desequilibrado, y me preguntaba si alguien ha intentado implementar técnicas específicas para tratar con conjuntos de datos no equilibrados (como SMOTE ) en problemas de clasificación utilizando MLlib de Spark.
Estoy usando la implementación de Random Forest de MLLib y ya probé el enfoque más simple de submuestrear al azar a la clase más grande, pero no funcionó tan bien como esperaba.
Agradecería cualquier comentario con respecto a su experiencia con problemas similares.
Gracias,
Peso de la clase con Spark ML
A partir de este momento, la ponderación de clase para el algoritmo de bosque aleatorio aún está en desarrollo (ver here )
Pero si está dispuesto a probar otros clasificadores, esta funcionalidad ya se ha agregado a la Regresión logística .
Considere un caso en el que tengamos un 80% de positivos (etiqueta == 1) en el conjunto de datos, por lo que teóricamente queremos "sub-muestrear" la clase positiva. La función objetivo de pérdida logística debe tratar la clase negativa (etiqueta == 0) con mayor peso.
Aquí hay un ejemplo en Scala de generar este peso, agregamos una nueva columna al marco de datos para cada registro en el conjunto de datos:
def balanceDataset(dataset: DataFrame): DataFrame = {
// Re-balancing (weighting) of records to be used in the logistic loss objective function
val numNegatives = dataset.filter(dataset("label") === 0).count
val datasetSize = dataset.count
val balancingRatio = (datasetSize - numNegatives).toDouble / datasetSize
val calculateWeights = udf { d: Double =>
if (d == 0.0) {
1 * balancingRatio
}
else {
(1 * (1.0 - balancingRatio))
}
}
val weightedDataset = dataset.withColumn("classWeightCol", calculateWeights(dataset("label")))
weightedDataset
}
Entonces, creamos una clase como sigue:
new LogisticRegression().setWeightCol("classWeightCol").setLabelCol("label").setFeaturesCol("features")
Para más detalles, vea aquí: https://issues.apache.org/jira/browse/SPARK-9610
- Poder de predicción
Debe comprobar un problema diferente: si sus características tienen un "poder predictivo" para la etiqueta que está tratando de predecir. En un caso en el que después de un muestreo insuficiente todavía tiene poca precisión, tal vez eso no tenga nada que ver con el hecho de que su conjunto de datos está desequilibrado por naturaleza.
Haría un análisis exploratorio de los datos : si el clasificador no funciona mejor que una elección aleatoria, existe el riesgo de que simplemente no haya conexión entre las características y la clase.
- Realizar análisis de correlación para cada característica con la etiqueta.
- La generación de histogramas específicos de clase para entidades (es decir, el trazado de histogramas de los datos para cada clase, para una entidad dada en el mismo eje) también puede ser una buena manera de mostrar si una entidad discrimina bien entre las dos clases.
Ajuste excesivo: un error bajo en su conjunto de entrenamiento y un error alto en su conjunto de prueba pueden ser una indicación de que se adapta excesivamente a un conjunto de características demasiado flexible.
Desviación de desviación: compruebe si su clasificador tiene un problema de alta desviación o de alta desviación.
- Error de entrenamiento vs. error de validación: grafique el error de validación y el error del conjunto de entrenamiento, como una función de los ejemplos de entrenamiento (haga un aprendizaje incremental)
- Si las líneas parecen converger al mismo valor y se cierran al final, entonces su clasificador tiene un alto sesgo. En tal caso, agregar más datos no ayudará. Cambie el clasificador por uno que tenga mayor variación, o simplemente reduzca el parámetro de regularización de su actual.
- Si, por otro lado, las líneas están bastante alejadas y tiene un error de conjunto de entrenamiento bajo pero un error de validación alto, entonces su clasificador tiene una variación demasiado alta. En este caso, obtener más datos es muy probable que ayude. Si después de obtener más datos, la variación seguirá siendo demasiado alta, puede aumentar el parámetro de regularización.
@dbakr ¿Recibió una respuesta para su predicción sesgada en su conjunto de datos desequilibrado?
Aunque no estoy seguro de que fuera su plan original, tenga en cuenta que si primero toma una muestra de la mayoría de la clase de su conjunto de datos en una relación r , luego, para obtener predicciones no balanceadas para la regresión logística de Spark, puede: - usar el RawPrediction proporcionado por la función transform()
y ajuste la intercepción con log(r)
, o puede entrenar su regresión con ponderaciones usando .setWeightCol("classWeightCol")
(consulte el artículo citado here para averiguar el valor que debe establecerse en el pesos).
Usé la solución de @Serendipity, pero podemos optimizar la función balanceDataset para evitar el uso de un udf. También agregué la posibilidad de cambiar la columna de etiqueta que se está utilizando. Esta es la versión de la función que terminé con:
def balanceDataset(dataset: DataFrame, label: String = "label"): DataFrame = {
// Re-balancing (weighting) of records to be used in the logistic loss objective function
val (datasetSize, positives) = dataset.select(count("*"), sum(dataset(label))).as[(Long, Double)].collect.head
val balancingRatio = positives / datasetSize
val weightedDataset = {
dataset.withColumn("classWeightCol", when(dataset(label) === 0.0, balancingRatio).otherwise(1.0 - balancingRatio))
}
weightedDataset
}
Creamos el clasificador como él declaró con:
new LogisticRegression().setWeightCol("classWeightCol").setLabelCol("label").setFeaturesCol("features")