FrancescoSaverioZuppichini / torchlego

High level building blocks for Neural Networks with examples

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

torchlego

This is an early preview

High-quality Neural Networks built with reusable blocks in PyTorch

alt

Photo by Ryan Quintal on Unsplash

Motivation

This library aims to create new components to make developing and writing neural networks faster and easier.

Installation

pip

You can install using pip

pip install git+https://github.com/FrancescoSaverioZuppichini/torchlego

Quick Tour

%load_ext autoreload
%autoreload 2

Building blocks

It follows a list of useful small components made to increase your code readability and development. Just like lego, when combined, they can become anything!

Convs

Most of the times you will use $3x3$ conv or $1x1$ convs followed by batchnorm and an activation function

from torchlego import conv3x3

conv3x3(32, 64)

Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))

alt

from torchlego import conv3x3_bn

conv3x3_bn(32, 64)

Sequential( (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1)) (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) )

alt

from torchlego import conv3x3_bn_act

conv3x3_bn_act(32, 64)

Sequential( (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1)) (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU() )

alt

Optionally, you can always pass your activation function

from torch.nn import SELU
conv3x3_bn_act(32, 64, act=SELU)

Sequential( (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1)) (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): SELU() )

Also, we have conv1x1

from torchlego.blocks import conv1x1

conv1x1(32, 64)

Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1))

InputForward

Many times we need to pass an input to multiple modules and do something with the results. This is went InputForward comes in handy! aggr_func takes as input a list where each entry is the output from one of the modules

import torch
from torchlego.blocks import InputForward, Lambda

blocks = [Lambda(lambda x: x), Lambda(lambda x: x)]
InputForward(blocks, aggr_func=lambda x: torch.tensor(x).sum())(torch.tensor([1]))

tensor(2)

alt

Cat

An InputForward instance that concat the outputs

import torch
from torchlego.blocks import Cat, Lambda


blocks = [Lambda(lambda x: x), Lambda(lambda x: x)]

Cat(blocks)(torch.tensor([1]))

tensor([1, 1])

alt

Residual Connection

Redisual connections are a big thing. They probably are most know from resnet paper (even if Schmidhuber did something very similar a long time ago). You can use torchlego to easily create any residual connection that you may need.

Residual

The main building block is the Residual class. It applies a function on the input and the output of a blocks.

import torch
from torchlego.blocks import Residual, Lambda

x = torch.tensor([1])
block = Lambda(lambda x: x)


res = Residual(block, res_func=lambda x, res: x + res)
res(x)

tensor([2])

alt

shortcut

We can also apply a function on the residual, this operation is called shortcut.

res = Residual(block, res_func=lambda x, res: x + res, shortcut=lambda x: x * 2)
res(x)

tensor([3])

alt

You can also not pass the .res_func, in that case, the residual will be passed as the second parameter to the blocks.

Multiple lever residuals

If you pass an array of nn.ModuleList, we assume you want to pass residual through each respective layer. In the following example, all the A layers just compute input + 1 the input while the B layers add the residual with the current input. Notice that the B.forward function takes as second argument the residual.

import torch.nn as nn

<<<<<<< Updated upstream
class A(nn.Module):
    def forward(self, x):
        return x + 1
    
class B(nn.Module):
    def forward(self, x, res):
        return x + res
        
=======
class A(nn.ModuleList):
 def forward(self, x):
 return x + 1
 
class B(nn.ModuleList):
 def forward(self, x, res):
 return x + res
 
>>>>>>> Stashed changes
down = nn.ModuleList([A(), A()])
up = nn.ModuleList([B(), B()])
res = Residual([down, up])
res(x)

tensor([8])

alt

Be aware that only the first n residuals will be passed, where n is the len of the second blocks. I know, it sounds confusing, but let's see an example.

import torch.nn as nn


<<<<<<< Updated upstream
class B(nn.Module):
    def forward(self, x, res = None):
        x = x if res is None else x + res
        return x 
        
=======
class B(nn.ModuleList):
 def forward(self, x, res = None):
 x = x if res is None else x + res
 return x 
 
