optimization clojure type-hinting numerical-computing

optimization - Rápida aritmética de números complejos en Clojure



type-hinting numerical-computing (2)

Las razones probables de tu bajo rendimiento son:

  • Los vectores Clojure son estructuras de datos intrínsecamente más pesadas que las matrices Java dobles []. Así que tienes un poco de sobrecarga adicional en la creación y lectura de vectores.
  • Estás jugando dobles de boxeo como argumentos para tus funciones y también cuando se ponen en vectores. El boxeo / unboxing es relativamente caro en este tipo de código numérico de bajo nivel.
  • Las sugerencias de tipo ( ^double ) no te ayudan: aunque puedes tener sugerencias de tipo primitivo sobre las funciones Clojure normales, no funcionarán en vectores.

Vea esta publicación en el blog sobre aceleración de la aritmética primitiva para más detalles.

Si realmente quieres números complejos rápidos en Clojure, probablemente necesitarás implementarlos usando deftype , algo así como:

(deftype Complex [^double real ^double imag])

Y luego define todas tus funciones complejas usando este tipo. Esto le permitirá utilizar la aritmética primitiva en todas partes, y debería ser más o menos equivalente al rendimiento del código Java bien escrito.

Estaba implementando una aritmética básica de números complejos en Clojure, y noté que era aproximadamente 10 veces más lenta que el código Java aproximadamente equivalente, incluso con sugerencias de tipo.

Comparar:

