matlab - pseudocodigo - vecino mas cercano c++
Encontrar K vecinos más cercanos y su implementación (1)
Estoy trabajando en la clasificación de datos simples usando KNN con distancia euclidiana. He visto un ejemplo de lo que me gustaría hacer que se hace con la función knnsearch
MATLAB como se muestra a continuación:
load fisheriris
x = meas(:,3:4);
gscatter(x(:,1),x(:,2),species)
newpoint = [5 1.45];
[n,d] = knnsearch(x,newpoint,''k'',10);
line(x(n,1),x(n,2),''color'',[.5 .5 .5],''marker'',''o'',''linestyle'',''none'',''markersize'',10)
El código anterior toma un nuevo punto, es decir [5 1.45]
y encuentra los 10 valores más cercanos al nuevo punto. ¿Alguien puede mostrarme un algoritmo MATLAB con una explicación detallada de lo que hace la función knnsearch
? Hay alguna otra manera de hacer esto?
con detalles de cálculo de distancia y clasificación.
La base del algoritmo K-Nearest Neighbourhood (KNN) es que tiene una matriz de datos que consta de N
filas y M
columnas donde N
es la cantidad de puntos de datos que tenemos, mientras que M
es la dimensionalidad de cada punto de datos. Por ejemplo, si colocamos coordenadas cartesianas dentro de una matriz de datos, esta suele ser una matriz N x 2
o N x 3
. Con esta matriz de datos, proporciona un punto de consulta y busca los puntos k
más cercanos dentro de esta matriz de datos que son los más cercanos a este punto de consulta.
Por lo general, usamos la distancia euclidiana entre la consulta y el resto de sus puntos en su matriz de datos para calcular nuestras distancias. Sin embargo, también se usan otras distancias como L1 o City-Block / Manhattan. Después de esta operación, tendrá N
distancias euclidianas o de Manhattan que simbolizan las distancias entre la consulta y cada punto correspondiente en el conjunto de datos. Una vez que encuentre estos, simplemente busque los k
puntos más cercanos a la consulta ordenando las distancias en orden ascendente y recuperando esos k
puntos que tienen la menor distancia entre su conjunto de datos y la consulta.
Suponiendo que su matriz de datos se almacenó en x
, y newpoint
es un punto de muestra donde tiene M
columnas (es decir, 1 x M
), este es el procedimiento general que seguiría en forma de puntos:
- Encuentre la distancia euclidiana o de Manhattan entre el punto
newpoint
y cada punto enx
. - Clasifique estas distancias en orden ascendente.
- Devuelve los
k
puntos de datos enx
que están más cerca del puntonewpoint
.
Hagamos cada paso lentamente.
Paso 1
Una forma en que alguien puede hacer esto es tal vez en un ciclo for
como sigue:
N = size(x,1);
dists = zeros(N,1);
for idx = 1 : N
dists(idx) = sqrt(sum((x(idx,:) - newpoint).^2));
end
Si quisiera implementar la distancia de Manhattan, esto sería simplemente:
N = size(x,1);
dists = zeros(N,1);
for idx = 1 : N
dists(idx) = sum(abs(x(idx,:) - newpoint));
end
dists
sería un vector de elemento N
que contiene las distancias entre cada punto de datos en x
y newpoint
. Hacemos una resta elemento por elemento entre newpoint
y un punto de datos en x
, cuadra las diferencias, luego las newpoint
todas juntas. Esta suma tiene raíz cuadrada, que completa la distancia euclidiana. Para la distancia de Manhattan, debe realizar un elemento por sustracción de elementos, tomar los valores absolutos y luego sumar todos los componentes. Esta es probablemente la más simple de las implementaciones para entender, pero posiblemente podría ser la más ineficiente ... especialmente para conjuntos de datos de mayor tamaño y una mayor dimensionalidad de sus datos.
Otra posible solución sería replicar newpoint
y hacer que esta matriz tenga el mismo tamaño que x
, luego hacer una resta elemento por elemento de esta matriz, luego sumar todas las columnas de cada fila y hacer la raíz cuadrada. Por lo tanto, podemos hacer algo como esto:
N = size(x, 1);
dists = sqrt(sum((x - repmat(newpoint, N, 1)).^2, 2));
Para la distancia de Manhattan, harías:
N = size(x, 1);
dists = sum(abs(x - repmat(newpoint, N, 1)), 2);
repmat
toma una matriz o vector y los repite una cierta cantidad de veces en una dirección dada. En nuestro caso, queremos tomar nuestro vector newpoint
y apilar N
veces uno encima del otro para crear una matriz N x M
, donde cada fila tiene M
elementos de longitud. Restamos estas dos matrices juntas, luego cuadramos cada componente. Una vez que hacemos esto, sum
todas las columnas de cada fila y finalmente tomamos la raíz cuadrada de todos los resultados. Para la distancia de Manhattan, hacemos la resta, tomamos el valor absoluto y luego sumamos.
Sin embargo, la forma más eficiente de hacer esto en mi opinión sería usar bsxfun
. Esto esencialmente hace la replicación de la que hablamos bajo el capó con una sola llamada de función. Por lo tanto, el código sería simplemente esto:
dists = sqrt(sum(bsxfun(@minus, x, newpoint).^2, 2));
Para mí esto se ve mucho más limpio y al grano. Para la distancia de Manhattan, harías:
dists = sum(abs(bsxfun(@minus, x, newpoint)), 2);
Paso 2
Ahora que tenemos nuestras distancias, simplemente las clasificamos. Podemos usar el sort
para ordenar nuestras distancias:
[d,ind] = sort(dists);
d
contendría las distancias ordenadas en orden ascendente, mientras que ind
le ind
cada valor en la matriz no ordenada donde aparece en el resultado ordenado . Necesitamos usar ind
, extraer los primeros k
elementos de este vector, luego usar ind
para indexar en nuestra matriz de datos x
para devolver aquellos puntos que estuvieron más cerca de newpoint
.
Paso 3
El último paso es devolver esos k
puntos de datos que están más cerca de newpoint
. Podemos hacer esto de manera muy simple por:
ind_closest = ind(1:k);
x_closest = x(ind_closest,:);
ind_closest
debe contener los índices en la matriz de datos original x
que son los más cercanos a newpoint
. Específicamente, ind_closest
contiene las filas de las que debe ind_closest
muestras en x
para obtener los puntos más cercanos a newpoint
. x_closest
contendrá esos puntos de datos reales.
Para su placer de copiar y pegar, así es como se ve el código:
dists = sqrt(sum(bsxfun(@minus, x, newpoint).^2, 2));
%// Or do this for Manhattan
% dists = sum(abs(bsxfun(@minus, x, newpoint)), 2);
[d,ind] = sort(dists);
ind_closest = ind(1:k);
x_closest = x(ind_closest,:);
Ejecutando tu ejemplo, veamos nuestro código en acción:
load fisheriris
x = meas(:,3:4);
newpoint = [5 1.45];
k = 10;
%// Use Euclidean
dists = sqrt(sum(bsxfun(@minus, x, newpoint).^2, 2));
[d,ind] = sort(dists);
ind_closest = ind(1:k);
x_closest = x(ind_closest,:);
Al inspeccionar ind_closest
y x_closest
, esto es lo que obtenemos:
>> ind_closest
ind_closest =
120
53
73
134
84
77
78
51
64
87
>> x_closest
x_closest =
5.0000 1.5000
4.9000 1.5000
4.9000 1.5000
5.1000 1.5000
5.1000 1.6000
4.8000 1.4000
5.0000 1.7000
4.7000 1.4000
4.7000 1.4000
4.7000 1.5000
Si ejecutó knnsearch
, verá que su variable n
coincide con ind_closest
. Sin embargo, la variable d
devuelve las distancias desde el punto newpoint
a cada punto x
, no los puntos de datos reales. Si desea las distancias reales, simplemente haga lo siguiente después del código que escribí:
dist_sorted = d(1:k);
Tenga en cuenta que la respuesta anterior utiliza solo un punto de consulta en un lote de N
ejemplos. Con mucha frecuencia, KNN se usa en múltiples ejemplos simultáneamente. Supongamos que tenemos Q
puntos de consulta que queremos probar en el KNN. Esto daría como resultado una matriz kx M x Q
donde para cada ejemplo o cada sector, devolvemos los k
puntos más cercanos con una dimensionalidad de M
Alternativamente, podemos devolver las ID de los k
puntos más cercanos, lo que da como resultado una matriz Q xk
. Vamos a calcular ambos.
Una forma ingenua de hacer esto sería aplicar el código anterior en un bucle y repetir cada ejemplo.
Algo así funcionaría donde bsxfun
una matriz Q xk
y aplicamos el bsxfun
basado en bsxfun
para establecer cada fila de la matriz de salida en los k
puntos más cercanos en el conjunto de datos, donde usaremos el conjunto de datos Fisher Iris tal como lo teníamos antes. También mantendremos la misma dimensionalidad que hicimos en el ejemplo anterior y usaré cuatro ejemplos, entonces Q = 4
y M = 2
:
%// Load the data and create the query points
load fisheriris;
x = meas(:,3:4);
newpoints = [5 1.45; 7 2; 4 2.5; 2 3.5];
%// Define k and the output matrices
Q = size(newpoints, 1);
M = size(x, 2);
k = 10;
x_closest = zeros(k, M, Q);
ind_closest = zeros(Q, k);
%// Loop through each point and do logic as seen above:
for ii = 1 : Q
%// Get the point
newpoint = newpoints(ii, :);
%// Use Euclidean
dists = sqrt(sum(bsxfun(@minus, x, newpoint).^2, 2));
[d,ind] = sort(dists);
%// New - Output the IDs of the match as well as the points themselves
ind_closest(ii, :) = ind(1 : k).'';
x_closest(:, :, ii) = x(ind_closest(ii, :), :);
end
Aunque esto es muy bueno, podemos hacerlo aún mejor. Hay una manera de calcular eficientemente la distancia euclidiana al cuadrado entre dos conjuntos de vectores. Lo dejaré como ejercicio si quieres hacer esto con el Manhattan. Consultar este blog , dado que A
es una matriz Q1 x M
donde cada fila es un punto de dimensionalidad M
con puntos Q1
y B
es una matriz Q2 x M
donde cada fila es también un punto de dimensionalidad M
con puntos Q2
, podemos eficientemente calcular una matriz de distancia D(i, j)
donde el elemento en la fila i
y la columna j
denota la distancia entre la fila i
de A
y la fila j
de B
usando la siguiente formulación de matriz:
nA = sum(A.^2, 2); %// Sum of squares for each row of A
nB = sum(B.^2, 2); %// Sum of squares for each row of B
D = bsxfun(@plus, nA, nB.'') - 2*A*B.''; %// Compute distance matrix
D = D.^(0.5); %// Compute square root to complete calculation
Por lo tanto, si permitimos que A
sea una matriz de puntos de consulta y B
el conjunto de datos que consiste en sus datos originales, podemos determinar los k
puntos más cercanos ordenando cada fila individualmente y determinando las k
ubicaciones de cada fila que fueron las más pequeñas. También podemos usar esto para recuperar los puntos reales.
Por lo tanto:
%// Load the data and create the query points
load fisheriris;
x = meas(:,3:4);
newpoints = [5 1.45; 7 2; 4 2.5; 2 3.5];
%// Define k and other variables
k = 10;
Q = size(newpoints, 1);
M = size(x, 2);
nA = sum(newpoints.^2, 2); %// Sum of squares for each row of A
nB = sum(x.^2, 2); %// Sum of squares for each row of B
D = bsxfun(@plus, nA, nB.'') - 2*newpoints*x.''; %// Compute distance matrix
D = D.^(0.5); %// Compute square root to complete calculation
%// Sort the distances
[d, ind] = sort(D, 2);
%// Get the indices of the closest distances
ind_closest = ind(:, 1:k);
%// Also get the nearest points
x_closest = permute(reshape(x(ind_closest(:), :).'', M, k, []), [2 1 3]);
Vemos que usamos la lógica para calcular que la matriz de distancia es la misma, pero algunas variables han cambiado para adecuarse al ejemplo. También ordenamos cada fila de forma independiente utilizando las dos versiones de entrada de sort
modo que ind
contenga los ID por fila d
contendrá las distancias correspondientes. Luego determinamos qué índices son los más cercanos a cada punto de consulta simplemente truncando esta matriz a k
columnas. Luego usamos permute
y permute
para determinar cuáles son los puntos más cercanos asociados. Primero usamos todos los índices más cercanos y creamos una matriz de puntos que apila todas las ID una encima de la otra para obtener una matriz Q * kx M
. El uso de reshape
y permute
nos permite crear nuestra matriz 3D para que se convierta en una matriz kx M x Q
como la que hemos especificado. Si quisieras obtener las distancias reales, podemos indexar en d
y tomar lo que necesitamos. Para hacer esto, necesitarás usar sub2ind
para obtener los índices lineales para que podamos indexar en d
de una vez. Los valores de ind_closest
ya nos dan a qué columnas debemos acceder. Las filas a las que debemos acceder son simplemente 1, k
veces, 2, k
veces, etc. hasta Q
k
es la cantidad de puntos que queríamos devolver:
row_indices = repmat((1:Q).'', 1, k);
linear_ind = sub2ind(size(d), row_indices, ind_closest);
dist_sorted = D(linear_ind);
Cuando ejecutamos el código anterior para los puntos de consulta anteriores, estos son los índices, puntos y distancias que obtenemos:
>> ind_closest
ind_closest =
120 134 53 73 84 77 78 51 64 87
123 119 118 106 132 108 131 136 126 110
107 62 86 122 71 127 139 115 60 52
99 65 58 94 60 61 80 44 54 72
>> x_closest
x_closest(:,:,1) =
5.0000 1.5000
6.7000 2.0000
4.5000 1.7000
3.0000 1.1000
5.1000 1.5000
6.9000 2.3000
4.2000 1.5000
3.6000 1.3000
4.9000 1.5000
6.7000 2.2000
x_closest(:,:,2) =
4.5000 1.6000
3.3000 1.0000
4.9000 1.5000
6.6000 2.1000
4.9000 2.0000
3.3000 1.0000
5.1000 1.6000
6.4000 2.0000
4.8000 1.8000
3.9000 1.4000
x_closest(:,:,3) =
4.8000 1.4000
6.3000 1.8000
4.8000 1.8000
3.5000 1.0000
5.0000 1.7000
6.1000 1.9000
4.8000 1.8000
3.5000 1.0000
4.7000 1.4000
6.1000 2.3000
x_closest(:,:,4) =
5.1000 2.4000
1.6000 0.6000
4.7000 1.4000
6.0000 1.8000
3.9000 1.4000
4.0000 1.3000
4.7000 1.5000
6.1000 2.5000
4.5000 1.5000
4.0000 1.3000
>> dist_sorted
dist_sorted =
0.0500 0.1118 0.1118 0.1118 0.1803 0.2062 0.2500 0.3041 0.3041 0.3041
0.3000 0.3162 0.3606 0.4123 0.6000 0.7280 0.9055 0.9487 1.0198 1.0296
0.9434 1.0198 1.0296 1.0296 1.0630 1.0630 1.0630 1.1045 1.1045 1.1180
2.6000 2.7203 2.8178 2.8178 2.8320 2.9155 2.9155 2.9275 2.9732 2.9732
Para comparar esto con knnsearch
, en su lugar, debe especificar una matriz de puntos para el segundo parámetro donde cada fila es un punto de consulta y verá que los índices y las distancias ordenadas coinciden entre esta implementación y knnsearch
.
Espero que esto te ayude. ¡Buena suerte!