kavorite / convnext

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

ConvNeXt

Implementation of A ConvNet for the 2020s by Liu, et al., in Haiku.

The authors limited themselves to a choice of architecture whose FLOP counts at various spatial resolutions could be easily compared with their baselines, in the interest of accurately representing the improvements resulting from numerous architectural changes in their experiments. For my use-case I saw fit to replace this default configuration with one given by Brock, et al. in High-Performance Large-Scale Image Recognition Without Normalization, which is instead motivated by a strong tradeoff observed between representational capacity and training latency during empirical architecture search.

API

The interface uses layer-wise normalization as employed in ConvNeXt by default and includes support for substituting other normalization schemes, including the baseline batch normalization. Specifying norm=None in model constructor will result in "norm-free" weight standardization as employed in NFNets.

import haiku as hk
import jax

from convnext import ConvNeXt


def ConvNeXtTiny(inputs, is_training=True):
    cnn = ConvNeXt(depths=[3, 3, 9, 3])
    logits = cnn(inputs, is_training)
    return logits


rng = hk.PRNGSequence(42)
inputs = jax.random.truncated_normal(next(rng), -1, 1, (1, 224, 224, 3))
params, state = hk.transform_with_state(ConvNeXtTiny).init(next(rng), inputs)
logits = hk.transform_with_state(ConvNeXtTiny).apply(params, state, next(rng), inputs)

About

License:MIT License


Languages

Language:Python 100.0%