(defn plus [[^double x1 ^double y1] [^double x2 ^double y2]] [(+ x1 x2) (+ y1 y2)]) (defn times [[^double x1 ^double y1] [^double x2 ^double y2]] [(- (* x1 x2) (* y1 y2)) (+ (* x1 y2) (* y1 x2))]) (time (dorun (repeatedly 100000 #(plus [1 0] [0 1])))) (time (dorun (repeatedly 100000 #(times [1 0] [0 1]))))

salida:

"Elapsed time: 69.429796 msecs" "Elapsed time: 72.232479 msecs"

con:

public static void main( String[] args ) { double[] z1 = new double[] { 1, 0 }; double[] z2 = new double[] { 0, 1 }; double[] z3 = null; long l_StartTimeMillis = System.currentTimeMillis(); for ( int i = 0; i < 100000; i++ ) { z3 = plus( z1, z2 ); // assign result to dummy var to stop compiler from optimising the loop away } long l_EndTimeMillis = System.currentTimeMillis(); long l_TimeTakenMillis = l_EndTimeMillis - l_StartTimeMillis; System.out.format( "Time taken: %d millis/n", l_TimeTakenMillis ); l_StartTimeMillis = System.currentTimeMillis(); for ( int i = 0; i < 100000; i++ ) { z3 = times( z1, z2 ); } l_EndTimeMillis = System.currentTimeMillis(); l_TimeTakenMillis = l_EndTimeMillis - l_StartTimeMillis; System.out.format( "Time taken: %d millis/n", l_TimeTakenMillis ); doNothing( z3 ); } private static void doNothing( double[] z ) { } public static double[] plus (double[] z1, double[] z2) { return new double[] { z1[0] + z2[0], z1[1] + z2[1] }; } public static double[] times (double[] z1, double[] z2) { return new double[] { z1[0]*z2[0] - z1[1]*z2[1], z1[0]*z2[1] + z1[1]*z2[0] }; }

salida:

Time taken: 6 millis Time taken: 6 millis

De hecho, los consejos de tipo no parecen marcar la diferencia: si los elimino obtengo aproximadamente el mismo resultado. Lo que es realmente extraño es que si ejecuto el script Clojure sin un REPL, obtengo resultados más lentos:

"Elapsed time: 137.337782 msecs" "Elapsed time: 214.213993 msecs"

Entonces mis preguntas son: ¿cómo puedo acercarme al rendimiento del código Java? ¿Y por qué en la Tierra las expresiones tardan más en evaluarse cuando se ejecuta clojure sin un REPL?

ACTUALIZACIÓN ==============

Genial, usando deftype con sugerencias de tipo en deftype y en defn s, y usar dotimes lugar de repeatedly da un rendimiento tan bueno o mejor que la versión de Java. Gracias a los dos.

(deftype complex [^double real ^double imag]) (defn plus [^complex z1 ^complex z2] (let [x1 (double (.real z1)) y1 (double (.imag z1)) x2 (double (.real z2)) y2 (double (.imag z2))] (complex. (+ x1 x2) (+ y1 y2)))) (defn times [^complex z1 ^complex z2] (let [x1 (double (.real z1)) y1 (double (.imag z1)) x2 (double (.real z2)) y2 (double (.imag z2))] (complex. (- (* x1 x2) (* y1 y2)) (+ (* x1 y2) (* y1 x2))))) (println "Warm up") (time (dorun (repeatedly 100000 #(plus (complex. 1 0) (complex. 0 1))))) (time (dorun (repeatedly 100000 #(times (complex. 1 0) (complex. 0 1))))) (time (dorun (repeatedly 100000 #(plus (complex. 1 0) (complex. 0 1))))) (time (dorun (repeatedly 100000 #(times (complex. 1 0) (complex. 0 1))))) (time (dorun (repeatedly 100000 #(plus (complex. 1 0) (complex. 0 1))))) (time (dorun (repeatedly 100000 #(times (complex. 1 0) (complex. 0 1))))) (println "Try with dorun") (time (dorun (repeatedly 100000 #(plus (complex. 1 0) (complex. 0 1))))) (time (dorun (repeatedly 100000 #(times (complex. 1 0) (complex. 0 1))))) (println "Try with dotimes") (time (dotimes [_ 100000] (plus (complex. 1 0) (complex. 0 1)))) (time (dotimes [_ 100000] (times (complex. 1 0) (complex. 0 1))))

Salida:

Warm up "Elapsed time: 92.805664 msecs" "Elapsed time: 164.929421 msecs" "Elapsed time: 23.799012 msecs" "Elapsed time: 32.841624 msecs" "Elapsed time: 20.886101 msecs" "Elapsed time: 18.872783 msecs" Try with dorun "Elapsed time: 19.238403 msecs" "Elapsed time: 17.856938 msecs" Try with dotimes "Elapsed time: 5.165658 msecs" "Elapsed time: 5.209027 msecs"


  • No sé mucho sobre las pruebas comparativas, pero parece que necesitas calentar jvm cuando comiences la prueba. Entonces cuando lo haces en REPL ya está calentado. Cuando se ejecuta como script todavía no es.

  • En java, ejecuta todos los bucles dentro de 1 método. No se llama ningún otro método excepto el plus y los times . En clojure crea una función anónima y llama repetidamente para llamarlo. Lleva algo de tiempo. Puedes reemplazarlo con dotimes .

Mi intento:

(println "Warm up") (time (dorun (repeatedly 100000 #(plus [1 0] [0 1])))) (time (dorun (repeatedly 100000 #(times [1 0] [0 1])))) (time (dorun (repeatedly 100000 #(plus [1 0] [0 1])))) (time (dorun (repeatedly 100000 #(times [1 0] [0 1])))) (time (dorun (repeatedly 100000 #(plus [1 0] [0 1])))) (time (dorun (repeatedly 100000 #(times [1 0] [0 1])))) (println "Try with dorun") (time (dorun (repeatedly 100000 #(plus [1 0] [0 1])))) (time (dorun (repeatedly 100000 #(times [1 0] [0 1])))) (println "Try with dotimes") (time (dotimes [_ 100000] (plus [1 0] [0 1]))) (time (dotimes [_ 100000] (times [1 0] [0 1])))

Resultados:

Warm up "Elapsed time: 367.569195 msecs" "Elapsed time: 493.547628 msecs" "Elapsed time: 116.832979 msecs" "Elapsed time: 46.862176 msecs" "Elapsed time: 27.805174 msecs" "Elapsed time: 28.584179 msecs" Try with dorun "Elapsed time: 26.540489 msecs" "Elapsed time: 27.64626 msecs" Try with dotimes "Elapsed time: 7.3792 msecs" "Elapsed time: 5.940705 msecs"