>>>>>>> Stashed changes
down = nn.ModuleList([A(), A()])
up = nn.ModuleList([B()])
res = Residual([down, up])
res(x)

tensor([5])

alt

Addition

torchlego comes with a useful ResidualAdd block that is just a Residual that performs automatically addition

import torch
from torchlego.blocks import ResidualAdd, Lambda

layer = ResidualAdd([Lambda(lambda x: x)])
layer(torch.tensor([1]))

tensor([2])

A more complete example

import torch.nn as nn
from torchlego.blocks import ResidualAdd

x = torch.rand((1, 64, 8, 8))

block = nn.Sequential(conv3x3_bn_act(64, 64, padding=1))

layer = ResidualAdd([block])
x = layer(x)

alt

You can pass multiple blocks

from torchlego.blocks import ResidualAdd, Lambda

blocks = [Lambda(lambda x: x), Lambda(lambda x: x)]
layer = ResidualAdd(blocks)
layer(torch.tensor(1))

tensor(4)

alt

Let's create a basic ResNet block

from torchlego.blocks import conv_bn

def resnet_basic_block(in_features, out_features):
 shortbut = conv_bn(in_features, out_features, kernel_size=1, stride=2, bias=False) if in_features != out_features else nn.Identity()
 stride = 2 if in_features != out_features else 1
 return nn.Sequential(
 ResidualAdd(nn.ModuleList([
 nn.Sequential(
 conv3x3_bn_act(in_features, out_features, stride=stride, padding=1, bias=False),
 conv3x3_bn(out_features, out_features, padding=1, bias=False))]), 
 shortcut=shortbut),
 nn.ReLU())
 
resnet_basic_block(32, 64)

