Complex Model Simulation

Package Loading

Before loading the package, we should allocate enough memory for Java. Here we allocate 10GB of memory for Java.

set.seed(123)
# Allocate 10GB of memory for Java. Must be called before library(iBART)
options(java.parameters = "-Xmx10g") 
library(iBART)

Complex Model

In this vignette, we will run iBART on the complex model described in Section 3.4 of the paper, i.e. the data-generating model is \[y = 15\{\exp(x_1)-\exp(x_2)\}^2 + 20\sin(\pi x_3x_4) + \varepsilon, \qquad\varepsilon\sim \mathcal{N}_n(0, \sigma^2 I).\] The primary features are \(X = (x_1,...,x_p)\), where \(x_1,...,x_p \overset{\text{iid}}\sim\text{Unif}_n(-1,1)\). We will use the following setting: \(n = 250\), \(p = 10\), and \(\sigma = 0.5\). The goal in OIS is to identify the 2 true descriptors: \(f_1(X) = \{\exp(x_1)-\exp(x_2)\}^2\) and \(f_2(X) = \sin(\pi x_3x_4)\) using only \((y, X)\) as input. Let’s generate the primary features \(X\) and the response variable \(y\).

#### Simulation Parameters ####
n <- 250 # Change n to 100 here to reproduce result in Supplementary Materials A.2.3
p <- 10  # Number of primary features

#### Generate Data ####
X <- matrix(runif(n * p, min = -1, max = 1), nrow = n, ncol = p)
colnames(X) <- paste("x.", seq(from = 1, to = p, by = 1), sep = "")
y <- 15 * (exp(X[, 1]) - exp(X[, 2]))^2 + 20 * sin(pi * X[, 3] * X[, 4]) + rnorm(n, mean = 0, sd = 0.5)

Note that the input data to iBART are only \(y\) and \(X = (x_1,...,x_{10})\), and iBART needs to

  1. Generate the correct descriptors \(f_1(X)\) and \(f_2(X)\)
  2. Select the correct descriptors

At Iteration 0 (base iteration), iBART determines which of the primary features, \((x_1,\ldots,x_{10})\), are useful and only apply operators on the useful ones. If successful, iBART should keep \((x_1,\ldots,x_4)\) in the loop and discard \((x_5,\ldots,x_{10})\). Let’s run iBART for 1 iteration (Iteration 0 + Iteration 1) and examine its outputs.

#### iBART ####
iBART_sim <- iBART(X = X, y = y,
                   head = colnames(X),
                   num_burn_in = 5000,                   # lower value for faster run
                   num_iterations_after_burn_in = 1000,  # lower value for faster run
                   num_permute_samples = 20,             # lower value for faster run
                   opt = c("unary"), # only apply unary operators after base iteration
                   sin_cos = TRUE,
                   apply_pos_opt_on_neg_x = FALSE,
                   Lzero = TRUE,
                   K = 4,
                   standardize = FALSE,
                   seed = 123)
#> Start iBART descriptor generation and selection... 
#> Iteration 1 
#> iBART descriptor selection... 
#> avg..........null....................
#> Constructing descriptors using unary operators... 
#> BART iteration done! 
#> LASSO descriptor selection... 
#> L-zero regression... 
#> Total time: 20.0548620223999 secs

iBART() returns a list object that contains many interesting outputs; see ?iBART::iBART for a full list of return values. Here we focus on 2 return values:

We can use the iBART model the same way we would use a glmnet model. For instance, we can print out the coefficients using coef().

# iBART selected descriptors
iBART_sim$descriptor_names
#>  [1] "x.1^2"       "x.2^2"       "x.4^2"       "exp(x.1)"    "exp(x.2)"   
#>  [6] "exp(x.3)"    "exp(x.4)"    "sin(pi*x.1)" "sin(pi*x.2)" "sin(pi*x.3)"
#> [11] "sin(pi*x.4)" "cos(pi*x.3)" "x.1^(-1)"    "x.2^(-1)"    "x.3^(-1)"   
#> [16] "x.4^(-1)"    "abs(x.2)"

