machine learning - ¿Qué es una explicación intuitiva de la técnica de maximización de expectativas?
machine-learning cluster-analysis (8)
Maximización de expectativas si es un tipo de método probabilístico para clasificar datos. Por favor corrígeme si estoy equivocado si no es un clasificador.
¿Cuál es una explicación intuitiva de esta técnica EM? ¿Qué es la expectativa aquí y qué se está maximizando?
Aquí hay una receta sencilla para entender el algoritmo de maximización de expectativas:
1- Lea este documento tutorial EM de Do y Batzoglou.
2- Puede tener signos de interrogación en su cabeza, eche un vistazo a las explicaciones en esta page intercambio de fichas matemáticas.
3- Observa este código que escribí en Python que explica el ejemplo en el documento tutorial EM del ítem 1:
Advertencia: el código puede ser desordenado / subóptimo, ya que no soy un desarrollador de Python. Pero hace el trabajo.
import numpy as np
import math
#### E-M Coin Toss Example as given in the EM tutorial paper by Do and Batzoglou* ####
def get_mn_log_likelihood(obs,probs):
""" Return the (log)likelihood of obs, given the probs"""
# Multinomial Distribution Log PMF
# ln (pdf) = multinomial coeff * product of probabilities
# ln[f(x|n, p)] = [ln(n!) - (ln(x1!)+ln(x2!)+...+ln(xk!))] + [x1*ln(p1)+x2*ln(p2)+...+xk*ln(pk)]
multinomial_coeff_denom= 0
prod_probs = 0
for x in range(0,len(obs)): # loop through state counts in each observation
multinomial_coeff_denom = multinomial_coeff_denom + math.log(math.factorial(obs[x]))
prod_probs = prod_probs + obs[x]*math.log(probs[x])
multinomial_coeff = math.log(math.factorial(sum(obs))) - multinomial_coeff_denom
likelihood = multinomial_coeff + prod_probs
return likelihood
# 1st: Coin B, {HTTTHHTHTH}, 5H,5T
# 2nd: Coin A, {HHHHTHHHHH}, 9H,1T
# 3rd: Coin A, {HTHHHHHTHH}, 8H,2T
# 4th: Coin B, {HTHTTTHHTT}, 4H,6T
# 5th: Coin A, {THHHTHHHTH}, 7H,3T
# so, from MLE: pA(heads) = 0.80 and pB(heads)=0.45
# represent the experiments
head_counts = np.array([5,9,8,4,7])
tail_counts = 10-head_counts
experiments = zip(head_counts,tail_counts)
# initialise the pA(heads) and pB(heads)
pA_heads = np.zeros(100); pA_heads[0] = 0.60
pB_heads = np.zeros(100); pB_heads[0] = 0.50
# E-M begins!
delta = 0.001
j = 0 # iteration counter
improvement = float(''inf'')
while (improvement>delta):
expectation_A = np.zeros((5,2), dtype=float)
expectation_B = np.zeros((5,2), dtype=float)
for i in range(0,len(experiments)):
e = experiments[i] # i''th experiment
ll_A = get_mn_log_likelihood(e,np.array([pA_heads[j],1-pA_heads[j]])) # loglikelihood of e given coin A
ll_B = get_mn_log_likelihood(e,np.array([pB_heads[j],1-pB_heads[j]])) # loglikelihood of e given coin B
weightA = math.exp(ll_A) / ( math.exp(ll_A) + math.exp(ll_B) ) # corresponding weight of A proportional to likelihood of A
weightB = math.exp(ll_B) / ( math.exp(ll_A) + math.exp(ll_B) ) # corresponding weight of B proportional to likelihood of B
expectation_A[i] = np.dot(weightA, e)
expectation_B[i] = np.dot(weightB, e)
pA_heads[j+1] = sum(expectation_A)[0] / sum(sum(expectation_A));
pB_heads[j+1] = sum(expectation_B)[0] / sum(sum(expectation_B));
improvement = max( abs(np.array([pA_heads[j+1],pB_heads[j+1]]) - np.array([pA_heads[j],pB_heads[j]]) ))
j = j+1
EM es un algoritmo para maximizar una función de verosimilitud cuando algunas de las variables en su modelo no son observadas (es decir, cuando tiene variables latentes).
Es muy probable que pregunte si solo estamos tratando de maximizar una función, ¿por qué no utilizamos la maquinaria existente para maximizar una función? Bueno, si intenta maximizar esto tomando derivados y poniéndolos a cero, encontrará que en muchos casos las condiciones de primer orden no tienen una solución. Hay un problema de huevo y gallina porque para resolver los parámetros de su modelo necesita saber la distribución de sus datos no observados; pero la distribución de sus datos no observados es una función de los parámetros de su modelo.
EM intenta sortear esto adivinando iterativamente una distribución para los datos no observados, luego estima los parámetros del modelo maximizando algo que es un límite inferior en la función de verosimilitud real, y repitiendo hasta la convergencia:
El algoritmo EM
Comience con adivinar los valores de los parámetros de su modelo
E-step: para cada punto de datos que tenga valores perdidos, use su ecuación modelo para resolver la distribución de los datos faltantes dado su conjetura actual de los parámetros del modelo y los datos observados (tenga en cuenta que está resolviendo una distribución por cada faltante) valor, no para el valor esperado). Ahora que tenemos una distribución para cada valor perdido, podemos calcular la expectativa de la función de verosimilitud con respecto a las variables no observadas. Si nuestra estimación del parámetro del modelo fue correcta, esta probabilidad esperada será la probabilidad real de nuestros datos observados; si los parámetros no son correctos, será solo un límite inferior.
M-step: ahora que tenemos una función de verosimilitud esperada sin variables no observadas, maximice la función como lo haría en el caso completamente observado, para obtener una nueva estimación de los parámetros de su modelo.
Repita hasta la convergencia.
EM se utiliza para maximizar la probabilidad de un modelo Q con variables latentes Z.
Es una optimización iterativa.
theta <- initial guess for hidden parameters
while not converged:
#e-step
Q(theta''|theta) = E[log L(theta|Z)]
#m-step
theta <- argmax_theta'' Q(theta''|theta)
e-step: dada la estimación actual de Z calcule la función de loglikelihood esperada
m-step: encuentra theta que maximiza esta Q
GMM Ejemplo:
Paso electrónico: estimación de las asignaciones de etiquetas para cada punto de datos dada la estimación actual del parámetro gmm
m-step: maximiza una nueva theta dada la nueva asignación de etiquetas
K-means también es un algoritmo EM y hay muchas animaciones explicativas en K-means.
La respuesta aceptada hace referencia al Chuong EM Paper , que hace un trabajo decente explicando EM. También hay un video de youtube que explica el documento con más detalle.
Para recapitular, aquí está el escenario:
1st: {H,T,T,T,H,H,T,H,T,H} 5 Heads, 5 Tails; Did coin A or B generate me?
2nd: {H,H,H,H,T,H,H,H,H,H} 9 Heads, 1 Tails
3rd: {H,T,H,H,H,H,H,T,H,H} 8 Heads, 2 Tails
4th: {H,T,H,T,T,T,H,H,T,T} 4 Heads, 6 Tails
5th: {T,H,H,H,T,H,H,H,T,H} 7 Heads, 3 Tails
Two possible coins, A & B are used to generate these distributions.
A & B have an unknown parameter: their bias towards heads.
We don''t know the biases, but we can simply start with a guess: A=60% heads, B=50% heads.
En el caso de la primera pregunta de la prueba, intuitivamente pensamos que B la generó ya que la proporción de cabezas coincide con el sesgo de B muy bien ... pero ese valor fue solo una suposición, por lo que no podemos estar seguros.
Con eso en mente, me gusta pensar en la solución EM así:
- Cada prueba de volteos consigue ''votar'' sobre qué moneda le gusta más
- Esto se basa en qué tan bien encaja cada moneda en su distribución
- O, desde el punto de vista de la moneda, existe una gran expectativa de ver esta prueba en relación con la otra moneda (en función de las probabilidades logarítmicas ).
- Dependiendo de cuánto le gusta a cada prueba cada moneda, puede actualizar la conjetura del parámetro de esa moneda (sesgo).
- Mientras más le guste una moneda a una prueba, más actualizará el sesgo de la moneda para que refleje la suya propia.
- Básicamente, los sesgos de la moneda se actualizan al combinar estas actualizaciones ponderadas en todos los ensayos, un proceso llamado ( maximación ), que se refiere a tratar de obtener las mejores suposiciones para cada sesgo de una moneda dada una serie de ensayos.
Esto puede ser una simplificación excesiva (o incluso fundamentalmente errónea en algunos niveles), ¡pero espero que esto ayude a un nivel intuitivo!
Si otras respuestas son buenas, intentaré proporcionar otra perspectiva y abordaré la parte intuitiva de la pregunta.
El algoritmo EM (Expectation-Maximization) es una variante de una clase de algoritmos iterativos que utilizan la duality
Extracto (el énfasis es mío):
En matemáticas, una dualidad, en términos generales, traduce conceptos, teoremas o estructuras matemáticas en otros conceptos, teoremas o estructuras, de manera uno a uno, a menudo (pero no siempre) por medio de una operación de involución: si el dual de A es B, entonces el dual de B es A. Tales involuciones a veces tienen puntos fijos , de modo que el dual de A es A mismo
Por lo general, una doble B de un objeto A está relacionada con A de alguna manera que conserva cierta simetría o compatibilidad . Por ejemplo AB = const
Los ejemplos de algoritmos iterativos que emplean dualidad (en el sentido anterior) son:
- Algoritmo euclidiano para Greatest Common Divisor y sus variantes
- Algoritmo y variantes de Gram-Schmidt Vector Basis
- Media aritmética - Desigualdad media geométrica y sus variantes
- Algoritmo de maximización de expectativas y sus variantes (ver también aquí para obtener una vista de información geométrica )
- (.. otros algoritmos similares ..)
De manera similar, el algoritmo EM también se puede ver como dos pasos de maximización dual :
.. [EM] se considera que maximiza una función conjunta de los parámetros y de la distribución sobre las variables no observadas. El paso E maximiza esta función con respecto a la distribución sobre las variables no observadas; el M-paso con respecto a los parámetros ..
En un algoritmo iterativo que usa la dualidad existe la suposición explícita (o implícita) de un punto de convergencia de equilibrio (o fijo) (para EM esto se prueba usando la desigualdad de Jensen)
Entonces el bosquejo de tales algoritmos es:
- Paso E-like: Encuentra la mejor solución x con respecto a dado y se mantiene constante.
- Paso M (dual): encuentre la mejor solución y con respecto a x (como se calculó en el paso anterior) que se mantiene constante.
- Criterio de terminación / paso de convergencia: repita los pasos 1 y 2 con los valores actualizados de x , y hasta la convergencia (o se alcanza el número especificado de iteraciones)
Tenga en cuenta que cuando dicho algoritmo converge a un óptimo (global), ha encontrado una configuración que es mejor en ambos sentidos (es decir, tanto en el dominio / parámetros x como en el dominio / parámetros). Sin embargo, el algoritmo solo puede encontrar un óptimo local y no el óptimo global .
Diría que esta es la descripción intuitiva del esquema del algoritmo
Para los argumentos estadísticos y las aplicaciones, otras respuestas han dado buenas explicaciones (consulte también las referencias en esta respuesta)
Supongamos que tenemos algunos datos muestreados de dos grupos diferentes, rojo y azul:
Aquí, podemos ver qué punto de datos pertenece al grupo rojo o azul. Esto hace que sea más fácil encontrar los parámetros que caracterizan a cada grupo. Por ejemplo, la media del grupo rojo es alrededor de 3, la media del grupo azul es alrededor de 7 (y podríamos encontrar el medio exacto si quisiéramos).
Esto es, en términos generales, conocido como estimación de máxima verosimilitud . Dados algunos datos, calculamos el valor de un parámetro (o parámetros) que mejor explica esos datos.
Ahora imagina que no podemos ver qué valor se muestreó de qué grupo. Todo se ve morado para nosotros:
Aquí tenemos el conocimiento de que hay dos grupos de valores, pero no sabemos a qué grupo pertenece ningún valor en particular.
¿Todavía podemos estimar los medios para el grupo rojo y azul que mejor se ajustan a esta información?
Sí, ¡a menudo podemos! La maximización de expectativas nos brinda una manera de hacerlo. La idea muy general detrás del algoritmo es esta:
- Comience con una estimación inicial de lo que podría ser cada parámetro.
- Calcule la probabilidad de que cada parámetro produzca el punto de datos ( expectativa ).
- Calcule los pesos para cada punto de datos en función de la probabilidad de que sea producido por un parámetro.
- Combine estos pesos junto con los datos para calcular una mejor estimación de los parámetros ( maximización ).
- Repita los pasos 2 a 4 hasta que la estimación del parámetro converja (el proceso deja de producir una estimación diferente).
Estos pasos necesitan una explicación más detallada, por lo que analizaré el problema de la media roja / azul que se describió anteriormente.
Ejemplo: estimación de la media y la desviación estándar
Usaré Python en este ejemplo, pero el código debería ser bastante fácil de entender si no está familiarizado con este idioma.
Supongamos que tenemos dos grupos, rojo y azul, con los valores distribuidos como en la imagen de arriba. Específicamente, cada grupo contiene un valor extraído de una distribución normal con los siguientes parámetros:
import numpy as np
from scipy import stats
np.random.seed(110) # for reproducible random results
# set parameters
red_mean = 3
red_std = 0.8
blue_mean = 7
blue_std = 2
# draw 20 samples from normal distributions with red/blue parameters
red = np.random.normal(red_mean, red_std, size=20)
blue = np.random.normal(blue_mean, blue_std, size=20)
both_colours = np.sort(np.concatenate((red, blue))) # for later use...
Aquí hay una imagen de estos grupos rojos y azules nuevamente (para evitar tener que desplazarse hacia arriba):
Cuando podemos ver el color de cada punto (es decir, a qué grupo pertenece), es muy fácil estimar la media y la desviación estándar para cada grupo. Pasamos los valores rojo y azul a las funciones integradas en NumPy. Por ejemplo:
>>> np.mean(red)
2.802
>>> np.std(red)
0.871
>>> np.mean(blue)
6.932
>>> np.std(blue)
2.195
Pero, ¿y si no podemos ver los colores de los puntos? Es decir, en lugar de rojo o azul, cada punto ha sido de color morado.
Para tratar de recuperar los parámetros de media y desviación estándar para los grupos rojo y azul, podemos usar Maximización de expectativa.
Nuestro primer paso ( paso 1 anterior) es adivinar los valores de los parámetros para la media y la desviación estándar de cada grupo. No tenemos que adivinar inteligentemente; podemos elegir cualquier número que deseemos:
# estimates for the mean
red_mean_guess = 1.1
blue_mean_guess = 9
# estimates for the standard deviation
red_std_guess = 2
blue_std_guess = 1.7
Estas estimaciones de parámetros producen curvas de campana que se ven así:
Estas son malas estimaciones. Ambos medios (las líneas de puntos verticales) se ven muy lejos de cualquier tipo de "medio" para grupos sensibles de puntos, por ejemplo. Queremos mejorar estas estimaciones.
El siguiente paso ( paso 2 ) es calcular la probabilidad de que cada punto de datos aparezca bajo las conjeturas de parámetros actuales:
likelihood_of_red = stats.norm(red_mean_guess, red_std_guess).pdf(both_colours)
likelihood_of_blue = stats.norm(blue_mean_guess, blue_std_guess).pdf(both_colours)
Aquí, simplemente hemos puesto cada punto de datos en la función de densidad de probabilidad para una distribución normal usando nuestras conjeturas actuales en la media y la desviación estándar para rojo y azul. Esto nos dice, por ejemplo, que con nuestras conjeturas actuales, el punto de datos en 1.761 es mucho más probable que sea rojo (0.189) que azul (0.00003).
Para cada punto de datos, podemos convertir estos dos valores de verosimilitud en ponderaciones ( paso 3 ) de modo que sumen a 1 de la siguiente manera:
likelihood_total = likelihood_of_red + likelihood_of_blue
red_weight = likelihood_of_red / likelihood_total
blue_weight = likelihood_of_blue / likelihood_total
Con nuestras estimaciones actuales y nuestras ponderaciones recientemente calculadas, ahora podemos calcular nuevas estimaciones para la media y la desviación estándar de los grupos rojo y azul ( paso 4 ).
Calculamos dos veces la media y la desviación estándar usando todos los puntos de datos, pero con las diferentes ponderaciones: una para los pesos rojos y una para los pesos azules.
La clave de la intuición es que cuanto mayor es el peso de un color en un punto de datos, más influye el punto de datos en las próximas estimaciones de los parámetros de ese color. Esto tiene el efecto de "tirar" los parámetros en la dirección correcta.
def estimate_mean(data, weight):
return np.sum(data * weight) / np.sum(weight)
def estimate_std(data, weight, mean):
variance = np.sum(weight * (data - mean)**2) / np.sum(weight)
return np.sqrt(variance)
# new estimates for standard deviation
blue_std_guess = estimate_std(both_colours, blue_weight, blue_mean_guess)
red_std_guess = estimate_std(both_colours, red_weight, red_mean_guess)
# new estimates for mean
red_mean_guess = estimate_mean(both_colours, red_weight)
blue_mean_guess = estimate_mean(both_colours, blue_weight)
Tenemos nuevas estimaciones para los parámetros. Para mejorarlos de nuevo, podemos volver al paso 2 y repetir el proceso. Hacemos esto hasta que las estimaciones convergen o después de que se haya realizado un número de iteraciones ( paso 5 ).
Para nuestros datos, las primeras cinco iteraciones de este proceso se ven así (las iteraciones recientes tienen una apariencia más sólida):
Vemos que los medios ya convergen en algunos valores, y las formas de las curvas (gobernadas por la desviación estándar) también se vuelven más estables.
Si continuamos durante 20 iteraciones, terminamos con lo siguiente:
El proceso EM ha convergido a los siguientes valores, que se acercan mucho a los valores reales (donde podemos ver los colores, sin variables ocultas):
| EM guess | Actual | Delta
----------+----------+--------+-------
Red mean | 2.910 | 2.802 | 0.108
Red std | 0.854 | 0.871 | -0.017
Blue mean | 6.838 | 6.932 | -0.094
Blue std | 2.227 | 2.195 | 0.032
Técnicamente, el término "EM" está poco especificado, pero supongo que se refiere a la técnica de análisis de conglomerados de modelado de mezclas de Gauss, que es una instancia del principio EM general.
En realidad, el análisis de clúster EM no es un clasificador . Sé que algunas personas consideran que la agrupación es una "clasificación no supervisada", pero en realidad el análisis de conglomerados es algo bastante diferente.
La diferencia clave y la gran confusión que las personas siempre tienen con el análisis de clusters es que: en el análisis de clústeres, no existe una "solución correcta" . Es un método de descubrimiento de conocimiento, ¡en realidad está destinado a encontrar algo nuevo ! Esto hace que la evaluación sea muy difícil. A menudo se evalúa usando una clasificación conocida como referencia, pero eso no siempre es apropiado: la clasificación que tiene puede o no reflejar lo que está en los datos.
Déjame darte un ejemplo: tienes un gran conjunto de datos de clientes, incluidos datos de género. Un método que divide este conjunto de datos en "masculino" y "femenino" es óptimo cuando lo compara con las clases existentes. En una forma de pensar de "predicción", esto es bueno, ya que para los nuevos usuarios ahora puedes predecir su género. En una forma de pensar "descubrimiento de conocimiento" esto es realmente malo, porque quería descubrir alguna nueva estructura en los datos. Un método que, por ejemplo, dividiría los datos en personas mayores y niños, sin embargo, puntuaría peor que con la clase de hombres y mujeres. Sin embargo, ese sería un excelente resultado de agrupamiento (si no se dio la edad).
Ahora volvamos a EM. Básicamente, asume que sus datos están compuestos de múltiples distribuciones normales multivariantes (tenga en cuenta que esta es una suposición muy fuerte, en particular cuando se fija el número de clústeres). Luego trata de encontrar un modelo óptimo local para esto al mejorar de forma alternativa el modelo y la asignación de objetos al modelo .
Para obtener los mejores resultados en un contexto de clasificación, elija el número de clústeres más grandes que el número de clases, o incluso aplique el agrupamiento a clases individuales solamente (para averiguar si hay alguna estructura dentro de la clase).
Digamos que quieres entrenar a un clasificador para distinguir "autos", "bicicletas" y "camiones". Hay poco uso en suponer que los datos consisten en exactamente 3 distribuciones normales. Sin embargo, puede suponer que hay más de un tipo de automóviles (y camiones y bicicletas). Entonces, en lugar de entrenar un clasificador para estas tres clases, agrupe autos, camiones y bicicletas en 10 clústers cada uno (o quizás 10 automóviles, 3 camiones y 3 bicicletas, lo que sea), luego entrene a un clasificador para distinguir estas 30 clases, y luego fusionar el resultado de la clase con las clases originales. También puede descubrir que hay un clúster que es particularmente difícil de clasificar, por ejemplo Trikes. Son algo de autos, y algo de bicicletas. O camiones de reparto, que son más como autos de gran tamaño que camiones.
Usando el mismo artículo de Do y Batzoglou citado en la respuesta de Zhubarb, implementé EM para ese problema en Java . Los comentarios a su respuesta muestran que el algoritmo se queda atascado en un óptimo local, lo que también ocurre con mi implementación si los parámetros thetaA y thetaB son los mismos.
A continuación se muestra la salida estándar de mi código, que muestra la convergencia de los parámetros.
thetaA = 0.71301, thetaB = 0.58134
thetaA = 0.74529, thetaB = 0.56926
thetaA = 0.76810, thetaB = 0.54954
thetaA = 0.78316, thetaB = 0.53462
thetaA = 0.79106, thetaB = 0.52628
thetaA = 0.79453, thetaB = 0.52239
thetaA = 0.79593, thetaB = 0.52073
thetaA = 0.79647, thetaB = 0.52005
thetaA = 0.79667, thetaB = 0.51977
thetaA = 0.79674, thetaB = 0.51966
thetaA = 0.79677, thetaB = 0.51961
thetaA = 0.79678, thetaB = 0.51960
thetaA = 0.79679, thetaB = 0.51959
Final result:
thetaA = 0.79678, thetaB = 0.51960
Debajo está mi implementación de Java de EM para resolver el problema en (Do y Batzoglou, 2008). La parte central de la implementación es el ciclo para ejecutar EM hasta que los parámetros convergen.
private Parameters _parameters;
public Parameters run()
{
while (true)
{
expectation();
Parameters estimatedParameters = maximization();
if (_parameters.converged(estimatedParameters)) {
break;
}
_parameters = estimatedParameters;
}
return _parameters;
}
A continuación está el código completo.
import java.util.*;
/*****************************************************************************
This class encapsulates the parameters of the problem. For this problem posed
in the article by (Do and Batzoglou, 2008), the parameters are thetaA and
thetaB, the probability of a coin coming up heads for the two coins A and B,
respectively.
*****************************************************************************/
class Parameters
{
double _thetaA = 0.0; // Probability of heads for coin A.
double _thetaB = 0.0; // Probability of heads for coin B.
double _delta = 0.00001;
public Parameters(double thetaA, double thetaB)
{
_thetaA = thetaA;
_thetaB = thetaB;
}
/*************************************************************************
Returns true if this parameter is close enough to another parameter
(typically the estimated parameter coming from the maximization step).
*************************************************************************/
public boolean converged(Parameters other)
{
if (Math.abs(_thetaA - other._thetaA) < _delta &&
Math.abs(_thetaB - other._thetaB) < _delta)
{
return true;
}
return false;
}
public double getThetaA()
{
return _thetaA;
}
public double getThetaB()
{
return _thetaB;
}
public String toString()
{
return String.format("thetaA = %.5f, thetaB = %.5f", _thetaA, _thetaB);
}
}
/*****************************************************************************
This class encapsulates an observation, that is the number of heads
and tails in a trial. The observation can be either (1) one of the
experimental observations, or (2) an estimated observation resulting from
the expectation step.
*****************************************************************************/
class Observation
{
double _numHeads = 0;
double _numTails = 0;
public Observation(String s)
{
for (int i = 0; i < s.length(); i++)
{
char c = s.charAt(i);
if (c == ''H'')
{
_numHeads++;
}
else if (c == ''T'')
{
_numTails++;
}
else
{
throw new RuntimeException("Unknown character: " + c);
}
}
}
public Observation(double numHeads, double numTails)
{
_numHeads = numHeads;
_numTails = numTails;
}
public double getNumHeads()
{
return _numHeads;
}
public double getNumTails()
{
return _numTails;
}
public String toString()
{
return String.format("heads: %.1f, tails: %.1f", _numHeads, _numTails);
}
}
/*****************************************************************************
This class runs expectation-maximization for the problem posed by the article
from (Do and Batzoglou, 2008).
*****************************************************************************/
public class EM
{
// Current estimated parameters.
private Parameters _parameters;
// Observations from the trials. These observations are set once.
private final List<Observation> _observations;
// Estimated observations per coin. These observations are the output
// of the expectation step.
private List<Observation> _expectedObservationsForCoinA;
private List<Observation> _expectedObservationsForCoinB;
private static java.io.PrintStream o = System.out;
/*************************************************************************
Principal constructor.
@param observations The observations from the trial.
@param parameters The initial guessed parameters.
*************************************************************************/
public EM(List<Observation> observations, Parameters parameters)
{
_observations = observations;
_parameters = parameters;
}
/*************************************************************************
Run EM until parameters converge.
*************************************************************************/
public Parameters run()
{
while (true)
{
expectation();
Parameters estimatedParameters = maximization();
o.printf("%s/n", estimatedParameters);
if (_parameters.converged(estimatedParameters)) {
break;
}
_parameters = estimatedParameters;
}
return _parameters;
}
/*************************************************************************
Given the observations and current estimated parameters, compute new
estimated completions (distribution over the classes) and observations.
*************************************************************************/
private void expectation()
{
_expectedObservationsForCoinA = new ArrayList<Observation>();
_expectedObservationsForCoinB = new ArrayList<Observation>();
for (Observation observation : _observations)
{
int numHeads = (int)observation.getNumHeads();
int numTails = (int)observation.getNumTails();
double probabilityOfObservationForCoinA=
binomialProbability(10, numHeads, _parameters.getThetaA());
double probabilityOfObservationForCoinB=
binomialProbability(10, numHeads, _parameters.getThetaB());
double normalizer = probabilityOfObservationForCoinA +
probabilityOfObservationForCoinB;
// Compute the completions for coin A and B (i.e. the probability
// distribution of the two classes, summed to 1.0).
double completionCoinA = probabilityOfObservationForCoinA /
normalizer;
double completionCoinB = probabilityOfObservationForCoinB /
normalizer;
// Compute new expected observations for the two coins.
Observation expectedObservationForCoinA =
new Observation(numHeads * completionCoinA,
numTails * completionCoinA);
Observation expectedObservationForCoinB =
new Observation(numHeads * completionCoinB,
numTails * completionCoinB);
_expectedObservationsForCoinA.add(expectedObservationForCoinA);
_expectedObservationsForCoinB.add(expectedObservationForCoinB);
}
}
/*************************************************************************
Given new estimated observations, compute new estimated parameters.
*************************************************************************/
private Parameters maximization()
{
double sumCoinAHeads = 0.0;
double sumCoinATails = 0.0;
double sumCoinBHeads = 0.0;
double sumCoinBTails = 0.0;
for (Observation observation : _expectedObservationsForCoinA)
{
sumCoinAHeads += observation.getNumHeads();
sumCoinATails += observation.getNumTails();
}
for (Observation observation : _expectedObservationsForCoinB)
{
sumCoinBHeads += observation.getNumHeads();
sumCoinBTails += observation.getNumTails();
}
return new Parameters(sumCoinAHeads / (sumCoinAHeads + sumCoinATails),
sumCoinBHeads / (sumCoinBHeads + sumCoinBTails));
//o.printf("parameters: %s/n", _parameters);
}
/*************************************************************************
Since the coin-toss experiment posed in this article is a Bernoulli trial,
use a binomial probability Pr(X=k; n,p) = (n choose k) * p^k * (1-p)^(n-k).
*************************************************************************/
private static double binomialProbability(int n, int k, double p)
{
double q = 1.0 - p;
return nChooseK(n, k) * Math.pow(p, k) * Math.pow(q, n-k);
}
private static long nChooseK(int n, int k)
{
long numerator = 1;
for (int i = 0; i < k; i++)
{
numerator = numerator * n;
n--;
}
long denominator = factorial(k);
return (long)(numerator / denominator);
}
private static long factorial(int n)
{
long result = 1;
for (; n >0; n--)
{
result = result * n;
}
return result;
}
/*************************************************************************
Entry point into the program.
*************************************************************************/
public static void main(String argv[])
{
// Create the observations and initial parameter guess
// from the (Do and Batzoglou, 2008) article.
List<Observation> observations = new ArrayList<Observation>();
observations.add(new Observation("HTTTHHTHTH"));
observations.add(new Observation("HHHHTHHHHH"));
observations.add(new Observation("HTHHHHHTHH"));
observations.add(new Observation("HTHTTTHHTT"));
observations.add(new Observation("THHHTHHHTH"));
Parameters initialParameters = new Parameters(0.6, 0.5);
EM em = new EM(observations, initialParameters);
Parameters finalParameters = em.run();
o.printf("Final result:/n%s/n", finalParameters);
}
}