rayleizhu / segformer.pure

A cleaned version of official SegFormer. It removes dependency on MMCV and MMSegmentation, which use deep wrapings.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

This is a cleaned version of official SegFormer. It removes dependency on MMCV and MMSegmentation, which use deep wrapings.

Requirements

  • pytorch>=1.0
  • timm>=0.5.4
  • gdown (optional, only required if you want to automatically load official checkpoints from url)

Features

  • written with pure pytorch api, no deep wraping, easy to understand, modification friendly
  • compatiable with officially released model weights
  • automatically downloads official checkpoints online

Example usage:

from networks.segformer import *
import torch

model1 = SegFormerB0(num_classes=150, encoder_weight=None)
print(model1.official_ckpts) # print officially released checkpoints
model1.load_official_state_dict('segformer.b0.512x512.ade.160k.pth', strict=True) # load official released weights

model2 = SegFormerB0(num_classes=1, encoder_weight=None) # binary classifier
model2.load_official_state_dict('segformer.b0.512x512.ade.160k.pth', strict=False) # the final prediction layer is not loaded

model3 = SegFormerB1(num_classes=20, encoder_weight='imagenet') # load only ImageNet-pretained backbone

x = torch.zeros((2, 3, 512, 512))
pred = model3(x)
print(pred.size()) # final resolution is (h/4, w/4)


model4 = SegFormerB1(in_ch=6, num_classes=20, encoder_weight='imagenet') # change input channels
x = torch.zeros((2, 6, 512, 512))
y = model4(x)

TODOs

  • MixVisionTransformer.load_official_state_dict()
  • Flexible input channels for ImageNet pretained MiT()

About

A cleaned version of official SegFormer. It removes dependency on MMCV and MMSegmentation, which use deep wrapings.


Languages

Language:Python 92.3%Language:Jupyter Notebook 7.7%