Towards self-supervised learning of global and object-centric representations
For the ICLR workshop on Objects, Structure, and Causality.
Setup
First, clone the repo:
git clone 'https://github.com/baldassarreFe/iclr-osc-22.git'
cd iclr-osc-22
Create an environment from scratch:
ENV_NAME='objects'
conda create -y -n "${ENV_NAME}" -c pytorch -c conda-forge \
python black isort pytest dill pre-commit \
hydra-core colorlog submitit fvcore tqdm wandb sphinx \
numpy pandas matplotlib seaborn tabulate scikit-learn scikit-image \
jupyterlab jupyterlab_code_formatter jupyter_console ipywidgets \
pytorch tensorflow-gpu cudatoolkit-dev cudnn \
torchvision einops opt_einsum
conda activate "${ENV_NAME}"
python -m pip install \
better_exceptions \
sphinx-rtd-theme sphinx-autodoc-typehints \
hydra_colorlog hydra-submitit-launcher namesgenerator \
tensorflow-datasets transformers datasets \
'git+https://github.com/deepmind/multi_object_datasets' \
'git+https://github.com/rwightman/pytorch-image-models'
conda env config vars set BETTER_EXCEPTIONS=1
pre-commit install
pre-commit autoupdate
python -m pip install --editable .
Or create an environment using the provided dependency file:
ENV_NAME='iclr-osc-22'
conda env create -n "${ENV_NAME}" -f 'environment.yaml'
conda activate "${ENV_NAME}"
pre-commit install
python -m pip install --editable .
Datasets
The project uses the "CLEVR with masks dataset", which is part of the Multi Object Datasets collection.
Download all datasets from a Google Cloud bucket (see original website for other options):
sudo apt install -y apt-transport-https ca-certificates gnupg
echo 'deb [signed-by=/usr/share/keyrings/cloud.google.gpg] https://packages.cloud.google.com/apt cloud-sdk main' |
sudo tee -a /etc/apt/sources.list.d/google-cloud-sdk.list
curl 'https://packages.cloud.google.com/apt/doc/apt-key.gpg' |
sudo apt-key --keyring /usr/share/keyrings/cloud.google.gpg add -
sudo apt update
sudo apt install -y google-cloud-sdk
gsutil -m cp -r gs://multi-object-datasets "${HOME}/"
Data loading and visualization notebooks:
Prepare the CLEVR dataset for training and evaluation by splitting the original TFRecords file in 3 parts (train+val only contain RGB images, test contains the full sample dict with object masks and attributes):
python -m osc.data.clevr_with_masks --data-root "${HOME}/multi-object-datasets"
Training
A training run with default parameters can be launched by executing following commands from the root of the repository (local training, single GPU):
export CUDA_VISIBLE_DEVICES=0
./train.py
The project uses Hydra as a configuration manager. All defaults can be listed as:
./train.py --cfg job
Individual parameters from the configuration can be changed on the command line:
./train.py \
training.batch_size=16 \
model.backbone.embed_dim=64 \
model.backbone.patch_size='[4,4]' \
model.backbone.embed_dim=64 \
model.backbone.num_heads=8 \
model.backbone.{proj_drop,attn_drop}=0.2
All available configuration groups (e.g. different loss functions, attention types,
learning rate schedules, etc.) can be found in the configs
folder.
For example, to train with an object-wise contrastive loss that takes all object
tokens from all images as negatives and overfit on a small subset of images:
./train.py losses/l_objects=ctr_all +overfit=overfit
Running a parameter sweep in a SLURM environment is also supported, for example:
./train.py --multirun hydra/launcher=submitit_slurm +slurm=slurm \
+losses=more_objects,more_global \
model=bb_obj_global \
model/obj_queries=sample \
model.backbone.embed_dim=64,128,256 \
logging.group='slurm_sweep' \
lr_scheduler=linear1_cosine4_x5 \
lr_scheduler.decay.end_lr=0.0003 \
optimizer.start_lr=0.0007 \
optimizer.weight_decay=0.0001 \
model.backbone.num_heads=4,8 \
model.backbone.num_layers=2,4,6 \
model.obj_fn.num_iters=1,2,4
Hyperparameters
Here follow the main hyperparameters that can be configured for the experiments.
A corresponding configuration file can be found in the configs
folder.
Architectures:
backbone-global_fn-global_proj
: global representation only. Backbone patch tokens can be aggregated either with global average pooling (avg
) or an extra CLS token (cls
)backbone(-slot_fn-slot_proj)-global_fn-global_proj
: after the backbone, two separate branches process global and object features. Backbone patch tokens can be aggregated either with global average pooling (avg
) or an extra CLS token (cls
)backbone-slot_fn(-global_fn-global_proj)-slot_proj
: after the backbone, the slot function extractsS
object representations, theseS
feature tokens are further projected to yield object representations, furthermore theseS
tokens are average-pooled and processed to extract global features and projections. The backbone pooling is set toavg
since a CLS token would not be ignored.
Object query implementations:
learned
: learned query tokens in fixed numbersample
: object queries are sampled either from a single Gaussian distributions with learned parameters, or a mixture of Gaussiand with uniform component weightskmeans_euclidean
: object queries are initialized as the K-Means clustering of backbone features. Number of clusters can be dynamically chosen, the distance function is a simply Euclidean distance.
Object function implementations:
slot-attention
slot attention decoder (iterative)cross-attention
cross attention decoder- co attention decoder (not implemented yet)
Loss functions:
- Global image representation:
- Contrastive loss
ctr
(given one image, classify positively an augmented version of that image amongB-2
other unrelated images in the batch) - Cosine similarity loss
sim
(given one image and its augmented version, maximise the cosine similarity between their projected representations)
- Contrastive loss
- Object representation:
- Contrastive loss
ctr_all
(one token compared to all tokens in all images) - Contrastive loss
ctr_img
(one token compared to all tokens from its original image and the augmented version) - Cosine similarity loss
sim_img
(one token compared to all tokens from its original image and the augmented version)
- Contrastive loss
Embedding dimension:
- default 64 for everything with 2x factor for all MLP hidden layers
- 128 and 256 also work well but require a smaller batch size especially when using 8 heads. Safe (dim, batch) pairs: (64, 64), (128, 16), (256, 8)
- Interesting to try different size for the final projection head when using matching cosine similarity loss
Documentation
The documentation is hosted on a GitHub Pages website.
The documentation is generated automatically using Sphinx. All sources are in the
docs/
folder. A separate docs
branch tracks the documentation builds.