python - scikit - sklearn min_samples_leaf
¿Cómo extraer las reglas de decisión de scikit-learn decision-tree? (12)
¿Puedo extraer las reglas de decisión subyacentes (o "rutas de decisión") de un árbol entrenado en un árbol de decisión como una lista de texto?
Algo como:
if A>0.4 then if B<0.2 then if C>0.8 then class=''X''
Gracias por tu ayuda.
Al parecer, hace mucho tiempo, alguien ya decidió intentar agregar la siguiente función a las funciones de exportación de árbol del scikit oficial (que básicamente solo admite export_graphviz)
def export_dict(tree, feature_names=None, max_depth=None) :
"""Export a decision tree in dict format.
Aquí está su compromiso completo:
No estoy seguro de qué pasó con este comentario. Pero también podrías intentar usar esa función.
Creo que esto garantiza una seria solicitud de documentación a las buenas personas de scikit-learn para documentar apropiadamente el sklearn.tree.Tree
API que es la estructura de árbol subyacente que DecisionTreeClassifier
expone como su atributo tree_
.
Aquí hay una función, las reglas de impresión de un árbol de decisión scikit-learn en python 3 y con compensaciones para bloques condicionales para hacer que la estructura sea más legible:
def print_decision_tree(tree, feature_names=None, offset_unit='' ''):
''''''Plots textual representation of rules of a decision tree
tree: scikit-learn representation of tree
feature_names: list of feature names. They are set to f1,f2,f3,... if not specified
offset_unit: a string of offset of the conditional block''''''
left = tree.tree_.children_left
right = tree.tree_.children_right
threshold = tree.tree_.threshold
value = tree.tree_.value
if feature_names is None:
features = [''f%d''%i for i in tree.tree_.feature]
else:
features = [feature_names[i] for i in tree.tree_.feature]
def recurse(left, right, threshold, features, node, depth=0):
offset = offset_unit*depth
if (threshold[node] != -2):
print(offset+"if ( " + features[node] + " <= " + str(threshold[node]) + " ) {")
if left[node] != -1:
recurse (left, right, threshold, features,left[node],depth+1)
print(offset+"} else {")
if right[node] != -1:
recurse (left, right, threshold, features,right[node],depth+1)
print(offset+"}")
else:
print(offset+"return " + str(value[node]))
recurse(left, right, threshold, features, 0,0)
Códigos a continuación es mi enfoque bajo anaconda python 2.7 más un nombre de paquete "pydot-ng" para hacer un archivo PDF con reglas de decisión. Espero que sea útil.
from sklearn import tree
clf = tree.DecisionTreeClassifier(max_leaf_nodes=n)
clf_ = clf.fit(X, data_y)
feature_names = X.columns
class_name = clf_.classes_.astype(int).astype(str)
def output_pdf(clf_, name):
from sklearn import tree
from sklearn.externals.six import StringIO
import pydot_ng as pydot
dot_data = StringIO()
tree.export_graphviz(clf_, out_file=dot_data,
feature_names=feature_names,
class_names=class_name,
filled=True, rounded=True,
special_characters=True,
node_ids=1,)
graph = pydot.graph_from_dot_data(dot_data.getvalue())
graph.write_pdf("%s.pdf"%name)
output_pdf(clf_, name=''filename%s''%n)
Creé mi propia función para extraer las reglas de los árboles de decisión creados por sklearn:
import pandas as pd
import numpy as np
from sklearn.tree import DecisionTreeClassifier
# dummy data:
df = pd.DataFrame({''col1'':[0,1,2,3],''col2'':[3,4,5,6],''dv'':[0,1,0,1]})
# create decision tree
dt = DecisionTreeClassifier(max_depth=5, min_samples_leaf=1)
dt.fit(df.ix[:,:2], df.dv)
Esta función comienza primero con los nodos (identificados por -1 en las matrices secundarias) y luego encuentra recursivamente los padres. Yo llamo a esto un ''linaje'' de un nodo. En el camino, agarro los valores que necesito para crear la lógica SAS if / then / else:
def get_lineage(tree, feature_names):
left = tree.tree_.children_left
right = tree.tree_.children_right
threshold = tree.tree_.threshold
features = [feature_names[i] for i in tree.tree_.feature]
# get ids of child nodes
idx = np.argwhere(left == -1)[:,0]
def recurse(left, right, child, lineage=None):
if lineage is None:
lineage = [child]
if child in left:
parent = np.where(left == child)[0].item()
split = ''l''
else:
parent = np.where(right == child)[0].item()
split = ''r''
lineage.append((parent, split, threshold[parent], features[parent]))
if parent == 0:
lineage.reverse()
return lineage
else:
return recurse(left, right, parent, lineage)
for child in idx:
for node in recurse(left, right, child):
print node
Los conjuntos de tuplas a continuación contienen todo lo que necesito para crear sentencias SAS if / then / else. No me gusta usar do
blocks en SAS, por eso creo una lógica que describe la ruta completa de un nodo. El entero único después de las tuplas es la ID del nodo terminal en una ruta. Todas las tuplas anteriores se combinan para crear ese nodo.
In [1]: get_lineage(dt, df.columns)
(0, ''l'', 0.5, ''col1'')
1
(0, ''r'', 0.5, ''col1'')
(2, ''l'', 4.5, ''col2'')
3
(0, ''r'', 0.5, ''col1'')
(2, ''r'', 4.5, ''col2'')
(4, ''l'', 2.5, ''col1'')
5
(0, ''r'', 0.5, ''col1'')
(2, ''r'', 4.5, ''col2'')
(4, ''r'', 2.5, ''col1'')
6
Creo que esta respuesta es más correcta que las otras respuestas aquí:
from sklearn.tree import _tree
def tree_to_code(tree, feature_names):
tree_ = tree.tree_
feature_name = [
feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
for i in tree_.feature
]
print "def tree({}):".format(", ".join(feature_names))
def recurse(node, depth):
indent = " " * depth
if tree_.feature[node] != _tree.TREE_UNDEFINED:
name = feature_name[node]
threshold = tree_.threshold[node]
print "{}if {} <= {}:".format(indent, name, threshold)
recurse(tree_.children_left[node], depth + 1)
print "{}else: # if {} > {}".format(indent, name, threshold)
recurse(tree_.children_right[node], depth + 1)
else:
print "{}return {}".format(indent, tree_.value[node])
recurse(0, 1)
Esto imprime una función de Python válida. Aquí hay un ejemplo de salida para un árbol que está tratando de devolver su entrada, un número entre 0 y 10.
def tree(f0):
if f0 <= 6.0:
if f0 <= 1.5:
return [[ 0.]]
else: # if f0 > 1.5
if f0 <= 4.5:
if f0 <= 3.5:
return [[ 3.]]
else: # if f0 > 3.5
return [[ 4.]]
else: # if f0 > 4.5
return [[ 5.]]
else: # if f0 > 6.0
if f0 <= 8.5:
if f0 <= 7.5:
return [[ 7.]]
else: # if f0 > 7.5
return [[ 8.]]
else: # if f0 > 8.5
return [[ 9.]]
Aquí hay algunos obstáculos que veo en otras respuestas:
- Usar
tree_.threshold == -2
para decidir si un nodo es una hoja no es una buena idea. ¿Qué pasa si se trata de un nodo de decisión real con un umbral de -2? En su lugar, deberías mirartree.feature
otree.children_*
. - La línea
features = [feature_names[i] for i in tree_.feature]
bloquea con mi versión de sklearn, porque algunos valores detree.tree_.feature
son -2 (específicamente para los nodos hoja). - No es necesario tener múltiples sentencias if en la función recursiva, solo una está bien.
Esto se basa en la respuesta de @paulkernfeld. Si tiene un marco de datos X con sus características y un marco de datos objetivo con sus resonses y usted desea obtener una idea de qué valor y terminó en qué nodo (y también para trazarlo en consecuencia) puede hacer lo siguiente:
def tree_to_code(tree, feature_names):
codelines = []
codelines.append(''def get_cat(X_tmp):/n'')
codelines.append('' catout = []/n'')
codelines.append('' for codelines in range(0,X_tmp.shape[0]):/n'')
codelines.append('' Xin = X_tmp.iloc[codelines]/n'')
tree_ = tree.tree_
feature_name = [
feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
for i in tree_.feature
]
#print "def tree({}):".format(", ".join(feature_names))
def recurse(node, depth):
indent = " " * depth
if tree_.feature[node] != _tree.TREE_UNDEFINED:
name = feature_name[node]
threshold = tree_.threshold[node]
codelines.append (''{}if Xin["{}"] <= {}:/n''.format(indent, name, threshold))
recurse(tree_.children_left[node], depth + 1)
codelines.append( ''{}else: # if Xin["{}"] > {}/n''.format(indent, name, threshold))
recurse(tree_.children_right[node], depth + 1)
else:
codelines.append( ''{}mycat = {}/n''.format(indent, node))
recurse(0, 1)
codelines.append('' catout.append(mycat)/n'')
codelines.append('' return pd.DataFrame(catout,index=X_tmp.index,columns=["category"])/n'')
codelines.append(''node_ids = get_cat(X)/n'')
return codelines
mycode = tree_to_code(clf,X.columns.values)
# now execute the function and obtain the dataframe with all nodes
exec(''''.join(mycode))
node_ids = [int(x[0]) for x in node_ids.values]
node_ids2 = pd.DataFrame(node_ids)
print(''make plot'')
import matplotlib.cm as cm
colors = cm.rainbow(np.linspace(0, 1, 1+max( list(set(node_ids)))))
#plt.figure(figsize=cm2inch(24, 21))
for i in list(set(node_ids)):
plt.plot(y[node_ids2.values==i],''o'',color=colors[i], label=str(i))
mytitle = [''y colored by node'']
plt.title(mytitle ,fontsize=14)
plt.xlabel(''my xlabel'')
plt.ylabel(tagname)
plt.xticks(rotation=70)
plt.legend(loc=''upper center'', bbox_to_anchor=(0.5, 1.00), shadow=True, ncol=9)
plt.tight_layout()
plt.show()
plt.close
No es la versión más elegante pero cumple su función ...
Hay un nuevo método DecisionTreeClassifier
, decision_path
, en la versión 0.18.0 . Los desarrolladores proporcionan una walkthrough exhaustiva (bien documentada).
La primera sección de código en el tutorial que imprime la estructura del árbol parece estar bien. Sin embargo, modifiqué el código en la segunda sección para interrogar una muestra. Mis cambios denotados con # <--
sample_id = 0
node_index = node_indicator.indices[node_indicator.indptr[sample_id]:
node_indicator.indptr[sample_id + 1]]
print(''Rules used to predict sample %s: '' % sample_id)
for node_id in node_index:
if leave_id[sample_id] == node_id: # <-- changed != to ==
#continue # <-- comment out
print("leaf node {} reached, no decision here".format(leave_id[sample_id])) # <--
else: # < -- added else to iterate through decision nodes
if (X_test[sample_id, feature[node_id]] <= threshold[node_id]):
threshold_sign = "<="
else:
threshold_sign = ">"
print("decision id node %s : (X[%s, %s] (= %s) %s %s)"
% (node_id,
sample_id,
feature[node_id],
X_test[sample_id, feature[node_id]], # <-- changed i to sample_id
threshold_sign,
threshold[node_id]))
Rules used to predict sample 0:
decision id node 0 : (X[0, 3] (= 2.4) > 0.800000011921)
decision id node 2 : (X[0, 2] (= 5.1) > 4.94999980927)
leaf node 4 reached, no decision here
Cambie el sample_id
para ver las rutas de decisión para otras muestras. No he preguntado a los desarrolladores sobre estos cambios, solo me pareció más intuitivo al trabajar con el ejemplo.
He estado pasando por esto, pero necesitaba que las reglas se escriban en este formato
if A>0.4 then if B<0.2 then if C>0.8 then class=''X''
Así que adapté la respuesta de @paulkernfeld (gracias) que puede personalizar a su necesidad
def tree_to_code(tree, feature_names, Y):
tree_ = tree.tree_
feature_name = [
feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
for i in tree_.feature
]
pathto=dict()
global k
k = 0
def recurse(node, depth, parent):
global k
indent = " " * depth
if tree_.feature[node] != _tree.TREE_UNDEFINED:
name = feature_name[node]
threshold = tree_.threshold[node]
s= "{} <= {} ".format( name, threshold, node )
if node == 0:
pathto[node]=s
else:
pathto[node]=pathto[parent]+'' & '' +s
recurse(tree_.children_left[node], depth + 1, node)
s="{} > {}".format( name, threshold)
if node == 0:
pathto[node]=s
else:
pathto[node]=pathto[parent]+'' & '' +s
recurse(tree_.children_right[node], depth + 1, node)
else:
k=k+1
print(k,'')'',pathto[parent], tree_.value[node])
recurse(0, 1, 0)
Se modificó el código de Zelazny7 para buscar SQL del árbol de decisión.
# SQL from decision tree
def get_lineage(tree, feature_names):
left = tree.tree_.children_left
right = tree.tree_.children_right
threshold = tree.tree_.threshold
features = [feature_names[i] for i in tree.tree_.feature]
le=''<=''
g =''>''
# get ids of child nodes
idx = np.argwhere(left == -1)[:,0]
def recurse(left, right, child, lineage=None):
if lineage is None:
lineage = [child]
if child in left:
parent = np.where(left == child)[0].item()
split = ''l''
else:
parent = np.where(right == child)[0].item()
split = ''r''
lineage.append((parent, split, threshold[parent], features[parent]))
if parent == 0:
lineage.reverse()
return lineage
else:
return recurse(left, right, parent, lineage)
print ''case ''
for j,child in enumerate(idx):
clause='' when ''
for node in recurse(left, right, child):
if len(str(node))<3:
continue
i=node
if i[1]==''l'': sign=le
else: sign=g
clause=clause+i[3]+sign+str(i[2])+'' and ''
clause=clause[:-4]+'' then ''+str(j)
print clause
print ''else 99 end as clusters''
Solo porque todos fueron tan útiles, solo añadiré una modificación a las hermosas soluciones de Zelazny7 y Daniele. Este es para Python 2.7, con pestañas para hacerlo más legible:
def get_code(tree, feature_names, tabdepth=0):
left = tree.tree_.children_left
right = tree.tree_.children_right
threshold = tree.tree_.threshold
features = [feature_names[i] for i in tree.tree_.feature]
value = tree.tree_.value
def recurse(left, right, threshold, features, node, tabdepth=0):
if (threshold[node] != -2):
print ''/t'' * tabdepth,
print "if ( " + features[node] + " <= " + str(threshold[node]) + " ) {"
if left[node] != -1:
recurse (left, right, threshold, features,left[node], tabdepth+1)
print ''/t'' * tabdepth,
print "} else {"
if right[node] != -1:
recurse (left, right, threshold, features,right[node], tabdepth+1)
print ''/t'' * tabdepth,
print "}"
else:
print ''/t'' * tabdepth,
print "return " + str(value[node])
recurse(left, right, threshold, features, 0)
Zelazny7 el código enviado por Zelazny7 para imprimir algunos pseudocódigos:
def get_code(tree, feature_names):
left = tree.tree_.children_left
right = tree.tree_.children_right
threshold = tree.tree_.threshold
features = [feature_names[i] for i in tree.tree_.feature]
value = tree.tree_.value
def recurse(left, right, threshold, features, node):
if (threshold[node] != -2):
print "if ( " + features[node] + " <= " + str(threshold[node]) + " ) {"
if left[node] != -1:
recurse (left, right, threshold, features,left[node])
print "} else {"
if right[node] != -1:
recurse (left, right, threshold, features,right[node])
print "}"
else:
print "return " + str(value[node])
recurse(left, right, threshold, features, 0)
si llama a get_code(dt, df.columns)
en el mismo ejemplo, obtendrá:
if ( col1 <= 0.5 ) {
return [[ 1. 0.]]
} else {
if ( col2 <= 4.5 ) {
return [[ 0. 1.]]
} else {
if ( col1 <= 2.5 ) {
return [[ 1. 0.]]
} else {
return [[ 0. 1.]]
}
}
}
from StringIO import StringIO
out = StringIO()
out = tree.export_graphviz(clf, out_file=out)
print out.getvalue()
Puedes ver un árbol digrafo. Luego, clf.tree_.feature
y clf.tree_.value
son una matriz de nodos que divide la característica y una matriz de valores de nodos, respectivamente. Puede consultar más detalles de esta fuente github .