dirkschumacher / armacmp

🚀 Automatically compile linear algebra R code to C++ with Armadillo

Home Page:https://dirkschumacher.github.io/armacmp/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

More type deductions for matrix multiplication

dirkschumacher opened this issue · comments

mat %*% colvec = colvec
rowvec %*% mat = rowvec
mat %*% mat = mat
T %*% <scalar> = T
<scalar> %*% T = T

The relevant code is here:

armacmp/R/ast-classes.R

Lines 830 to 858 in 5777a3f

get_cpp_type = function() {
arma_type <- find_arma_cpp_types(self$get_tail_elements())
arma_cpp_types <- vapply(arma_type, function(x) x$get_cpp_type(), character(1L))
if (length(arma_type) == 2L) {
is_integer_arma_type <- function(type) {
grepl("^arma::i", type)
}
is_unsigned_integer_arma_type <- function(type) {
grepl("^arma::u", type)
}
is_double_arma_type <- function(type) {
type %in% paste0("arma::", c("vec", "colvec", "rowvec", "mat"))
}
# TODO: revisit these rules and contemplate
if (any(sapply(arma_cpp_types, is_double_arma_type))) {
return("arma::mat")
}
if (any(sapply(arma_cpp_types, is_integer_arma_type))) {
return("arma::imat")
}
if (all(sapply(arma_cpp_types, is_unsigned_integer_arma_type))) {
return("arma::umat")
}
}
if (length(arma_type) == 1L) {
return(arma_type[[1L]]$get_cpp_type())
}
"arma::mat"
}

There is also a lot of room for refactoring and making this stuff smarter :)