# iBART model
coef(iBART_sim$iBART_model, s = "lambda.min")
#> 29 x 1 sparse Matrix of class "dgCMatrix"
#>                       1
#> (Intercept) -8.98377149
#> x.1          .         
#> x.2          .         
#> x.3          .         
#> x.4          .         
#> x.1^2       18.48979389
#> x.2^2        8.44944372
#> x.3^2        .         
#> x.4^2       -2.91764266
#> exp(x.1)     5.51582287
#> exp(x.2)     9.74121269
#> exp(x.3)    -0.49444789
#> exp(x.4)    -5.47704490
#> sin(pi*x.1) -2.22769247
#> sin(pi*x.2) -7.06001108
#> sin(pi*x.3) -1.39995518
#> sin(pi*x.4)  5.10808704
#> cos(pi*x.1)  .         
#> cos(pi*x.2)  .         
#> cos(pi*x.3)  0.23919051
#> cos(pi*x.4)  .         
#> x.1^(-1)    -0.05683143
#> x.2^(-1)     0.04017341
#> x.3^(-1)    -0.01126814
#> x.4^(-1)    -0.03267275
#> abs(x.1)     .         
#> abs(x.2)     2.96458665
#> abs(x.3)     .         
#> abs(x.4)     .

iBART_sim$descriptor_names contains the name of the selected descriptors at the last iteration (Iteration 1) and coef(iBART_sim$iBART_model, s = "lambda.min") shows the input descriptors at the last iteration (Iteration 1) and their coefficients. Notice that the first 4 descriptors in coef(iBART_sim$iBART_model, s = "lambda.min") are \(x_1,\ldots,x_4\). This indicates that iBART discarded \(x_5,\ldots,x_{10}\) and kept \(x_1,\ldots,x_4\) in the loop at Iteration 0.

At Iteration 1, iBART applied unary operators to \(x_1,\ldots,x_4\), yielding \[x_i, x_i^2, \exp(x_i), \sin(\pi x_i), \cos(\pi x_i), x_i^{-1}, |x_i|, \qquad\text{for } i = 1,2,3,4.\] Among them, iBART selected 2 active intermediate descriptors: \(\exp(x_1)\) and \(\exp(x_2)\), which are needed to generate \(f_1(X) = \{\exp(x_1)-\exp(x_2)\}^2\). This is very promising. Note that we don’t have \(\sqrt{x_i}\) and \(\log(x_i)\) here because \(\sqrt{\cdot}\) and \(\log(\cdot)\) are only defined if \(x_i\)’s are positive. We can overwrite this by setting apply_pos_opt_on_neg_x = TRUE; this effectively generates \(\sqrt{|x_i|}\) and \(\log(|x_i|)\).

To save time, we cached the result of a complete run in data("iBART_sim", package = "iBART"), which can be replicated by using the following code.

iBART_sim <- iBART(X = X, y = y,
                   head = colnames(X),
                   opt = c("unary", "binary", "unary"), 
                   sin_cos = TRUE,
                   apply_pos_opt_on_neg_x = FALSE,
                   Lzero = TRUE,
                   K = 4,
                   standardize = FALSE,
                   seed = 123)

Let’s load the full result and see how iBART did.

load("../data/iBART_sim.rda")                 # load full result

