div - tabla de datos | Actualización recursiva más rápida por filas dentro del grupo
tags$div shiny (1)
Gran pregunta
A partir de una nueva sesión R, que muestra los datos de demostración con 5 millones de filas, aquí está su función de la pregunta y el tiempo en mi computadora portátil. Con algunos comentarios en línea.
require(data.table) # v1.10.0
n_smpl = 1e6
ni = 5
id = rep(1:n_smpl, each = ni)
smpl = data.table(id)
smpl[, time := 1:.N, by = id]
a_init = 1; b_init = 1
smpl[, '':='' (a = a_init, b = b_init)]
smpl[, xb := (1:.N)*id, by = id]
myfun = function (xb, a, b) {
z = NULL
# initializes a new length-0 variable
for (t in 1:length(xb)) {
if (t >= 2) { a[t] = b[t-1] + xb[t] }
# if() on every iteration. t==1 could be done before loop
z[t] = rnorm(1, mean = a[t])
# z vector is grown by 1 item, each time
b[t] = a[t] + z[t]
# assigns to all of b vector when only really b[t-1] is
# needed on the next iteration
}
return(z)
}
set.seed(1); system.time(smpl[, z := myfun(xb, a, b), by = id][])
user system elapsed
19.216 0.004 19.212
smpl
id time a b xb z
1: 1 1 1 1 1 3.735462e-01
2: 1 2 1 1 2 3.557190e+00
3: 1 3 1 1 3 9.095107e+00
4: 1 4 1 1 4 2.462112e+01
5: 1 5 1 1 5 5.297647e+01
---
4999996: 1000000 1 1 1 1000000 1.618913e+00
4999997: 1000000 2 1 1 2000000 2.000000e+06
4999998: 1000000 3 1 1 3000000 7.000003e+06
4999999: 1000000 4 1 1 4000000 1.800001e+07
5000000: 1000000 5 1 1 5000000 4.100001e+07
Así que 19.2s es el momento de vencer. En todos estos tiempos, he ejecutado el comando 3 veces localmente para asegurarme de que sea un tiempo estable. La variación de tiempo es insignificante en esta tarea, así que solo informaré un momento para que la respuesta sea más rápida de leer.
Abordando los comentarios en línea arriba en myfun()
:
myfun2 = function (xb, a, b) {
z = numeric(length(xb))
# allocate size up front rather than growing
z[1] = rnorm(1, mean=a[1])
prevb = a[1]+z[1]
t = 2L
while(t<=length(xb)) {
at = prevb + xb[t]
z[t] = rnorm(1, mean=at)
prevb = at + z[t]
t = t+1L
}
return(z)
}
set.seed(1); system.time(smpl[, z2 := myfun2(xb, a, b), by = id][])
user system elapsed
13.212 0.036 13.245
smpl[,identical(z,z2)]
[1] TRUE
Eso fue bastante bueno (19.2s hasta 13.2s) pero sigue siendo un bucle for
en el nivel R. A primera vista, no se puede vectorizar porque la llamada rnorm()
depende del valor anterior. De hecho, probablemente se puede vectorizar utilizando la propiedad que m+sd*rnorm(mean=0,sd=1) == rnorm(mean=m, sd=sd)
y llamar rnorm(n=5e6)
vectorizado rnorm(n=5e6)
una vez que 5e6 veces Pero probablemente habría un cumsum()
involucrado para tratar con los grupos. Así que no vayamos allí, ya que eso probablemente dificultaría la lectura del código y sería específico para este problema preciso.
Así que probemos Rcpp, que se parece mucho al estilo que escribiste y es más aplicable:
require(Rcpp) # v0.12.8
cppFunction(
''NumericVector myfun3(IntegerVector xb, NumericVector a, NumericVector b) {
NumericVector z = NumericVector(xb.length());
z[0] = R::rnorm(/*mean=*/ a[0], /*sd=*/ 1);
double prevb = a[0]+z[0];
int t = 1;
while (t<xb.length()) {
double at = prevb + xb[t];
z[t] = R::rnorm(at, 1);
prevb = at + z[t];
t++;
}
return z;
}'')
set.seed(1); system.time(smpl[, z3 := myfun3(xb, a, b), by = id][])
user system elapsed
1.800 0.020 1.819
smpl[,identical(z,z3)]
[1] TRUE
Mucho mejor: 19.2s hasta 1.8s . Pero cada llamada a la función llama a la primera línea ( NumericVector()
) que asigna un nuevo vector siempre que el número de filas en el grupo. Luego se completa y se devuelve, que se copia en la columna final en el lugar correcto para ese grupo (por :=
), solo para ser liberado. Esa asignación y manejo de todos esos 1 millón de vectores temporales pequeños (uno para cada grupo) es todo un poco complicado.
¿Por qué no hacemos toda la columna de una sola vez? Ya lo has escrito en un estilo de bucle for y no hay nada de malo en eso. Modifiquemos la función C para que también acepte la columna id
y agregue if
para cuando llegue a un nuevo grupo.
cppFunction(
''NumericVector myfun4(IntegerVector id, IntegerVector xb, NumericVector a, NumericVector b) {
// ** id must be pre-grouped, such as via setkey(DT,id) **
NumericVector z = NumericVector(id.length());
int previd = id[0]-1; // initialize to anything different than id[0]
for (int i=0; i<id.length(); i++) {
double prevb;
if (id[i]!=previd) {
// first row of new group
z[i] = R::rnorm(a[i], 1);
prevb = a[i]+z[i];
previd = id[i];
} else {
// 2nd row of group onwards
double at = prevb + xb[i];
z[i] = R::rnorm(at, 1);
prevb = at + z[i];
}
}
return z;
}'')
system.time(setkey(smpl,id)) # ensure grouped by id
user system elapsed
0.028 0.004 0.033
set.seed(1); system.time(smpl[, z4 := myfun4(id, xb, a, b)][])
user system elapsed
0.232 0.004 0.237
smpl[,identical(z,z4)]
[1] TRUE
Eso es mejor: 19.2s hasta 0.27s .
Tengo que hacer la siguiente operación recursiva fila por fila para obtener z
:
myfun = function (xb, a, b) {
z = NULL
for (t in 1:length(xb)) {
if (t >= 2) { a[t] = b[t-1] + xb[t] }
z[t] = rnorm(1, mean = a[t])
b[t] = a[t] + z[t]
}
return(z)
}
set.seed(1)
n_smpl = 1e6
ni = 5
id = rep(1:n_smpl, each = ni)
smpl = data.table(id)
smpl[, time := 1:.N, by = id]
a_init = 1; b_init = 1
smpl[, '':='' (a = a_init, b = b_init)]
smpl[, xb := (1:.N)*id, by = id]
smpl[, z := myfun(xb, a, b), by = id]
Me gustaría obtener un resultado como este:
id time a b xb z
1: 1 1 1 1 1 0.3735462
2: 1 2 1 1 2 2.7470924
3: 1 3 1 1 3 8.4941848
4: 1 4 1 1 4 20.9883695
5: 1 5 1 1 5 46.9767390
---
496: 100 1 1 1 100 0.3735462
497: 100 2 1 1 200 200.7470924
498: 100 3 1 1 300 701.4941848
499: 100 4 1 1 400 1802.9883695
500: 100 5 1 1 500 4105.9767390
Esto funciona pero lleva tiempo:
system.time(smpl[, z := myfun(xb, a, b), by = id])
user system elapsed
33.646 0.994 34.473
Necesito hacerlo más rápido, dado el tamaño de mis datos reales (más de 2 millones de observaciones). Supongo que do.call(myfun, .SD), .SDcols = c(''xb'', ''a'', ''b'')
con by = .(id, time)
sería mucho más rápido, evitando el bucle for dentro de myfun
. Sin embargo, no estaba seguro de cómo puedo actualizar b
y su retraso (probablemente usando shift
) cuando ejecuto esta operación fila por fila en data.table
. ¿Alguna sugerencia?