style div r recursion data.table shift

div - tabla de datos | Actualización recursiva más rápida por filas dentro del grupo



tags$div shiny (1)

Gran pregunta

A partir de una nueva sesión R, que muestra los datos de demostración con 5 millones de filas, aquí está su función de la pregunta y el tiempo en mi computadora portátil. Con algunos comentarios en línea.

require(data.table) # v1.10.0 n_smpl = 1e6 ni = 5 id = rep(1:n_smpl, each = ni) smpl = data.table(id) smpl[, time := 1:.N, by = id] a_init = 1; b_init = 1 smpl[, '':='' (a = a_init, b = b_init)] smpl[, xb := (1:.N)*id, by = id] myfun = function (xb, a, b) { z = NULL # initializes a new length-0 variable for (t in 1:length(xb)) { if (t >= 2) { a[t] = b[t-1] + xb[t] } # if() on every iteration. t==1 could be done before loop z[t] = rnorm(1, mean = a[t]) # z vector is grown by 1 item, each time b[t] = a[t] + z[t] # assigns to all of b vector when only really b[t-1] is # needed on the next iteration } return(z) } set.seed(1); system.time(smpl[, z := myfun(xb, a, b), by = id][]) user system elapsed 19.216 0.004 19.212 smpl id time a b xb z 1: 1 1 1 1 1 3.735462e-01 2: 1 2 1 1 2 3.557190e+00 3: 1 3 1 1 3 9.095107e+00 4: 1 4 1 1 4 2.462112e+01 5: 1 5 1 1 5 5.297647e+01 --- 4999996: 1000000 1 1 1 1000000 1.618913e+00 4999997: 1000000 2 1 1 2000000 2.000000e+06 4999998: 1000000 3 1 1 3000000 7.000003e+06 4999999: 1000000 4 1 1 4000000 1.800001e+07 5000000: 1000000 5 1 1 5000000 4.100001e+07

Así que 19.2s es el momento de vencer. En todos estos tiempos, he ejecutado el comando 3 veces localmente para asegurarme de que sea un tiempo estable. La variación de tiempo es insignificante en esta tarea, así que solo informaré un momento para que la respuesta sea más rápida de leer.

Abordando los comentarios en línea arriba en myfun() :