iBART_sim$descriptor_names                    # iBART selected descriptors
#> [1] "(exp(x.1)-exp(x.2))^2" "sin(pi*(x.3*x.4))"
coef(iBART_sim$iBART_model, s = "lambda.min") # iBART model
#> 146 x 1 sparse Matrix of class "dgCMatrix"
#>                                      1
#> (Intercept)                  0.1928037
#> x.1                          .        
#> x.2                          .        
#> x.3                          .        
#> x.4                          .        
#> exp(x.1)                     .        
#> exp(x.2)                     .        
#> exp(x.3)                     .        
#> exp(x.4)                     .        
#> (x.2+exp(x.2))               .        
#> (x.1-exp(x.2))               .        
#> (x.2-exp(x.1))               .        
#> (exp(x.1)-exp(x.2))          .        
#> (x.1*x.2)                    .        
#> (x.2*exp(x.1))               .        
#> (x.3*x.4)                    .        
#> (exp(x.1)*exp(x.2))          .        
#> (x.2/exp(x.1))               .        
#> (exp(x.2)/exp(x.1))          .        
#> |x.1-x.2|                    .        
#> |x.3-x.4|                    .        
#> |exp(x.1)-exp(x.2)|          .        
#> exp(x.1)^0.5                 .        
#> exp(x.2)^0.5                 .        
#> exp(x.3)^0.5                 .        
#> exp(x.4)^0.5                 .        
#> (exp(x.1)*exp(x.2))^0.5      .        
#> (exp(x.2)/exp(x.1))^0.5      .        
#> |x.1-x.2|^0.5                .        
#> |x.3-x.4|^0.5                .        
#> |exp(x.1)-exp(x.2)|^0.5      .        
#> x.1^2                        .        
#> x.2^2                        .        
#> x.3^2                        .        
#> x.4^2                        .        
#> exp(x.1)^2                   .        
#> exp(x.2)^2                   .        
#> exp(x.3)^2                   .        
#> exp(x.4)^2                   .        
#> (x.2+exp(x.2))^2             .        
#> (x.1-exp(x.2))^2             .        
#> (x.2-exp(x.1))^2             .        
#> (exp(x.1)-exp(x.2))^2       14.7643022
#> (x.1*x.2)^2                  .        
#> (x.2*exp(x.1))^2             .        
#> (x.3*x.4)^2                  .        
#> (exp(x.1)*exp(x.2))^2        .        
#> (x.2/exp(x.1))^2             .        
#> (exp(x.2)/exp(x.1))^2        .        
#> |x.1-x.2|^2                  .        
#> |x.3-x.4|^2                  .        
#> log((exp(x.1)*exp(x.2)))     .        
#> log((exp(x.2)/exp(x.1)))     .        
#> log(|x.1-x.2|)               .        
#> log(|x.3-x.4|)               .        
#> log(|exp(x.1)-exp(x.2)|)     .        
#> exp(exp(x.1))                .        
#> exp(exp(x.2))                .        
#> exp(exp(x.3))                .        
#> exp(exp(x.4))                .        
#> exp((x.2+exp(x.2)))          .        
#> exp((x.1-exp(x.2)))          .        
#> exp((x.2-exp(x.1)))          .        
#> exp((exp(x.1)-exp(x.2)))     .        
#> exp((x.1*x.2))               .        
#> exp((x.2*exp(x.1)))          .        
#> exp((x.3*x.4))               .        
#> exp((exp(x.1)*exp(x.2)))     .        
#> exp((x.2/exp(x.1)))          .        
#> exp((exp(x.2)/exp(x.1)))     .        
#> exp(|x.1-x.2|)               .        
#> exp(|x.3-x.4|)               .        
#> exp(|exp(x.1)-exp(x.2)|)     .        
#> sin(pi*x.1)                  .        
#> sin(pi*x.2)                  .        
#> sin(pi*x.3)                  .        
#> sin(pi*x.4)                  .        
#> sin(pi*exp(x.1))             .        
#> sin(pi*exp(x.2))             .        
#> sin(pi*exp(x.3))             .        
#> sin(pi*exp(x.4))             .        
#> sin(pi*(x.2+exp(x.2)))       .        
#> sin(pi*(x.1-exp(x.2)))       .        
#> sin(pi*(x.2-exp(x.1)))       .        
#> sin(pi*(exp(x.1)-exp(x.2)))  .        
#> sin(pi*(x.1*x.2))            .        
#> sin(pi*(x.2*exp(x.1)))       .        
#> sin(pi*(x.3*x.4))           19.5876303
#> sin(pi*(exp(x.1)*exp(x.2)))  .        
#> sin(pi*(x.2/exp(x.1)))       .        
#> sin(pi*(exp(x.2)/exp(x.1)))  .        
#> sin(pi*|x.1-x.2|)            .        
#> sin(pi*|x.3-x.4|)            .        
#> sin(pi*|exp(x.1)-exp(x.2)|)  .        
#> cos(pi*x.1)                  .        
#> cos(pi*x.2)                  .        
#> cos(pi*x.3)                  .        
#> cos(pi*x.4)                  .        
#> cos(pi*exp(x.1))             .        
#> cos(pi*exp(x.2))             .        
#> cos(pi*exp(x.3))             .        
#> cos(pi*exp(x.4))             .        
#> cos(pi*(x.2+exp(x.2)))       .        
#> cos(pi*(x.1-exp(x.2)))       .        
#> cos(pi*(x.2-exp(x.1)))       .        
#> cos(pi*(exp(x.1)-exp(x.2)))  .        
#> cos(pi*(x.1*x.2))            .        
#> cos(pi*(x.2*exp(x.1)))       .        
#> cos(pi*(x.3*x.4))            .        
#> cos(pi*(exp(x.1)*exp(x.2)))  .        
#> cos(pi*(x.2/exp(x.1)))       .        
#> cos(pi*(exp(x.2)/exp(x.1)))  .        
#> cos(pi*|x.1-x.2|)            .        
#> cos(pi*|x.3-x.4|)            .        
#> x.1^(-1)                     .        
#> x.2^(-1)                     .        
#> x.3^(-1)                     .        
#> x.4^(-1)                     .        
#> exp(x.1)^(-1)                .        
#> exp(x.2)^(-1)                .        
#> exp(x.3)^(-1)                .        
#> exp(x.4)^(-1)                .        
#> (x.2+exp(x.2))^(-1)          .        
#> (x.1-exp(x.2))^(-1)          .        
#> (x.2-exp(x.1))^(-1)          .        
#> (exp(x.1)-exp(x.2))^(-1)     .        
#> (x.1*x.2)^(-1)               .        
#> (x.2*exp(x.1))^(-1)          .        
#> (x.3*x.4)^(-1)               .        
#> (exp(x.1)*exp(x.2))^(-1)     .        
#> (x.2/exp(x.1))^(-1)          .        
#> (exp(x.2)/exp(x.1))^(-1)     .        
#> |x.1-x.2|^(-1)               .        
#> |x.3-x.4|^(-1)               .        
#> |exp(x.1)-exp(x.2)|^(-1)     .        
#> abs(x.1)                     .        
#> abs(x.2)                     .        
#> abs(x.3)                     .        
#> abs(x.4)                     .        
#> abs((x.2+exp(x.2)))          .        
#> abs((x.1-exp(x.2)))          .        
#> abs((x.2-exp(x.1)))          .        
#> abs((x.1*x.2))               .        
#> abs((x.2*exp(x.1)))          .        
#> abs((x.3*x.4))               .        
#> abs((x.2/exp(x.1)))          .

