ykumards / flax

Flax is a neural network ecosystem for JAX that is designed for flexibility.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Flax: A neural network ecosystem for JAX designed for flexibility

Overview | What does Flax look like? | Documentation

coverage

Please check our full documentation website to learn everything you need to know about Flax.

NOTE: Flax is in use by a growing community of researchers and engineers at Google who happily use Flax for their daily research. The new Flax "Linen" module API is now stable and we recommend it for all new projects. The old flax.nn API will be deprecated. Please report any feature requests, issues, questions or concerns in our discussion forum, or just let us know what you're working on!

Expect changes to the API, but we'll use deprecation warnings when we can, and keep track of them in our Changelog.

In case you need to reach us directly, we're at flax-dev@google.com.

Overview

Flax is a high-performance neural network library for JAX that is designed for flexibility: Try new forms of training by forking an example and by modifying the training loop, not by adding features to a framework.

Flax is being developed in close collaboration with the JAX team and comes with everything you need to start your research, including:

  • Neural network API (flax.linen): Dense, Conv, {Batch|Layer|Group} Norm, Attention, Pooling, {LSTM|GRU} Cell, Dropout

  • Optimizers (flax.optim): SGD, Momentum, Adam, LARS, Adagrad, LAMB, RMSprop

  • Utilities and patterns: replicated training, serialization and checkpointing, metrics, prefetching on device

  • Educational examples that work out of the box: MNIST, LSTM seq2seq, Graph Neural Networks, Sequence Tagging

  • Fast, tuned large-scale end-to-end examples: CIFAR10, ResNet on ImageNet, Transformer LM1b

What does Flax look like?

We provide here two examples using the Flax API: a simple multi-layer perceptron and a CNN. To learn more about the Module abstraction, please check our docs.

class SimpleMLP(nn.Module):
  """ A MLP model """
  features: Sequence[int]

  @nn.compact
  def __call__(self, x):
    for i, feat in enumerate(self.features):
      x = nn.Dense(feat)(x)
      if i != len(self.features) - 1:
        x = nn.relu(x)
    return x
class CNN(nn.Module):
  """A simple CNN model."""

  @nn.compact
  def __call__(self, x):
    x = nn.Conv(features=32, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = nn.Conv(features=64, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = x.reshape((x.shape[0], -1))  # flatten
    x = nn.Dense(features=256)(x)
    x = nn.relu(x)
    x = nn.Dense(features=10)(x)
    x = nn.log_softmax(x)
    return x

Note

This is not an official Google product.

About

Flax is a neural network ecosystem for JAX that is designed for flexibility.

License:Apache License 2.0


Languages

Language:Python 87.8%Language:Jupyter Notebook 11.2%Language:Shell 0.8%Language:Dockerfile 0.2%