myfun2 = function (xb, a, b) { z = numeric(length(xb)) # allocate size up front rather than growing z[1] = rnorm(1, mean=a[1]) prevb = a[1]+z[1] t = 2L while(t<=length(xb)) { at = prevb + xb[t] z[t] = rnorm(1, mean=at) prevb = at + z[t] t = t+1L } return(z) } set.seed(1); system.time(smpl[, z2 := myfun2(xb, a, b), by = id][]) user system elapsed 13.212 0.036 13.245 smpl[,identical(z,z2)] [1] TRUE

Eso fue bastante bueno (19.2s hasta 13.2s) pero sigue siendo un bucle for en el nivel R. A primera vista, no se puede vectorizar porque la llamada rnorm() depende del valor anterior. De hecho, probablemente se puede vectorizar utilizando la propiedad que m+sd*rnorm(mean=0,sd=1) == rnorm(mean=m, sd=sd) y llamar rnorm(n=5e6) vectorizado rnorm(n=5e6) una vez que 5e6 veces Pero probablemente habría un cumsum() involucrado para tratar con los grupos. Así que no vayamos allí, ya que eso probablemente dificultaría la lectura del código y sería específico para este problema preciso.

Así que probemos Rcpp, que se parece mucho al estilo que escribiste y es más aplicable:

require(Rcpp) # v0.12.8 cppFunction( ''NumericVector myfun3(IntegerVector xb, NumericVector a, NumericVector b) { NumericVector z = NumericVector(xb.length()); z[0] = R::rnorm(/*mean=*/ a[0], /*sd=*/ 1); double prevb = a[0]+z[0]; int t = 1; while (t<xb.length()) { double at = prevb + xb[t]; z[t] = R::rnorm(at, 1); prevb = at + z[t]; t++; } return z; }'') set.seed(1); system.time(smpl[, z3 := myfun3(xb, a, b), by = id][]) user system elapsed 1.800 0.020 1.819 smpl[,identical(z,z3)] [1] TRUE

Mucho mejor: 19.2s hasta 1.8s . Pero cada llamada a la función llama a la primera línea ( NumericVector() ) que asigna un nuevo vector siempre que el número de filas en el grupo. Luego se completa y se devuelve, que se copia en la columna final en el lugar correcto para ese grupo (por := ), solo para ser liberado. Esa asignación y manejo de todos esos 1 millón de vectores temporales pequeños (uno para cada grupo) es todo un poco complicado.

¿Por qué no hacemos toda la columna de una sola vez? Ya lo has escrito en un estilo de bucle for y no hay nada de malo en eso. Modifiquemos la función C para que también acepte la columna id y agregue if para cuando llegue a un nuevo grupo.

cppFunction( ''NumericVector myfun4(IntegerVector id, IntegerVector xb, NumericVector a, NumericVector b) { // ** id must be pre-grouped, such as via setkey(DT,id) ** NumericVector z = NumericVector(id.length()); int previd = id[0]-1; // initialize to anything different than id[0] for (int i=0; i<id.length(); i++) { double prevb; if (id[i]!=previd) { // first row of new group z[i] = R::rnorm(a[i], 1); prevb = a[i]+z[i]; previd = id[i]; } else { // 2nd row of group onwards double at = prevb + xb[i]; z[i] = R::rnorm(at, 1); prevb = at + z[i]; } } return z; }'') system.time(setkey(smpl,id)) # ensure grouped by id user system elapsed 0.028 0.004 0.033 set.seed(1); system.time(smpl[, z4 := myfun4(id, xb, a, b)][]) user system elapsed 0.232 0.004 0.237 smpl[,identical(z,z4)] [1] TRUE

Eso es mejor: 19.2s hasta 0.27s .

Tengo que hacer la siguiente operación recursiva fila por fila para obtener z :

myfun = function (xb, a, b) { z = NULL for (t in 1:length(xb)) { if (t >= 2) { a[t] = b[t-1] + xb[t] } z[t] = rnorm(1, mean = a[t]) b[t] = a[t] + z[t] } return(z) } set.seed(1) n_smpl = 1e6 ni = 5 id = rep(1:n_smpl, each = ni) smpl = data.table(id) smpl[, time := 1:.N, by = id] a_init = 1; b_init = 1 smpl[, '':='' (a = a_init, b = b_init)] smpl[, xb := (1:.N)*id, by = id] smpl[, z := myfun(xb, a, b), by = id]

Me gustaría obtener un resultado como este:

id time a b xb z 1: 1 1 1 1 1 0.3735462 2: 1 2 1 1 2 2.7470924 3: 1 3 1 1 3 8.4941848 4: 1 4 1 1 4 20.9883695 5: 1 5 1 1 5 46.9767390 --- 496: 100 1 1 1 100 0.3735462 497: 100 2 1 1 200 200.7470924 498: 100 3 1 1 300 701.4941848 499: 100 4 1 1 400 1802.9883695 500: 100 5 1 1 500 4105.9767390

Esto funciona pero lleva tiempo:

system.time(smpl[, z := myfun(xb, a, b), by = id]) user system elapsed 33.646 0.994 34.473

Necesito hacerlo más rápido, dado el tamaño de mis datos reales (más de 2 millones de observaciones). Supongo que do.call(myfun, .SD), .SDcols = c(''xb'', ''a'', ''b'') con by = .(id, time) sería mucho más rápido, evitando el bucle for dentro de myfun . Sin embargo, no estaba seguro de cómo puedo actualizar b y su retraso (probablemente usando shift ) cuando ejecuto esta operación fila por fila en data.table . ¿Alguna sugerencia?