Here iBART generated 145 descriptors in the last iteration, and it correctly identified the true descriptors \(f_1(X)\) and \(f_2(X)\) without selecting any false positive. This is very reassuring especially when some of these descriptors are highly correlated with \(f_1(X)\) or \(f_2(X)\). For instance, \(\tilde{f}_1(X) = |\exp(x_1) - \exp(x_2)|\) in the descriptor space is highly correlated with \(f_1(X)\).

f1_true <- (exp(X[,1]) - exp(X[,2]))^2
f1_cor <- abs(exp(X[,1]) - exp(X[,2]))
cor(f1_true, f1_cor)
#> [1] 0.9517217

iBART() also returns other useful and interesting outputs, such as iBART_sim$iBART_gen_size and iBART_sim$iBART_sel_size. They store the dimension of the newly generated / selected descriptor space for each iteration. Let’s plot them and see how iBART use nonparametric variable selection for dimension reduction. In each iteration, we keep the dimension of intermediate descriptor space under \(\mathcal{O}(p^2)\), leading to a progressive dimension reduction.

library(ggplot2)
df_dim <- data.frame(dim = c(iBART_sim$iBART_sel_size, iBART_sim$iBART_gen_size),
                     iter = rep(0:3, 2),
                     type = rep(c("Selected", "Generated"), each = 4))
