liu3xing3long / pytorch_refinenet

Pytorch Implementation of Refinenet

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Pytorch Implementation of RefineNet

Build Status

This is a Pytorch implementation of the Multipath RefineNet architecture from the paper.

Installation

Install PyTorch following instructions on their website. Then install this package:

pip install git+https://github.com/thomasjpfan/pytorch_refinenet.git

Implemented versions

  • Multi-path 4-Cascaded RefineNet: RefineNet4Cascade
  • Multi-path 4-Cascaded RefineNet With Improved Pooling: RefineNet4CascadePoolingImproved

There are diagrams of these two versions in the author's github repo. The improved pooling version adds an additional pooling/convolution layer and flips the order of the pooling/convolution layers in the Chained Residual Pooling block.

Usage

This implementation of the Multipath RefineNet has the following initialization:

class RefineNet4Cascade(nn.Module):

    def __init__(self,
                 input_shape,
                 num_classes=1,
                 features=256,
                 resnet_factory=models.resnet101,
                 pretrained=True,
                 freeze_resnet=True):
        ...

The input_shape is a tuple of(channels, size) which denotes the number of channels in the input image and the input width/height. For an input to flow cleanly through the resnet layers, the input size should be divisible by 32. The input size is assumed to be a square image/patch. For example the RefineNet4Cascade can be defined to intake 3x224x224 images:

import torch
from torch.autograd import Variable

net = RefineNet4Cascade((3, 224), num_classes=10)
x_var = Variable(torch.randn(1, 3, 224, 224))
y = net(x_var)
y.size()
# torch.Size([1, 10, 56, 56])

The number of channels outputed will equal num_classes and the size will be 1/4 the size of the input as described in the paper. You can upscale the to get back to the original resolution.

About

Pytorch Implementation of Refinenet

License:MIT License


Languages

Language:Python 100.0%