google-research / met

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

MET : Masked Encoding Tabular Data

This repository is the official implementation of MET.

Disclaimer : This is not an officially supported Google product.

Architecture

Requirements

To run experiments mentioned in the paper and install requirements use python version >=3.7:

git clone http://github.com/google-research/met
cd met
pip install -r requirements.txt

Standard Training (MET-S)

To train the MET-S model mentioned in the paper (model without adversarial training step) for FashionMNIST dataset, run this command:

python3 train.py

The following hyper-parameters are available for train.py :

  • embed_dim : Embedding dimension
  • ff_dim : Feed-Forward dimension
  • num_heads : Number of heads
  • model_depth_enc : Depth of Encoder/ Number of transformers in Encoder stack
  • model_depth_dec : Depth of Decoder/ Number of transformers in Decoder stack
  • mask_pct : Masking Percentage
  • lr : Learning rate

Each of the above can be changed by adding --flag_name=flag_value to train.py. For example :

python3 train.py --model_depth_enc=1

The model is saved here by default

Adversarial Training (MET)

To train the MET model in the paper for FashionMNIST dataset trained using Adversarial training, run this command:

python3 train_adv.py

The following hyper-parameters are available for train.py :

  • embed_dim : Embedding dimension
  • ff_dim : Feed-Forward dimension
  • num_heads : Number of heads
  • model_depth_enc : Depth of Encoder/ Number of transformers in Encoder stack
  • model_depth_dec : Depth of Decoder/ Number of transformers in Decoder stack
  • mask_pct : Masking Percentage
  • lr : Learning rate
  • radius : Radius of L2 norm ball around the input data point
  • adv_steps : Adversarial loop length
  • lr_adv : Adversarial Learning Rate

Each of the above can be changed by adding --flag_name=flag_value to train.py. For example :

python3 train_adv.py --radius=14

The model is saved here by default

Adding a new dataset :

You can try using the model on any new dataset by creating a csv file. The first column of the csv file should be class followed by the attributes. Sample csv files are available in data

To pass on the csv file to any of the training and evaluation scripts use the following flags :

  • num_classes : Number of classes
  • model_kw : Keyword for model (Eg fmnist for fashion-mnist)
  • train_len : Length of train csv
  • train_data_path : Path to train csv file
  • test_len : Length of test csv
  • test_data_path : Path to test csv files
  • By default models are stored in saved_models. You can change the training path using flag model_path.
  • Synthetic dataset can be created using get_2d_dataset.py. By default a created dataset is available in data

Pre-trained Models

Pretrained models for FashionMNIST for optimal adversarial training setting is available in saved_models. You can extract the models using command:

7z e fmnist_saved.7z.001
7z e fmnist_saved_adv.7z.001

Evaluation

To evaluate the saved MET-S model run

python3 eval.py --model_path="./saved_models/fmnist_64_1_64_6_1_70_1e-05" --model_path_linear="./saved_models/fmnist_linear_64_1_64_6_1_70_1e-05"

To evaluate the saved MET model run

python3 eval.py --model_path="./saved_models/fmnist_adv_64_1_64_6_1_70_1e-05" --model_path_linear="./saved_models/fmnist_linear_adv_64_1_64_6_1_70_1e-05"

By default results are written to met.csv.

Results

The performance of our model across various multi-class classification datasets is shown below.


