python - example - Trazar el dendrograma con sklearn.AglomeraciĆ³nClustering
python dendrogram (4)
Estoy tratando de construir un dendrograma usando el atributo children_
proporcionado por AgglomerativeClustering
, pero hasta ahora no he tenido suerte. No puedo usar scipy.cluster
ya que la aglomeración en scipy
carece de algunas opciones que son importantes para mí (como la opción de especificar la cantidad de clústeres). Estaría muy agradecido por algún consejo por ahí.
import sklearn.cluster
clstr = cluster.AgglomerativeClustering(n_clusters=2)
clusterer.children_
Aquí hay una función simple para tomar un modelo de agrupamiento jerárquico de sklearn y trazarlo usando la función de dendrogram
scipy. Parece que las funciones de representación gráfica a menudo no se admiten directamente en sklearn. here puede encontrar una discusión interesante de la relacionada con la solicitud de extracción para este fragmento de código plot_dendrogram
.
Me gustaría aclarar que el caso de uso que describes (definir el número de clusters) está disponible en scipy: después de realizar el clúster jerárquico utilizando el linkage
de scipy puedes cortar la jerarquía a la cantidad de clústeres que quieras utilizando fcluster
con número de clusters especificado en el argumento t
y criterion=''maxclust''
argumento criterion=''maxclust''
.
Me encontré con el mismo problema hace algún tiempo. La forma en que logré trazar el maldito dendograma fue usando el paquete de software ete3 . Este paquete es capaz de trazar de manera flexible árboles con varias opciones. La única dificultad era convertir la salida children_
sklearn
en el formato Newick Tree que ete3
puede leer y comprender. Además, necesito calcular manualmente el lapso de la dendrita porque esa información no se proporcionó con children_
. Aquí hay un fragmento del código que utilicé. Calcula el árbol Newick y luego muestra la ete3
Árbol ete3. Para más detalles sobre cómo trazar, mira here
import numpy as np
from sklearn.cluster import AgglomerativeClustering
import ete3
def build_Newick_tree(children,n_leaves,X,leaf_labels,spanner):
"""
build_Newick_tree(children,n_leaves,X,leaf_labels,spanner)
Get a string representation (Newick tree) from the sklearn
AgglomerativeClustering.fit output.
Input:
children: AgglomerativeClustering.children_
n_leaves: AgglomerativeClustering.n_leaves_
X: parameters supplied to AgglomerativeClustering.fit
leaf_labels: The label of each parameter array in X
spanner: Callable that computes the dendrite''s span
Output:
ntree: A str with the Newick tree representation
"""
return go_down_tree(children,n_leaves,X,leaf_labels,len(children)+n_leaves-1,spanner)[0]+'';''
def go_down_tree(children,n_leaves,X,leaf_labels,nodename,spanner):
"""
go_down_tree(children,n_leaves,X,leaf_labels,nodename,spanner)
Iterative function that traverses the subtree that descends from
nodename and returns the Newick representation of the subtree.
Input:
children: AgglomerativeClustering.children_
n_leaves: AgglomerativeClustering.n_leaves_
X: parameters supplied to AgglomerativeClustering.fit
leaf_labels: The label of each parameter array in X
nodename: An int that is the intermediate node name whos
children are located in children[nodename-n_leaves].
spanner: Callable that computes the dendrite''s span
Output:
ntree: A str with the Newick tree representation
"""
nodeindex = nodename-n_leaves
if nodename<n_leaves:
return leaf_labels[nodeindex],np.array([X[nodeindex]])
else:
node_children = children[nodeindex]
branch0,branch0samples = go_down_tree(children,n_leaves,X,leaf_labels,node_children[0])
branch1,branch1samples = go_down_tree(children,n_leaves,X,leaf_labels,node_children[1])
node = np.vstack((branch0samples,branch1samples))
branch0span = spanner(branch0samples)
branch1span = spanner(branch1samples)
nodespan = spanner(node)
branch0distance = nodespan-branch0span
branch1distance = nodespan-branch1span
nodename = ''({branch0}:{branch0distance},{branch1}:{branch1distance})''.format(branch0=branch0,branch0distance=branch0distance,branch1=branch1,branch1distance=branch1distance)
return nodename,node
def get_cluster_spanner(aggClusterer):
"""
spanner = get_cluster_spanner(aggClusterer)
Input:
aggClusterer: sklearn.cluster.AgglomerativeClustering instance
Get a callable that computes a given cluster''s span. To compute
a cluster''s span, call spanner(cluster)
The cluster must be a 2D numpy array, where the axis=0 holds
separate cluster members and the axis=1 holds the different
variables.
"""
if aggClusterer.linkage==''ward'':
if aggClusterer.affinity==''euclidean'':
spanner = lambda x:np.sum((x-aggClusterer.pooling_func(x,axis=0))**2)
elif aggClusterer.linkage==''complete'':
if aggClusterer.affinity==''euclidean'':
spanner = lambda x:np.max(np.sum((x[:,None,:]-x[None,:,:])**2,axis=2))
elif aggClusterer.affinity==''l1'' or aggClusterer.affinity==''manhattan'':
spanner = lambda x:np.max(np.sum(np.abs(x[:,None,:]-x[None,:,:]),axis=2))
elif aggClusterer.affinity==''l2'':
spanner = lambda x:np.max(np.sqrt(np.sum((x[:,None,:]-x[None,:,:])**2,axis=2)))
elif aggClusterer.affinity==''cosine'':
spanner = lambda x:np.max(np.sum((x[:,None,:]*x[None,:,:]))/(np.sqrt(np.sum(x[:,None,:]*x[:,None,:],axis=2,keepdims=True))*np.sqrt(np.sum(x[None,:,:]*x[None,:,:],axis=2,keepdims=True))))
else:
raise AttributeError(''Unknown affinity attribute value {0}.''.format(aggClusterer.affinity))
elif aggClusterer.linkage==''average'':
if aggClusterer.affinity==''euclidean'':
spanner = lambda x:np.mean(np.sum((x[:,None,:]-x[None,:,:])**2,axis=2))
elif aggClusterer.affinity==''l1'' or aggClusterer.affinity==''manhattan'':
spanner = lambda x:np.mean(np.sum(np.abs(x[:,None,:]-x[None,:,:]),axis=2))
elif aggClusterer.affinity==''l2'':
spanner = lambda x:np.mean(np.sqrt(np.sum((x[:,None,:]-x[None,:,:])**2,axis=2)))
elif aggClusterer.affinity==''cosine'':
spanner = lambda x:np.mean(np.sum((x[:,None,:]*x[None,:,:]))/(np.sqrt(np.sum(x[:,None,:]*x[:,None,:],axis=2,keepdims=True))*np.sqrt(np.sum(x[None,:,:]*x[None,:,:],axis=2,keepdims=True))))
else:
raise AttributeError(''Unknown affinity attribute value {0}.''.format(aggClusterer.affinity))
else:
raise AttributeError(''Unknown linkage attribute value {0}.''.format(aggClusterer.linkage))
return spanner
clusterer = AgglomerativeClustering(n_clusters=2,compute_full_tree=True) # You can set compute_full_tree to ''auto'', but I left it this way to get the entire tree plotted
clusterer.fit(X) # X for whatever you want to fit
spanner = get_cluster_spanner(clusterer)
newick_tree = build_Newick_tree(clusterer.children_,clusterer.n_leaves_,X,leaf_labels,spanner) # leaf_labels is a list of labels for each entry in X
tree = ete3.Tree(newick_tree)
tree.show()
Para aquellos dispuestos a salir de Python y usar la robusta biblioteca D3, no es muy difícil usar las d3.cluster()
(o, supongo, d3.tree()
) para lograr un resultado agradable y personalizable.
Vea el jsfiddle para una demostración.
La matriz children_
afortunadamente funciona como una matriz JS, y el único paso intermedio es usar d3.stratify()
para convertirla en una representación jerárquica. Específicamente, necesitamos que cada nodo tenga un id
Y un parentId
:
var N = 272; // Your n_samples/corpus size.
var root = d3.stratify()
.id((d,i) => i + N)
.parentId((d, i) => {
var parIndex = data.findIndex(e => e.includes(i + N));
if (parIndex < 0) {
return; // The root should have an undefined parentId.
}
return parIndex + N;
})(data); // Your children_
Usted termina con al menos O (n ^ 2) comportamiento aquí debido a la línea findIndex
, pero probablemente no importe hasta que sus n_muestras se vuelvan enormes, en cuyo caso, podría precomputar un índice más eficiente.
Más allá de eso, es más el uso plug and chug de d3.cluster()
. Vea el bloque canónico de mbostock o mi JSFiddle.
NB Para mi caso de uso, bastaba simplemente con mostrar nodos de hoja; es un poco más complicado visualizar las muestras / hojas, ya que es posible que no estén todas en la matriz children_
explícitamente.
Utilice la implementación scipy de agrupamiento aglomerativo en su lugar. Aquí hay un ejemplo.
from scipy.cluster.hierarchy import dendrogram, linkage
data = [[0., 0.], [0.1, -0.1], [1., 1.], [1.1, 1.1]]
Z = linkage(data)
dendrogram(Z)
Puede encontrar documentación para linkage
here y documentación para dendrogram
here .