amitjslearn / tf-fit

Fit your tensorflow model using fastai and PyTorch

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

fastai-tf-fit

Fit your Tensorflow model using fastai and PyTorch

Installation

pip install git+https://github.com/fastai/tf-fit.git

Features

This project is an extension of fastai to allow training of Tensorflow models with a similar interface of fastai. It uses fastai DataBunch objects so the interface is exactly the same for loading data. For training, the TfLearner has many of the same features as the fastai Learner. Here is a list of the currently supported features.

  • Training Tensorflow models with constant learning rate and weight decay
  • Training using the 1cycle policy
  • Learning rate finder
  • Fit with callbacks with access to hyper parameter updates
  • Discriminative learning rates
  • Freezing layers from having parameters trained
  • True weight decay option
  • L2 regularization (true_wd=False)
  • Removing weight decay from batchnorm layers option (bn_wd=False)
  • Momentum
  • Option to train batchnorm layers even if the layer is frozen (train_bn=True)
  • Model saving and loading
  • Default image data format is channels * hieght * width

To do

This project is a work in progress so there may be missing features or obscure bugs.

  • Get predictions function
  • Tensorflow train/eval functionality for dropout and batchnorm in eager mode
  • Pip and conda packages

Examples

Setup

Setup fastai data bunch, optimizer, loss function, and metrics.

from fastai.vision import *
from fastai_tf_fit import *

path = untar_data(URLs.CIFAR)
ds_tfms = ([*rand_pad(4, 32), flip_lr(p=0.5)], [])
data = ImageDataBunch.from_folder(path, valid='test', ds_tfms=ds_tfms, bs=512).normalize(cifar_stats)

opt_fn = tf.train.AdamOptimizer

loss_fn = tf.losses.sparse_softmax_cross_entropy

def categorical_accuracy(y_pred, y_true):
    return tf.keras.backend.mean(tf.keras.backend.equal(y_true, tf.keras.backend.argmax(y_pred, axis=-1)))
metrics = [categorical_accuracy]

Using tf.keras.Model

class Simple_CNN(tf.keras.Model):
    def __init__(self):
        super().__init__()
        self.conv1 = tf.keras.layers.Conv2D(16, kernel_size=3, strides=(2,2), padding='same')
        self.bn1 = tf.keras.layers.BatchNormalization(axis=1)
        self.conv2 = tf.keras.layers.Conv2D(16, kernel_size=3, strides=(2,2), padding='same')
        self.bn2 = tf.keras.layers.BatchNormalization(axis=1)
        self.conv3 = tf.keras.layers.Conv2D(10, kernel_size=3, strides=(2,2), padding='same')
        self.bn3 = tf.keras.layers.BatchNormalization(axis=1)
    def call(self, xb):
        xb = tf.nn.relu(self.bn1(self.conv1(xb)))
        xb = tf.nn.relu(self.bn2(self.conv2(xb)))
        xb = tf.nn.relu(self.bn3(self.conv3(xb)))
        xb = tf.nn.pool(xb, (4,4), 'AVG', 'VALID', data_format="NCHW")
        xb = tf.reshape(xb, (-1, 10))
        return xb

model = Simple_CNN()

Using Keras functional API

inputs = tf.keras.layers.Input(shape=(3,32,32))
x = tf.keras.layers.Conv2D(16, kernel_size=3, strides=(2,2), padding='same')(inputs)
x = tf.keras.layers.BatchNormalization(axis=1)(x)
x = tf.keras.layers.Activation("relu")(x)
x = tf.keras.layers.Conv2D(16, kernel_size=3, strides=(2,2), padding='same')(x)
x = tf.keras.layers.BatchNormalization(axis=1)(x)
x = tf.keras.layers.Activation("relu")(x)
x = tf.keras.layers.Conv2D(10, kernel_size=3, strides=(2,2), padding='same')(x)
x = tf.keras.layers.BatchNormalization(axis=1)(x)
x = tf.keras.layers.Activation("relu")(x)
x = tf.keras.layers.AveragePooling2D(pool_size=(4, 4), padding='same')(x)
x = tf.keras.layers.Reshape((10,))(x)
predictions = tf.keras.layers.Dense(10, activation='softmax')(x)
model = tf.keras.models.Model(inputs=inputs, outputs=predictions)

Training

Create TfLearner object

learn = TfLearner(data, model, opt_fn, loss_fn, metrics=metrics, true_wd=True, bn_wd=True, wd=defaults.wd, train_bn=True)

Learning rate finder.

learn.lr_find()
learn.recorder.plot()

Train the model for 3 epochs with a learning rate of 3e-3 and weight decay of 0.4.

learn.fit(3, lr=3e-3, wd=0.4)

Fit the model using 1cycle policy with a cycle length of 10 using a discriminative learning rate.

learn.fit_one_cycle(10, max_lr=slice(6e-3, 3e-3))

Freeze, unfreeze, and freeze to last layers from training.

learn.freeze()
learn.unfreeze()
learn.freeze_to(-1)

Save and load model weights.

learn.save('cnn-1')
learn.load('cnn-1')

Metrics

Plot learning rate and momentum schedules.

learn.recorder.plot_lr(show_moms=True)

Plot train and validation losses.

learn.recorder.plot_losses()

Plot metrics.

learn.recorder.plot_metrics()

About

Fit your tensorflow model using fastai and PyTorch

License:Apache License 2.0


Languages

Language:Python 100.0%