Type Methods FMNIST CIFAR10 MNIST CovType Income
Supervised Baseline MLP 87.57 ± 0.13 16.47 ± 0.23 96.98 ± 0.1 65.45 ± 0.09 84.35 ± 0.11
RF 87.19 ± 0.09 36.75 ± 0.17 97.62 ± 0.18 64.94 ± 0.12 84.6 ± 0.2
GBDT 88.71 ± 0.07 45.7 ±  0.27 100 ± 0.0 72.96 ± 0.11 86.01 ± 0.06
RF-G 89.84 ± 0.08 29.28 ± 0.16 97.63 ± 0.03 71.53 ± 0.06 85.57 ± 0.13
MET-R 88.81 ± 0.12 28.97 ± 0.08 97.43 ± 0.02 69.68 ± 0.07 75.50 ± 0.04
Self-Supervised Methods VIME 80.36 ± 0.02 34 ± 0.5 95.74 ± 0.03 62.78 ± 0.02 85.99 ± 0.04
DACL+ 81.38 ± 0.03 39.7 ± 0.06 91.35 ± 0.075 64.17 ± 0.12 84.46 ± 0.03
SubTab 87.58 ± 0.03 39.32 ± 0.04 98.31 ± 0.06 42.36 ± 0.03 84.41 ± 0.06
Our Method MET-S 90.90 ± 0.06 47.96  ±  0.1 98.98 ± 0.05 74.13 ± 0.04 86.17  ±  0.08
MET 91.68 ± 0.12 47.92  ±  0.13 99.17+-0.04 76.68  ±  0.12 86.21 ± 0.05

The performance of our model across various binary classification datasets is shown below.


Datasets Metric MLP RF GBDT RF-G MET-R DACL+ VIME SubTab MET
Obesity Accuracy 58.1 ± 0.07 65.99 ± 0.12 67.19 ± 0.04 58.39 ± 0.17 58.8 ± 0.59 62.34 ± 0.12 59.23 ± 0.17 67.48 ± 0.03 74.38 ± 0.13
AUROC 52.3 ± 0.12 64.36 ± 0.07 64.4 ± 0.05 54.45 ± 0.08 53.2 ± 0.18 61.18 ± 0.07 57.27 ± 0.21 64.92 ± 0.06 71.84 ± 0.15
Income Accuracy 84.35 ± 0.11 84.6 ± 0.2 86.01 ± 0.06 85.57 ± 0.13 75.50 ± 0.04 85.99 ± 0.24 84.46 ± 0.03 84.41 ± 0.06 86.21 ± 0.05
AUROC 89.39 ± 0.2 91.53 ± 0.32 92.5 ± 0.08 90.09 ± 0.57 83.48 ± 0.23 89.01 ± 0.4 87.37 ± 0.07 88.95 ± 0.19 93.85 ± 0.33
Criteo Accuracy 74.28 ± 0.32 71.09 ± 0.05 72.03 ± 0.03 74.62 ± 0.08 73.57 ± 0.12 69.82 ± 0.06 68.78 ± 0.13 73.02 ± 0.08 78.49 ± 0.05
AUROC 79.82 ± 0.17 77.57 ± 0.1 78.77 ± 0.04 80.32 ± 0.16 79.17 ± 0.17 75.32 ± 0.27 74.28 ± 0.39 76.57 ± 0.05 86.17 ± 0.2
Arrhythmia Accuracy 59.7 ± 0.02 68.18 ± 0.02 69.79 ± 0.12 60.6 ± 0.05 51.67 ± 0.1 57.81 ± 0.47 56.06 ± 0.04 60.1 ± 0.1 81.25 ± 0.12
AUROC 72.23 ± 0.06 90.63 ± 0.08 92.19 ± 0.05 74.02 ± 0.12 58.36 ± 0.32 69.23 ± 0.98 67.03 ± 0.27 69.97 ± 0.07 98.75 ± 0.04
Thyroid Accuracy 50 ± 0.0 94.94 ± 0.1 96.44 ± 0.07 50 ± 0.0 57.42 ± 0.37 60.03 ± 0.05 66.1 ± 0.19 59.9 ± 0.16 98.1 ± 0.08
AUROC 62.3 ± 0.12 99.62 ± 0.03 99.34 ± 0.02 52.65 ± 0.13 82.03 ± 0.26 86.63 ± 0.1 94.87 ± 0.03 88.93 ± 0.12 99.81 ± 0.09

About

License:Apache License 2.0


Languages

Language:Python 100.0%