para - programa que calcule el cuadrado y el cubo de un numero en c++
Cálculo rápido de cuadrados bignum (2)
Para acelerar mis divisiones bignum, necesito acelerar la operación y = x ^ 2 para los elementos grandes que se representan como matrices dinámicas de DWORD sin signo. Para ser claro:
DWORD x[n+1] = { LSW, ......, MSW };
- donde n + 1 es el número de DWORDs usadas
- entonces el valor del número x = x [0] + x [1] << 32 + ... x [N] << 32 * (n)
La pregunta es: ¿cómo calculo y = x ^ 2 lo más rápido posible sin pérdida de precisión? - Utilizando C ++ y con aritmética de enteros (32 bits con Carry) a disposición.
Mi enfoque actual es aplicar la multiplicación, y = x * x, y evitar multiplicaciones múltiples.
Por ejemplo:
x = x[0] + x[1]<<32 + ... x[n]<<32*(n)
Para simplificar, déjame reescribirlo:
x = x0+ x1 + x2 + ... + xn
donde index representa la dirección dentro de la matriz, entonces:
y = x*x
y = (x0 + x1 + x2 + ...xn)*(x0 + x1 + x2 + ...xn)
y = x0*(x0 + x1 + x2 + ...xn) + x1*(x0 + x1 + x2 + ...xn) + x2*(x0 + x1 + x2 + ...xn) + ...xn*(x0 + x1 + x2 + ...xn)
y0 = x0*x0
y1 = x1*x0 + x0*x1
y2 = x2*x0 + x1*x1 + x0*x2
y3 = x3*x0 + x2*x1 + x1*x2
...
y(2n-3) = xn(n-2)*x(n ) + x(n-1)*x(n-1) + x(n )*x(n-2)
y(2n-2) = xn(n-1)*x(n ) + x(n )*x(n-1)
y(2n-1) = xn(n )*x(n )
Después de una mirada más cercana, está claro que casi todas las xi xj aparecen dos veces (no la primera y la última) lo que significa que las multiplicaciones de N N pueden ser reemplazadas por (N + 1) * (N / 2) multiplicaciones. PS 32bit * 32bit = 64 bits, por lo que el resultado de cada operación mul + add se maneja como 64 + 1 bit.
¿Hay una mejor manera de calcular esto rápido? Todo lo que encontré durante las búsquedas fueron algoritmos de sqrts, no sqr ...
Rápido sqr
!!! Tenga en cuenta que todos los números en mi código son MSW primero, ... no como en la prueba anterior (hay LSW primero por simplicidad de ecuaciones, de lo contrario sería un lío de índice).
Implementación fsqr funcional actual
void arbnum::sqr(const arbnum &x)
{
// O((N+1)*N/2)
arbnum c;
DWORD h, l;
int N, nx, nc, i, i0, i1, k;
c._alloc(x.siz + x.siz + 1);
nx = x.siz - 1;
nc = c.siz - 1;
N = nx + nx;
for (i=0; i<=nc; i++)
c.dat[i]=0;
for (i=1; i<N; i++)
for (i0=0; (i0<=nx) && (i0<=i); i0++)
{
i1 = i - i0;
if (i0 >= i1)
break;
if (i1 > nx)
continue;
h = x.dat[nx-i0];
if (!h)
continue;
l = x.dat[nx-i1];
if (!l)
continue;
alu.mul(h, l, h, l);
k = nc - i;
if (k >= 0)
alu.add(c.dat[k], c.dat[k], l);
k--;
if (k>=0)
alu.adc(c.dat[k], c.dat[k],h);
k--;
for (; (alu.cy) && (k>=0); k--)
alu.inc(c.dat[k]);
}
c.shl(1);
for (i = 0; i <= N; i += 2)
{
i0 = i>>1;
h = x.dat[nx-i0];
if (!h)
continue;
alu.mul(h, l, h, h);
k = nc - i;
if (k >= 0)
alu.add(c.dat[k], c.dat[k],l);
k--;
if (k>=0)
alu.adc(c.dat[k], c.dat[k], h);
k--;
for (; (alu.cy) && (k >= 0); k--)
alu.inc(c.dat[k]);
}
c.bits = c.siz<<5;
c.exp = x.exp + x.exp + ((c.siz - x.siz - x.siz)<<5) + 1;
c.sig = sig;
*this = c;
}
Uso de la multiplicación de Karatsuba
(gracias a Calpis)
Implementé la multiplicación de Karatsuba, pero los resultados son mucho más lentos incluso que con la simple multiplicación O (N ^ 2), probablemente debido a esa recursión horrible que no veo ninguna forma de evitar. Es una compensación debe ser en números realmente grandes (más grandes que cientos de dígitos) ... pero incluso entonces hay una gran cantidad de transferencias de memoria. ¿Hay alguna manera de evitar las llamadas de recursión (variante no recursiva, ... Casi todos los algoritmos recursivos se pueden hacer de esa manera). Aún así, intentaré modificar las cosas y ver qué pasa (evite las normalizaciones, etc., también podría ser un error tonto en el código). De todos modos, después de resolver Karatsuba para el caso x * x, no hay mucha ganancia de rendimiento.
Multiplicación optimizada de Karatsuba
Prueba de rendimiento para y = x ^ 2 en bucle 1000x veces, 0.9 <x <1 ~ 32 * 98 bits:
x = 0.98765588997654321000000009876... | 98*32 bits
sqr [ 213.989 ms ] ... O((N+1)*N/2) fast sqr
mul1[ 363.472 ms ] ... O(N^2) classic multiplication
mul2[ 349.384 ms ] ... O(3*(N^log2(3))) optimized Karatsuba multiplication
mul3[ 9345.127 ms] ... O(3*(N^log2(3))) unoptimized Karatsuba multiplication
x = 0.98765588997654321000... | 195*32 bits
sqr [ 883.01 ms ]
mul1[ 1427.02 ms ]
mul2[ 1089.84 ms ]
x = 0.98765588997654321000... | 389*32 bits
sqr [ 3189.19 ms ]
mul1[ 5553.23 ms ]
mul2[ 3159.07 ms ]
Después de las optimizaciones para Karatsuba, el código es masivamente más rápido que antes. Aún así, para números más pequeños, es un poco menos de la mitad de la velocidad de mi multiplicación de O (N ^ 2). Para números más grandes, es más rápido con la proporción dada por las complejidades de las multiplicaciones de Booth. El umbral para la multiplicación es de alrededor de 32 * 98 bits y para sqr alrededor de 32 * 389 bits, por lo que si la suma de los bits de entrada cruza este umbral, la multiplicación de Karatsuba se utilizará para acelerar la multiplicación y también para sqr.
Por cierto, optimizaciones incluidas:
- Minimice la destrucción de montón por argumento de recursión demasiado grande
- En su lugar, se utiliza la evitación de cualquier aritmética de bignum (+, -) ALU de 32 bits con acarreo.
- Ignorando 0 * y o x * 0 o 0 * 0 casos
- Reformateo de los tamaños de los números de entrada x, y a la potencia de dos para evitar la reasignación
- Implementar la multiplicación de módulo para z1 = (x0 + x1) * (y0 + y1) para minimizar la recursión
Modificación de la multiplicación de Schönhage-Strassen a la implementación de sqr
He probado el uso de transformaciones FTT y NTT para acelerar el cálculo de sqr. Los resultados son estos:
FTT
- Perder precisión y, por lo tanto, necesita números complejos de alta precisión
- Esto realmente ralentiza considerablemente las cosas por lo que no hay aceleración.
- El resultado no es preciso (puede redondearse incorrectamente)
- FTT es inutilizable
NTT
- NTT es campo finito DFT y por lo tanto no se produce pérdida de precisión.
- Necesidad de aritmética modular en enteros sin signo: modpow, modmul, modadd y modsub
- Yo uso DWORD (números enteros sin signo de 32 bits).
- NTT input / otput vector size es limitado debido a problemas de desbordamiento !!! Para aritmética modular de 32 bits, N está limitado a (2 ^ 32) / (máximo (entrada []) ^ 2), por lo que bigint debe dividirse en fragmentos más pequeños (yo uso BYTES, por lo que el tamaño máximo de bigint procesado es (2) ^ 32) / ((2 ^ 8) ^ 2) = 2 ^ 16 bytes = 2 ^ 14 DWORDs = 16384 DWORDs).
- sqr usa solo 1xNTT + 1xINTT en lugar de 2xNTT + 1xINTT para la multiplicación.
- El uso de NTT es demasiado lento y el tamaño del número de umbral es demasiado grande para el uso práctico en mi implementación (para mul y también para sqr), es posible incluso por encima del límite de desbordamiento así que se deben usar aritméticas modulares de 64 bits que pueden ralentizar abajo aún más.
- NTT es para mis propósitos también inutilizable
Algunas medidas:
a = 0.98765588997654321000 | 389*32 bits
looped 1x times
sqr1[ 3.177 ms ] fast sqr
sqr2[ 720.419 ms ] NTT sqr
mul1[ 5.588 ms ] simpe mul
mul2[ 3.172 ms ] karatsuba mul
mul3[ 1053.382 ms ] NTT mul
Mi implementación:
void arbnum::sqr_NTT(const arbnum &x)
{
// O(N*log(N)*(log(log(N)))) - 1x NTT
// Schönhage-Strassen sqr
// To prevent NTT overflow: n <= 48K * 8 bit -> result siz <= 12K * 32 bit -> x.siz + y.siz <= 12K!!!
int i, j, k, n;
int s = x.sig*x.sig, exp0 = x.exp + x.exp - ((x.siz+x.siz)<<5) + 2;
i = x.siz;
for (n = 1; n < i; n<<=1)
;
if (n + n > 0x3000) {
_error(_arbnum_error_TooBigNumber);
zero();
return;
}
n <<= 3;
DWORD *xx, *yy, q, qq;
xx = new DWORD[n+n];
#ifdef _mmap_h
if (xx)
mmap_new(xx, (n+n) << 2);
#endif
if (xx==NULL) {
_error(_arbnum_error_NotEnoughMemory);
zero();
return;
}
yy = xx + n;
// Zero padding (and split DWORDs to BYTEs)
for (i--, k=0; i >= 0; i--)
{
q = x.dat[i];
xx[k] = q&0xFF; k++; q>>=8;
xx[k] = q&0xFF; k++; q>>=8;
xx[k] = q&0xFF; k++; q>>=8;
xx[k] = q&0xFF; k++;
}
for (;k<n;k++)
xx[k] = 0;
//NTT
fourier_NTT ntt;
ntt.NTT(yy,xx,n); // init NTT for n
// Convolution
for (i=0; i<n; i++)
yy[i] = modmul(yy[i], yy[i], ntt.p);
//INTT
ntt.INTT(xx, yy);
//suma
q=0;
for (i = 0, j = 0; i<n; i++) {
qq = xx[i];
q += qq&0xFF;
yy[n-i-1] = q&0xFF;
q>>=8;
qq>>=8;
q+=qq;
}
// Merge WORDs to DWORDs and copy them to result
_alloc(n>>2);
for (i = 0, j = 0; i<siz; i++)
{
q =(yy[j]<<24)&0xFF000000; j++;
q |=(yy[j]<<16)&0x00FF0000; j++;
q |=(yy[j]<< 8)&0x0000FF00; j++;
q |=(yy[j] )&0x000000FF; j++;
dat[i] = q;
}
#ifdef _mmap_h
if (xx)
mmap_del(xx);
#endif
delete xx;
bits = siz<<5;
sig = s;
exp = exp0 + (siz<<5) - 1;
// _normalize();
}
Conclusión
Para números más pequeños, es la mejor opción mi enfoque rápido sqr, y después del umbral la multiplicación de karatsuba es mejor. Pero todavía creo que debería haber algo trivial que hemos pasado por alto. ¿Alguien tiene otras ideas?
Optimización NTT
Después de optimizaciones intensamente intensas (en su mayoría NTT): pregunta sobre desbordamiento de pila Aritmética modular y optimizaciones NTT (campo finito DFT) .
Algunos valores han cambiado:
a = 0.98765588997654321000 | 1553*32bits
looped 10x times
mul2[ 28.585 ms ] Karatsuba mul
mul3[ 26.311 ms ] NTT mul
Así que ahora la multiplicación de NTT es finalmente más rápida que Karatsuba después de un umbral de aproximadamente 1500 * 32 bits.
Algunas medidas y errores detectados
a = 0.99991970486 | 1553*32 bits
looped: 10x
sqr1[ 58.656 ms ] fast sqr
sqr2[ 13.447 ms ] NTT sqr
mul1[ 102.563 ms ] simpe mul
mul2[ 28.916 ms ] Karatsuba mul Error
mul3[ 19.470 ms ] NTT mul
Descubrí que mi Karatsuba (más / menos) fluye el LSB de cada segmento DWORD de bignum. Cuando investigué, actualizaré el código ...
Además, después de nuevas optimizaciones NTT los umbrales cambiaron, por lo que para NTT sqr es 310 * 32 bits = 9920 bits de operando , y para NTT mul es 1396 * 32 bits = 44672 bits de resultado (suma de bits de operandos).
Código de Karatsuba reparado gracias a @greybeard
//---------------------------------------------------------------------------
void arbnum::_mul_karatsuba(DWORD *z, DWORD *x, DWORD *y, int n)
{
// Recursion for Karatsuba
// z[2n] = x[n]*y[n];
// n=2^m
int i;
for (i=0; i<n; i++)
if (x[i]) {
i=-1;
break;
} // x==0 ?
if (i < 0)
for (i = 0; i<n; i++)
if (y[i]) {
i = -1;
break;
} // y==0 ?
if (i >= 0) {
for (i = 0; i < n + n; i++)
z[i]=0;
return;
} // 0.? = 0
if (n == 1) {
alu.mul(z[0], z[1], x[0], y[0]);
return;
}
if (n< 1)
return;
int n2 = n>>1;
_mul_karatsuba(z+n, x+n2, y+n2, n2); // z0 = x0.y0
_mul_karatsuba(z , x , y , n2); // z2 = x1.y1
DWORD *q = new DWORD[n<<1], *q0, *q1, *qq;
BYTE cx,cy;
if (q == NULL) {
_error(_arbnum_error_NotEnoughMemory);
return;
}
#define _add { alu.add(qq[i], q0[i], q1[i]); for (i--; i>=0; i--) alu.adc(qq[i], q0[i], q1[i]); } // qq = q0 + q1 ...[i..0]
#define _sub { alu.sub(qq[i], q0[i], q1[i]); for (i--; i>=0; i--) alu.sbc(qq[i], q0[i], q1[i]); } // qq = q0 - q1 ...[i..0]
qq = q;
q0 = x + n2;
q1 = x;
i = n2 - 1;
_add;
cx = alu.cy; // =x0+x1
qq = q + n2;
q0 = y + n2;
q1 = y;
i = n2 - 1;
_add;
cy = alu.cy; // =y0+y1
_mul_karatsuba(q + n, q + n2, q, n2); // =(x0+x1)(y0+y1) mod ((2^N)-1)
if (cx) {
qq = q + n;
q0 = qq;
q1 = q + n2;
i = n2 - 1;
_add;
cx = alu.cy;
}// += cx*(y0 + y1) << n2
if (cy) {
qq = q + n;
q0 = qq;
q1 = q;
i = n2 -1;
_add;
cy = alu.cy;
}// +=cy*(x0+x1)<<n2
qq = q + n; q0 = qq; q1 = z + n; i = n - 1; _sub; // -=z0
qq = q + n; q0 = qq; q1 = z; i = n - 1; _sub; // -=z2
qq = z + n2; q0 = qq; q1 = q + n; i = n - 1; _add; // z1=(x0+x1)(y0+y1)-z0-z2
DWORD ccc=0;
if (alu.cy)
ccc++; // Handle carry from last operation
if (cx || cy)
ccc++; // Handle carry from before last operation
if (ccc)
{
i = n2 - 1;
alu.add(z[i], z[i], ccc);
for (i--; i>=0; i--)
if (alu.cy)
alu.inc(z[i]);
else
break;
}
delete[] q;
#undef _add
#undef _sub
}
//---------------------------------------------------------------------------
void arbnum::mul_karatsuba(const arbnum &x, const arbnum &y)
{
// O(3*(N)^log2(3)) ~ O(3*(N^1.585))
// Karatsuba multiplication
//
int s = x.sig*y.sig;
arbnum a, b;
a = x;
b = y;
a.sig = +1;
b.sig = +1;
int i, n;
for (n = 1; (n < a.siz) || (n < b.siz); n <<= 1)
;
a._realloc(n);
b._realloc(n);
_alloc(n + n);
for (i=0; i < siz; i++)
dat[i]=0;
_mul_karatsuba(dat, a.dat, b.dat, n);
bits = siz << 5;
sig = s;
exp = a.exp + b.exp + ((siz-a.siz-b.siz)<<5) + 1;
// _normalize();
}
//---------------------------------------------------------------------------
Mi representación numérica arbnum
:
// dat is MSDW first ... LSDW last
DWORD *dat; int siz,exp,sig,bits;
-
dat[siz]
es la mantisa. LSDW significa DWORD menos significativo. -
exp
es el exponente de MSB dedat[0]
¡El primer bit distinto de cero está presente en la mantisa!
// |-----|---------------------------|---------------|------| // | sig | MSB mantisa LSB | exponent | bits | // |-----|---------------------------|---------------|------| // | +1 | 0.(0 ... 0) | 2^0 | 0 | +zero // | -1 | 0.(0 ... 0) | 2^0 | 0 | -zero // |-----|---------------------------|---------------|------| // | +1 | 1.(dat[0] ... dat[siz-1]) | 2^exp | n | +number // | -1 | 1.(dat[0] ... dat[siz-1]) | 2^exp | n | -number // |-----|---------------------------|---------------|------| // | +1 | 1.0 | 2^+0x7FFFFFFE | 1 | +infinity // | -1 | 1.0 | 2^+0x7FFFFFFE | 1 | -infinity // |-----|---------------------------|---------------|------|
Si entiendo tu algoritmo correctamente, parece O(n^2)
donde n
es el número de dígitos.
¿Has mirado Algoritmo Karatsuba ? Acelera la multiplicación usando el enfoque de dividir y conquistar. Puede valer la pena echarle un vistazo.
Si está buscando escribir un nuevo exponente mejor, es posible que deba escribirlo en ensamblaje. Este es el código de golang.
https://code.google.com/p/go/source/browse/src/pkg/math/exp_amd64.s