systats / supermatch

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Supermatch

Packages

pacman::p_load(tidyverse, keras, purrr)

Fun

tokenize_seq <- function(text, tok, seq_len = 150){
  tok %>%
    keras::texts_to_sequences(texts = text) %>%
    keras::pad_sequences(maxlen = seq_len)
}

predict_labels <- function(label1, label2){
  predict(
    model, 
    list(
      label1, 
      label2
    )
  )[,1]
}

predict_score <- function(d, label1 = label1, label2 = label2){

  d %>% 
    dplyr::mutate(
      prob = list(
        label1 = tokenize_seq({{label1}}, tok = tok, seq_len = 30),
        label2 = tokenize_seq({{label2}}, tok = tok, seq_len = 30)
      ) %>%
      do.call(predict_labels, .)
    ) 
}

Model

model <- keras::load_model_hdf5(filepath = "models/keras_siamese_cnn_lstm_86")
tok <- keras::load_text_tokenizer(filename = "models/tok")

Prediction

tibble(label1 = "Bayern Muenchen", label2 = "FC Bayern München") %>%
  predict_score()
## # A tibble: 1 x 3
##   label1          label2             prob
##   <chr>           <chr>             <dbl>
## 1 Bayern Muenchen FC Bayern München 0.894

About