Madhav-Kanda / ASTRA

"AI for Sustainability" Toolkit for Research and Analysis

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

"AI for Sustainability" Toolkit for Research and Analysis. ASTRA (अस्त्र) means a "tool" or "a weapon" in Sanskrit.

Install

Stable version:

pip install astra-lib

Latest version:

pip install git+https://github.com/sustainability-lab/ASTRA

Contributing

Please go through the contributing guidelines before making a contribution.

Useful Code Snippets

Data

Load Data

from astra.torch.data import load_mnist, load_cifar_10
ds, ds_name = load_cifar_10()

Models

MLPs

from astra.torch.models import MLP

mlp = MLP(input_dim=100, hidden_dims=[128, 64], output_dim=10, activation="relu", dropout=0.1)

CNNs

from astra.torch.models import CNN
cnn = CNN(image_dim=32, 
          kernel_size=5, 
          n_channels=3, 
          conv_hidden_dims=[32, 64], 
          dense_hidden_dims=[128, 64], 
          output_dim=10)

EfficientNets

from torchvision.models import efficientnet_b0, EfficientNet_B0_Weights
from astra.torch.models import EfficientNet

model = EfficientNet(efficientnet_b0, EfficientNet_B0_Weights.DEFAULT, output_dim=10)

ViT

from torchvision.models import vit_b_16, ViT_B_16_Weights
from astra.torch.models import ViT

model = ViT(vit_b_16, ViT_B_16_Weights.DEFAULT, output_dim=10)

Training

Quick train a model

from astra.torch.utils import train_fn
result = train_fn(model, inputs, outputs, loss_fn, lr, n_epochs, batch_size, enable_tqdm=True)
print(result.keys()) # dict_keys(['epoch_losses', 'iter_losses'])

Adhoc

Count number of parameters in a model

from astra.torch.utils import count_params
n_params = count_params(mlp)

Flatten/Unflatten the weights of a model

import torch
from astra.torch.models import ViT
from torchvision.models import vit_b_16, ViT_B_16_Weights
from astra.torch.utils import ravel_pytree
import optree

model = ViT(vit_b_16, ViT_B_16_Weights.DEFAULT, output_dim=10)
params = dict(model.named_parameters())

flat_params, unravel_fn = ravel_pytree(params)
unraveled_params = unravel_fn(flat_params) # returns the original params

# check if the tree structure is preserved
assert optree.tree_structure(params) == optree.tree_structure(unraveled_params)

# check if the values are preserved
for before_leaf, after_leaf in zip(optree.tree_leaves(params), optree.tree_leaves(unraveled_params)):
    assert torch.all(before_leaf == after_leaf)

About

"AI for Sustainability" Toolkit for Research and Analysis


Languages

Language:Jupyter Notebook 76.9%Language:Python 23.1%