Sequential( (0): Residual( (blocks): ModuleList( (0): Sequential( (0): Sequential( (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU() ) (1): Sequential( (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) ) (shortcut): Sequential( (0): Conv2d(32, 64, kernel_size=(1, 1), stride=(2, 2), bias=False) (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (1): ReLU() )

alt

What about a full resnet? Easy peasy

def resnet(in_features, n_classes, sizes):
 return nn.Sequential(
 nn.Conv2d(in_features, 64, kernel_size=7, stride=2, padding=3),
 nn.BatchNorm2d(64),
 nn.ReLU(),
 nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
 nn.Sequential(*[resnet_basic_block(64, 64) for _ in range(sizes[0])]),
 resnet_basic_block(64, 128),
 nn.Sequential(*[resnet_basic_block(128, 128) for _ in range(sizes[1] - 1)]),
 resnet_basic_block(128, 256),
 nn.Sequential(*[resnet_basic_block(256, 256) for _ in range(sizes[2] - 1)]),
 resnet_basic_block(256, 512),
 nn.Sequential(*[resnet_basic_block(512, 512) for _ in range(sizes[3] - 1)]),
 nn.AdaptiveAvgPool2d(1),
 nn.Flatten(),
 nn.Linear(512, n_classes)
)
x = torch.rand((1,3,224,244))

resnet34 = resnet(3, 1000, [3, 4, 6, 3])

resnet34(x).shape

torch.Size([1, 1000])

alt

Unet

What about Unet?

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchlego.blocks import ResidualCat2d, Lambda
from torchlego.blocks import Residual, conv3x3, conv3x3_bn_act, conv1x1

down = lambda in_features, out_features: nn.Sequential(
 nn.MaxPool2d(kernel_size=2, stride=2),
 conv3x3_bn_act(in_features, out_features, padding=1),
 conv3x3_bn_act(out_features, out_features, padding=1),
)

class up(nn.Module):
 def __init__(self, in_features, out_features, should_up=True, *args, **kwargs):
 super().__init__()
 self.up = nn.ConvTranspose2d(in_features, out_features, kernel_size=2, stride=2)
 self.should_up = should_up
 self.blocks =nn.Sequential(
 conv3x3(out_features * 2, out_features, padding=1),
 conv3x3(out_features, out_features, padding=1),
 )

 def forward(self, x, res):
 if self.should_up: x = self.up(x)
 
 diffX = x.size()[2] - res.size()[2]
 diffY = x.size()[3] - res.size()[3]
 pad = (diffX // 2, int(diffX / 2), diffY // 2, int(diffY / 2))
 res = F.pad(res, pad)
 
 x = torch.cat([res, x], dim=1)
 out = self.blocks(x)
 return out

unet = nn.Sequential(
 Residual([
 nn.ModuleList([
 nn.Sequential(
 conv3x3_bn_act(3, 64),
 conv3x3_bn_act(64, 64)),
 down(64, 128),
 down(128, 256),
 down(256, 512),
 down(512, 1024),
 ]),
 nn.ModuleList([
 up(512 * 2, 512),
 up(256 * 2, 256),
 up(128 * 2, 128),
 up(64 * 2, 64),

 ])
]), 
 conv1x1(64, 2)
)

x = torch.rand((1,3,256,256))

unet(x)

<<<<<<< Updated upstream tensor([[[[ 0.1703, 0.1333, 0.1065, ..., 0.0960, 0.0289, 0.2440], [ 0.0541, 0.1370, -0.0009, ..., -0.0234, 0.1766, 0.1227], [ 0.0071, 0.0103, -0.0728, ..., -0.0840, 0.0448, 0.0931], ..., [ 0.0387, 0.0116, 0.1317, ..., 0.0861, 0.0679, 0.1378], [ 0.0253, 0.0583, -0.0304, ..., -0.0590, -0.0112, 0.0985], [ 0.1501, 0.1165, 0.0767, ..., 0.0516, 0.0438, 0.0853]],

         [[ 0.0734,  0.1631, -0.0388,  ..., -0.0313, -0.0182, -0.0158],
          [ 0.0150, -0.0157, -0.1313,  ...,  0.0417, -0.0827,  0.0312],
          [-0.0783, -0.1002, -0.0104,  ..., -0.0625,  0.1027, -0.0133],
          ...,
          [-0.0326, -0.0716, -0.0284,  ..., -0.0891, -0.0317, -0.0103],
          [ 0.1046, -0.0530, -0.0447,  ...,  0.0813, -0.0651,  0.0246],
          [-0.0351,  0.0342, -0.0398,  ..., -0.0206, -0.0515,  0.0425]]]],
       grad_fn=<ThnnConv2DBackward>)

======= tensor([[[[ 8.8137e-02, 3.6018e-02, 3.6493e-02, ..., -8.1409e-02, -1.0382e-01, -3.2821e-02], [-6.6964e-02, -2.5947e-02, 6.4162e-02, ..., 1.5033e-01, -7.9811e-02, -1.8225e-02], [ 1.4054e-01, -1.9263e-02, -7.7571e-02, ..., -7.0426e-03, -1.0605e-01, -1.3636e-01], ..., [ 7.7351e-02, -1.5599e-02, 1.2618e-01, ..., 2.8263e-02, 3.2247e-02, 2.6603e-02], [ 9.9938e-02, -1.0595e-01, -5.5370e-02, ..., 5.1364e-03, 1.1270e-01, -7.0449e-02], [ 1.1673e-01, 1.1466e-01, 9.6519e-02, ..., 2.0525e-01, -4.1448e-03, -1.7874e-04]],

[[-1.1501e-01, -7.1215e-02, -9.0835e-02, ..., 2.8257e-02, -9.3994e-02, -5.7923e-02], [-5.4163e-02, -9.5989e-02, 1.5962e-02, ..., -6.1023e-02, -1.2532e-01, -1.2940e-01], [ 9.8738e-02, -4.5944e-02, -9.7948e-02, ..., -6.3290e-02, -1.7416e-01, -7.3600e-02], ..., [-3.4967e-02, -4.3565e-02, -6.5006e-02, ..., -8.8919e-02, -1.3443e-01, -1.8066e-01], [-2.2645e-02, -7.6100e-02, -1.3066e-01, ..., -5.9316e-02, -1.9493e-01, -1.1370e-01], [-3.7421e-02, -7.6370e-02, -1.0176e-01, ..., -2.0086e-01, -1.5869e-01, -9.6589e-02]]]], grad_fn=)

Stashed changes

About

High level building blocks for Neural Networks with examples

License:MIT License


Languages

Language:Python 100.0%