ggplot(df_dim, aes(x = iter, y = dim, colour = type, group = type)) +
  theme(text = element_text(size = 15), legend.title = element_blank()) +
  geom_line(size = 1) +
  geom_point(size = 3, shape = 21, fill = "white") +
  geom_text(data = df_dim, aes(label = dim, y = dim + 10, group = type),
            position = position_dodge(0), size = 5, colour = "blue") +
  labs(x = "Iteration", y = "Number of descriptors") +
  scale_x_continuous(breaks = c(0, 1, 2, 3))

R Session Info

sessionInfo()
#> R version 4.0.5 (2021-03-31)
#> Platform: x86_64-w64-mingw32/x64 (64-bit)
#> Running under: Windows 10 x64 (build 22621)
#> 
#> Matrix products: default
#> 
#> locale:
#> [1] LC_COLLATE=C                          
#> [2] LC_CTYPE=English_United States.1252   
#> [3] LC_MONETARY=English_United States.1252
#> [4] LC_NUMERIC=C                          
#> [5] LC_TIME=English_United States.1252    
#> 
#> attached base packages:
#> [1] stats     graphics  grDevices utils     datasets  methods   base     
#> 
#> other attached packages:
#> [1] ggpubr_0.6.0  ggplot2_3.4.4 iBART_1.0.0  
#> 
#> loaded via a namespace (and not attached):
#>  [1] shape_1.4.6         tidyselect_1.2.0    xfun_0.40          
#>  [4] bslib_0.5.1         purrr_1.0.2         splines_4.0.5      
#>  [7] rJava_1.0-4         lattice_0.20-44     carData_3.0-5      
#> [10] colorspace_2.0-3    vctrs_0.6.4         generics_0.1.3     
#> [13] htmltools_0.5.7     yaml_2.3.5          utf8_1.2.2         
#> [16] survival_3.2-11     rlang_1.1.2         jquerylib_0.1.4    
#> [19] pillar_1.9.0        glue_1.6.2          withr_2.5.2        
#> [22] foreach_1.5.1       lifecycle_1.0.4     munsell_0.5.0      
#> [25] ggsignif_0.6.4      gtable_0.3.4        codetools_0.2-18   
#> [28] evaluate_0.23       labeling_0.4.3      knitr_1.44         
#> [31] fastmap_1.1.1       parallel_4.0.5      fansi_1.0.3        
#> [34] itertools_0.1-3     broom_1.0.5         bartMachine_1.2.6  
#> [37] backports_1.4.1     scales_1.2.1        cachem_1.0.6       
#> [40] jsonlite_1.8.7      abind_1.4-5         farver_2.1.0       
#> [43] gridExtra_2.3       digest_0.6.33       rstatix_0.7.2      
#> [46] dplyr_1.1.3         cowplot_1.1.1       grid_4.0.5         
#> [49] cli_3.6.1           tools_4.0.5         magrittr_2.0.3     
#> [52] sass_0.4.1          missForest_1.4      glmnet_4.1-1       
#> [55] tibble_3.2.1        randomForest_4.6-14 car_3.1-2          
#> [58] tidyr_1.3.0         crayon_1.5.2        pkgconfig_2.0.3    
#> [61] Matrix_1.6-2        bartMachineJARs_1.1 rmarkdown_2.25     
#> [64] rstudioapi_0.15.0   iterators_1.0.13    R6_2.5.1           
#> [67] compiler_4.0.5