treeclassifier sklearn scikit min_samples_leaf from feature_importances_ example decisiontreeclassifier decision classifier python machine-learning scikit-learn decision-tree random-forest

python - scikit - sklearn min_samples_leaf



¿Cómo extraer las reglas de decisión de scikit-learn decision-tree? (12)

¿Puedo extraer las reglas de decisión subyacentes (o "rutas de decisión") de un árbol entrenado en un árbol de decisión como una lista de texto?

Algo como:

if A>0.4 then if B<0.2 then if C>0.8 then class=''X''

Gracias por tu ayuda.


Al parecer, hace mucho tiempo, alguien ya decidió intentar agregar la siguiente función a las funciones de exportación de árbol del scikit oficial (que básicamente solo admite export_graphviz)

def export_dict(tree, feature_names=None, max_depth=None) : """Export a decision tree in dict format.

Aquí está su compromiso completo:

https://github.com/scikit-learn/scikit-learn/blob/79bdc8f711d0af225ed6be9fdb708cea9f98a910/sklearn/tree/export.py

No estoy seguro de qué pasó con este comentario. Pero también podrías intentar usar esa función.

Creo que esto garantiza una seria solicitud de documentación a las buenas personas de scikit-learn para documentar apropiadamente el sklearn.tree.Tree API que es la estructura de árbol subyacente que DecisionTreeClassifier expone como su atributo tree_ .


Aquí hay una función, las reglas de impresión de un árbol de decisión scikit-learn en python 3 y con compensaciones para bloques condicionales para hacer que la estructura sea más legible:

def print_decision_tree(tree, feature_names=None, offset_unit='' ''): ''''''Plots textual representation of rules of a decision tree tree: scikit-learn representation of tree feature_names: list of feature names. They are set to f1,f2,f3,... if not specified offset_unit: a string of offset of the conditional block'''''' left = tree.tree_.children_left right = tree.tree_.children_right threshold = tree.tree_.threshold value = tree.tree_.value if feature_names is None: features = [''f%d''%i for i in tree.tree_.feature] else: features = [feature_names[i] for i in tree.tree_.feature] def recurse(left, right, threshold, features, node, depth=0): offset = offset_unit*depth if (threshold[node] != -2): print(offset+"if ( " + features[node] + " <= " + str(threshold[node]) + " ) {") if left[node] != -1: recurse (left, right, threshold, features,left[node],depth+1) print(offset+"} else {") if right[node] != -1: recurse (left, right, threshold, features,right[node],depth+1) print(offset+"}") else: print(offset+"return " + str(value[node])) recurse(left, right, threshold, features, 0,0)


Códigos a continuación es mi enfoque bajo anaconda python 2.7 más un nombre de paquete "pydot-ng" para hacer un archivo PDF con reglas de decisión. Espero que sea útil.

from sklearn import tree clf = tree.DecisionTreeClassifier(max_leaf_nodes=n) clf_ = clf.fit(X, data_y) feature_names = X.columns class_name = clf_.classes_.astype(int).astype(str) def output_pdf(clf_, name): from sklearn import tree from sklearn.externals.six import StringIO import pydot_ng as pydot dot_data = StringIO() tree.export_graphviz(clf_, out_file=dot_data, feature_names=feature_names, class_names=class_name, filled=True, rounded=True, special_characters=True, node_ids=1,) graph = pydot.graph_from_dot_data(dot_data.getvalue()) graph.write_pdf("%s.pdf"%name) output_pdf(clf_, name=''filename%s''%n)

un gráfico de árbol muestra aquí


Creé mi propia función para extraer las reglas de los árboles de decisión creados por sklearn:

import pandas as pd import numpy as np from sklearn.tree import DecisionTreeClassifier # dummy data: df = pd.DataFrame({''col1'':[0,1,2,3],''col2'':[3,4,5,6],''dv'':[0,1,0,1]}) # create decision tree dt = DecisionTreeClassifier(max_depth=5, min_samples_leaf=1) dt.fit(df.ix[:,:2], df.dv)

Esta función comienza primero con los nodos (identificados por -1 en las matrices secundarias) y luego encuentra recursivamente los padres. Yo llamo a esto un ''linaje'' de un nodo. En el camino, agarro los valores que necesito para crear la lógica SAS if / then / else:

def get_lineage(tree, feature_names): left = tree.tree_.children_left right = tree.tree_.children_right threshold = tree.tree_.threshold features = [feature_names[i] for i in tree.tree_.feature] # get ids of child nodes idx = np.argwhere(left == -1)[:,0] def recurse(left, right, child, lineage=None): if lineage is None: lineage = [child] if child in left: parent = np.where(left == child)[0].item() split = ''l'' else: parent = np.where(right == child)[0].item() split = ''r'' lineage.append((parent, split, threshold[parent], features[parent])) if parent == 0: lineage.reverse() return lineage else: return recurse(left, right, parent, lineage) for child in idx: for node in recurse(left, right, child): print node

Los conjuntos de tuplas a continuación contienen todo lo que necesito para crear sentencias SAS if / then / else. No me gusta usar do blocks en SAS, por eso creo una lógica que describe la ruta completa de un nodo. El entero único después de las tuplas es la ID del nodo terminal en una ruta. Todas las tuplas anteriores se combinan para crear ese nodo.

In [1]: get_lineage(dt, df.columns) (0, ''l'', 0.5, ''col1'') 1 (0, ''r'', 0.5, ''col1'') (2, ''l'', 4.5, ''col2'') 3 (0, ''r'', 0.5, ''col1'') (2, ''r'', 4.5, ''col2'') (4, ''l'', 2.5, ''col1'') 5 (0, ''r'', 0.5, ''col1'') (2, ''r'', 4.5, ''col2'') (4, ''r'', 2.5, ''col1'') 6


Creo que esta respuesta es más correcta que las otras respuestas aquí:

from sklearn.tree import _tree def tree_to_code(tree, feature_names): tree_ = tree.tree_ feature_name = [ feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!" for i in tree_.feature ] print "def tree({}):".format(", ".join(feature_names)) def recurse(node, depth): indent = " " * depth if tree_.feature[node] != _tree.TREE_UNDEFINED: name = feature_name[node] threshold = tree_.threshold[node] print "{}if {} <= {}:".format(indent, name, threshold) recurse(tree_.children_left[node], depth + 1) print "{}else: # if {} > {}".format(indent, name, threshold) recurse(tree_.children_right[node], depth + 1) else: print "{}return {}".format(indent, tree_.value[node]) recurse(0, 1)

Esto imprime una función de Python válida. Aquí hay un ejemplo de salida para un árbol que está tratando de devolver su entrada, un número entre 0 y 10.

def tree(f0): if f0 <= 6.0: if f0 <= 1.5: return [[ 0.]] else: # if f0 > 1.5 if f0 <= 4.5: if f0 <= 3.5: return [[ 3.]] else: # if f0 > 3.5 return [[ 4.]] else: # if f0 > 4.5 return [[ 5.]] else: # if f0 > 6.0 if f0 <= 8.5: if f0 <= 7.5: return [[ 7.]] else: # if f0 > 7.5 return [[ 8.]] else: # if f0 > 8.5 return [[ 9.]]

Aquí hay algunos obstáculos que veo en otras respuestas:

  1. Usar tree_.threshold == -2 para decidir si un nodo es una hoja no es una buena idea. ¿Qué pasa si se trata de un nodo de decisión real con un umbral de -2? En su lugar, deberías mirar tree.feature o tree.children_* .
  2. La línea features = [feature_names[i] for i in tree_.feature] bloquea con mi versión de sklearn, porque algunos valores de tree.tree_.feature son -2 (específicamente para los nodos hoja).
  3. No es necesario tener múltiples sentencias if en la función recursiva, solo una está bien.

Esto se basa en la respuesta de @paulkernfeld. Si tiene un marco de datos X con sus características y un marco de datos objetivo con sus resonses y usted desea obtener una idea de qué valor y terminó en qué nodo (y también para trazarlo en consecuencia) puede hacer lo siguiente:

def tree_to_code(tree, feature_names): codelines = [] codelines.append(''def get_cat(X_tmp):/n'') codelines.append('' catout = []/n'') codelines.append('' for codelines in range(0,X_tmp.shape[0]):/n'') codelines.append('' Xin = X_tmp.iloc[codelines]/n'') tree_ = tree.tree_ feature_name = [ feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!" for i in tree_.feature ] #print "def tree({}):".format(", ".join(feature_names)) def recurse(node, depth): indent = " " * depth if tree_.feature[node] != _tree.TREE_UNDEFINED: name = feature_name[node] threshold = tree_.threshold[node] codelines.append (''{}if Xin["{}"] <= {}:/n''.format(indent, name, threshold)) recurse(tree_.children_left[node], depth + 1) codelines.append( ''{}else: # if Xin["{}"] > {}/n''.format(indent, name, threshold)) recurse(tree_.children_right[node], depth + 1) else: codelines.append( ''{}mycat = {}/n''.format(indent, node)) recurse(0, 1) codelines.append('' catout.append(mycat)/n'') codelines.append('' return pd.DataFrame(catout,index=X_tmp.index,columns=["category"])/n'') codelines.append(''node_ids = get_cat(X)/n'') return codelines mycode = tree_to_code(clf,X.columns.values) # now execute the function and obtain the dataframe with all nodes exec(''''.join(mycode)) node_ids = [int(x[0]) for x in node_ids.values] node_ids2 = pd.DataFrame(node_ids) print(''make plot'') import matplotlib.cm as cm colors = cm.rainbow(np.linspace(0, 1, 1+max( list(set(node_ids))))) #plt.figure(figsize=cm2inch(24, 21)) for i in list(set(node_ids)): plt.plot(y[node_ids2.values==i],''o'',color=colors[i], label=str(i)) mytitle = [''y colored by node''] plt.title(mytitle ,fontsize=14) plt.xlabel(''my xlabel'') plt.ylabel(tagname) plt.xticks(rotation=70) plt.legend(loc=''upper center'', bbox_to_anchor=(0.5, 1.00), shadow=True, ncol=9) plt.tight_layout() plt.show() plt.close

No es la versión más elegante pero cumple su función ...


Hay un nuevo método DecisionTreeClassifier , decision_path , en la versión 0.18.0 . Los desarrolladores proporcionan una walkthrough exhaustiva (bien documentada).

La primera sección de código en el tutorial que imprime la estructura del árbol parece estar bien. Sin embargo, modifiqué el código en la segunda sección para interrogar una muestra. Mis cambios denotados con # <--

sample_id = 0 node_index = node_indicator.indices[node_indicator.indptr[sample_id]: node_indicator.indptr[sample_id + 1]] print(''Rules used to predict sample %s: '' % sample_id) for node_id in node_index: if leave_id[sample_id] == node_id: # <-- changed != to == #continue # <-- comment out print("leaf node {} reached, no decision here".format(leave_id[sample_id])) # <-- else: # < -- added else to iterate through decision nodes if (X_test[sample_id, feature[node_id]] <= threshold[node_id]): threshold_sign = "<=" else: threshold_sign = ">" print("decision id node %s : (X[%s, %s] (= %s) %s %s)" % (node_id, sample_id, feature[node_id], X_test[sample_id, feature[node_id]], # <-- changed i to sample_id threshold_sign, threshold[node_id])) Rules used to predict sample 0: decision id node 0 : (X[0, 3] (= 2.4) > 0.800000011921) decision id node 2 : (X[0, 2] (= 5.1) > 4.94999980927) leaf node 4 reached, no decision here

Cambie el sample_id para ver las rutas de decisión para otras muestras. No he preguntado a los desarrolladores sobre estos cambios, solo me pareció más intuitivo al trabajar con el ejemplo.


He estado pasando por esto, pero necesitaba que las reglas se escriban en este formato

if A>0.4 then if B<0.2 then if C>0.8 then class=''X''

Así que adapté la respuesta de @paulkernfeld (gracias) que puede personalizar a su necesidad

def tree_to_code(tree, feature_names, Y): tree_ = tree.tree_ feature_name = [ feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!" for i in tree_.feature ] pathto=dict() global k k = 0 def recurse(node, depth, parent): global k indent = " " * depth if tree_.feature[node] != _tree.TREE_UNDEFINED: name = feature_name[node] threshold = tree_.threshold[node] s= "{} <= {} ".format( name, threshold, node ) if node == 0: pathto[node]=s else: pathto[node]=pathto[parent]+'' & '' +s recurse(tree_.children_left[node], depth + 1, node) s="{} > {}".format( name, threshold) if node == 0: pathto[node]=s else: pathto[node]=pathto[parent]+'' & '' +s recurse(tree_.children_right[node], depth + 1, node) else: k=k+1 print(k,'')'',pathto[parent], tree_.value[node]) recurse(0, 1, 0)


Se modificó el código de Zelazny7 para buscar SQL del árbol de decisión.

# SQL from decision tree def get_lineage(tree, feature_names): left = tree.tree_.children_left right = tree.tree_.children_right threshold = tree.tree_.threshold features = [feature_names[i] for i in tree.tree_.feature] le=''<='' g =''>'' # get ids of child nodes idx = np.argwhere(left == -1)[:,0] def recurse(left, right, child, lineage=None): if lineage is None: lineage = [child] if child in left: parent = np.where(left == child)[0].item() split = ''l'' else: parent = np.where(right == child)[0].item() split = ''r'' lineage.append((parent, split, threshold[parent], features[parent])) if parent == 0: lineage.reverse() return lineage else: return recurse(left, right, parent, lineage) print ''case '' for j,child in enumerate(idx): clause='' when '' for node in recurse(left, right, child): if len(str(node))<3: continue i=node if i[1]==''l'': sign=le else: sign=g clause=clause+i[3]+sign+str(i[2])+'' and '' clause=clause[:-4]+'' then ''+str(j) print clause print ''else 99 end as clusters''


Solo porque todos fueron tan útiles, solo añadiré una modificación a las hermosas soluciones de Zelazny7 y Daniele. Este es para Python 2.7, con pestañas para hacerlo más legible:

def get_code(tree, feature_names, tabdepth=0): left = tree.tree_.children_left right = tree.tree_.children_right threshold = tree.tree_.threshold features = [feature_names[i] for i in tree.tree_.feature] value = tree.tree_.value def recurse(left, right, threshold, features, node, tabdepth=0): if (threshold[node] != -2): print ''/t'' * tabdepth, print "if ( " + features[node] + " <= " + str(threshold[node]) + " ) {" if left[node] != -1: recurse (left, right, threshold, features,left[node], tabdepth+1) print ''/t'' * tabdepth, print "} else {" if right[node] != -1: recurse (left, right, threshold, features,right[node], tabdepth+1) print ''/t'' * tabdepth, print "}" else: print ''/t'' * tabdepth, print "return " + str(value[node]) recurse(left, right, threshold, features, 0)


Zelazny7 el código enviado por Zelazny7 para imprimir algunos pseudocódigos:

def get_code(tree, feature_names): left = tree.tree_.children_left right = tree.tree_.children_right threshold = tree.tree_.threshold features = [feature_names[i] for i in tree.tree_.feature] value = tree.tree_.value def recurse(left, right, threshold, features, node): if (threshold[node] != -2): print "if ( " + features[node] + " <= " + str(threshold[node]) + " ) {" if left[node] != -1: recurse (left, right, threshold, features,left[node]) print "} else {" if right[node] != -1: recurse (left, right, threshold, features,right[node]) print "}" else: print "return " + str(value[node]) recurse(left, right, threshold, features, 0)

si llama a get_code(dt, df.columns) en el mismo ejemplo, obtendrá:

if ( col1 <= 0.5 ) { return [[ 1. 0.]] } else { if ( col2 <= 4.5 ) { return [[ 0. 1.]] } else { if ( col1 <= 2.5 ) { return [[ 1. 0.]] } else { return [[ 0. 1.]] } } }


from StringIO import StringIO out = StringIO() out = tree.export_graphviz(clf, out_file=out) print out.getvalue()

Puedes ver un árbol digrafo. Luego, clf.tree_.feature y clf.tree_.value son una matriz de nodos que divide la característica y una matriz de valores de nodos, respectivamente. Puede consultar más detalles de esta fuente github .