OverLordGoldDragon / keras-adamw

Keras/TF implementation of AdamW, SGDW, NadamW, Warm Restarts, and Learning Rate multipliers

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Actual lr seems fixed during training

eypros opened this issue · comments

I am a bit confused about the actual optimizers lr at each batch.

I have noticed that you there is a (now closed) issue regarding the Usage & concept questions where you refer to the actual lr (learning rate) being lr*eta_t.

But if I use your example as basis and include a plotting of the lr at each batch there does not appear to be any fluctuation of actual lr regardless of the values eta_t is assigned to.

from tensorflow.keras import backend as K
import os
os.environ["TF_KERAS"] = '1'
os.environ["TF_EAGER"] = '0'

from tensorflow.keras.layers import Input, Dense, LSTM
from tensorflow.keras.models import Model
from tensorflow.keras.regularizers import l1, l2, l1_l2

import numpy as np
import matplotlib.pyplot as plt

from keras_adamw import AdamW
from keras_adamw.utils import K_eval

USE_CPU = True

if USE_CPU:
    os.environ['CUDA_VISIBLE_DEVICES'] = ''

ipt = Input(shape=(120, 4))
x = LSTM(60, activation='relu', name='lstm_1',
         kernel_regularizer=l1(1e-4), recurrent_regularizer=l2(2e-4))(ipt)
out = Dense(1, activation='sigmoid', kernel_regularizer=l1_l2(1e-4, 2e-4))(x)
model = Model(ipt, out)

lr_multipliers = {'lstm_1': 0.5}

optimizer = AdamW(lr=1e-4, model=model, lr_multipliers=lr_multipliers,
                  use_cosine_annealing=True, total_iterations=24)
model.compile(optimizer, loss='binary_crossentropy')

eta_history = []
lr_history = []
for epoch in range(3):
    for iteration in range(24):
        x = np.random.rand(10, 120, 4)  # dummy data
        y = np.random.randint(0, 2, (10, 1))  # dummy labels
        loss = model.train_on_batch(x, y)
        eta_t = K_eval(model.optimizer.eta_t, K)
        eta_history.append(eta_t)
        t_cur = K_eval(model.optimizer.t_cur, K)
        lr = K_eval(model.optimizer.lr, K)  # K.eval(model.optimizer.lr)
        lr_history.append(lr)
        eta_max = K_eval(model.optimizer.eta_max, K)
        eta_min = K_eval(model.optimizer.eta_min, K)

        print('Iter {} t_cur: {} - lr: {} - eta_max: {} - eta_min: {}'.format(iteration + 1, t_cur, lr, eta_max, eta_min))
        print("Iter {} loss: {} - eta_t: {}".format(iteration + 1, "%.3f" % loss, eta_t))
        if iteration == (24 - 2):
            K.set_value(model.optimizer.t_cur, -1)  # WARM RESTART
    print("EPOCH {} COMPLETED\n".format(epoch + 1))

plt.plot(eta_history, linewidth=2)
plt.xlim(0, len(eta_history))
plt.ylim(0, 1.05)
plt.ylabel('eta_t', weight='bold', fontsize=15)
plt.xlabel('Train iterations', weight='bold', fontsize=15)
plt.gcf().set_size_inches(10, 5)
plt.show()
plt.close()

plt.plot(lr_history, linewidth=2)
plt.xlim(0, len(lr_history))
plt.ylim(0.9*np.min(lr_history), 1.1*np.max(lr_history))
plt.ylabel('lr', weight='bold', fontsize=15)
plt.xlabel('Train iterations', weight='bold', fontsize=15)
plt.gcf().set_size_inches(10, 5)
plt.show()

Iter 1 t_cur: 1 - lr: 9.999999747378752e-05 - eta_max: 1.0 - eta_min: 0.0
Iter 1 loss: 0.691 - eta_t: 0.9953429698944092
Iter 2 t_cur: 2 - lr: 9.999999747378752e-05 - eta_max: 1.0 - eta_min: 0.0
Iter 2 loss: 0.694 - eta_t: 0.9814586639404297
Iter 3 t_cur: 3 - lr: 9.999999747378752e-05 - eta_max: 1.0 - eta_min: 0.0
Iter 3 loss: 0.704 - eta_t: 0.9586056470870972
Iter 4 t_cur: 4 - lr: 9.999999747378752e-05 - eta_max: 1.0 - eta_min: 0.0
Iter 4 loss: 0.689 - eta_t: 0.927209734916687
Iter 5 t_cur: 5 - lr: 9.999999747378752e-05 - eta_max: 1.0 - eta_min: 0.0
Iter 5 loss: 0.682 - eta_t: 0.8878556489944458
Iter 6 t_cur: 6 - lr: 9.999999747378752e-05 - eta_max: 1.0 - eta_min: 0.0
Iter 6 loss: 0.708 - eta_t: 0.8412765264511108
Iter 7 t_cur: 7 - lr: 9.999999747378752e-05 - eta_max: 1.0 - eta_min: 0.0
Iter 7 loss: 0.684 - eta_t: 0.788340151309967
Iter 8 t_cur: 8 - lr: 9.999999747378752e-05 - eta_max: 1.0 - eta_min: 0.0
Iter 8 loss: 0.691 - eta_t: 0.7300325036048889
Iter 9 t_cur: 9 - lr: 9.999999747378752e-05 - eta_max: 1.0 - eta_min: 0.0
Iter 9 loss: 0.701 - eta_t: 0.6674398183822632
Iter 10 t_cur: 10 - lr: 9.999999747378752e-05 - eta_max: 1.0 - eta_min: 0.0
Iter 10 loss: 0.690 - eta_t: 0.6017280220985413
Iter 11 t_cur: 11 - lr: 9.999999747378752e-05 - eta_max: 1.0 - eta_min: 0.0
Iter 11 loss: 0.699 - eta_t: 0.5341211557388306
Iter 12 t_cur: 12 - lr: 9.999999747378752e-05 - eta_max: 1.0 - eta_min: 0.0
Iter 12 loss: 0.699 - eta_t: 0.46587878465652466
Iter 13 t_cur: 13 - lr: 9.999999747378752e-05 - eta_max: 1.0 - eta_min: 0.0
Iter 13 loss: 0.687 - eta_t: 0.39827197790145874
Iter 14 t_cur: 14 - lr: 9.999999747378752e-05 - eta_max: 1.0 - eta_min: 0.0
Iter 14 loss: 0.713 - eta_t: 0.3325602114200592
Iter 15 t_cur: 15 - lr: 9.999999747378752e-05 - eta_max: 1.0 - eta_min: 0.0
Iter 15 loss: 0.709 - eta_t: 0.2699674367904663
Iter 16 t_cur: 16 - lr: 9.999999747378752e-05 - eta_max: 1.0 - eta_min: 0.0
Iter 16 loss: 0.688 - eta_t: 0.21165981888771057
Iter 17 t_cur: 17 - lr: 9.999999747378752e-05 - eta_max: 1.0 - eta_min: 0.0
Iter 17 loss: 0.692 - eta_t: 0.15872341394424438
Iter 18 t_cur: 18 - lr: 9.999999747378752e-05 - eta_max: 1.0 - eta_min: 0.0
Iter 18 loss: 0.687 - eta_t: 0.1121443510055542
Iter 19 t_cur: 19 - lr: 9.999999747378752e-05 - eta_max: 1.0 - eta_min: 0.0
Iter 19 loss: 0.684 - eta_t: 0.07279029488563538
Iter 20 t_cur: 20 - lr: 9.999999747378752e-05 - eta_max: 1.0 - eta_min: 0.0
Iter 20 loss: 0.693 - eta_t: 0.04139435291290283
Iter 21 t_cur: 21 - lr: 9.999999747378752e-05 - eta_max: 1.0 - eta_min: 0.0
Iter 21 loss: 0.699 - eta_t: 0.018541336059570312
Iter 22 t_cur: 22 - lr: 9.999999747378752e-05 - eta_max: 1.0 - eta_min: 0.0
Iter 22 loss: 0.699 - eta_t: 0.00465703010559082
Iter 23 t_cur: 23 - lr: 9.999999747378752e-05 - eta_max: 1.0 - eta_min: 0.0
Iter 23 loss: 0.678 - eta_t: 0.0
Iter 24 t_cur: 0 - lr: 9.999999747378752e-05 - eta_max: 1.0 - eta_min: 0.0
Iter 24 loss: 0.696 - eta_t: 1.0
EPOCH 1 COMPLETED

@eypros Thanks for the report.

This is intended behavior. "Actual" LR is, in fact, not lr; LR is scaled by the betas (regular Adam), then by eta_t. Unlike tf.keras optimizers, the keras implementations do have an lr_t to track the true LR. It was a design decision to omit it from tf.keras per performance concerns - but admittedly, it is a useful feature, and performance impact might be negligible. I'll consider it for the next release.

You can verify that eta_t is effective with code below. I'll pin this issue for now in case anyone else wonders; feel free to re-open if any further concerns (or just comment).

Code
import random
random.seed(0)
import numpy as np
np.random.seed(1)
import tensorflow as tf
tf.compat.v1.set_random_seed(2)  # graph-level seed
if tf.__version__[0] == '2':
    tf.random.set_seed(3)  # global seed
else:
    tf.set_random_seed(3)  # global seed

import os
os.environ['TF_KERAS'] = '1'

from tensorflow.keras.layers import Input, Dense
from tensorflow.keras.models import Model
from tensorflow.keras import backend as K
from keras_adamw import AdamW

#%%##########################################################################
ipt = Input(batch_shape=(32, 4))
out = Dense(4)(ipt)
model = Model(ipt, out)
opt = AdamW(lr=1e-3, use_cosine_annealing=True, total_iterations=25)
model.compile(opt, 'mse')

x = y = np.random.randn(*model.input_shape)
K.set_value(opt.eta_t, 0)
K.set_value(opt.t_cur, opt.total_iterations - 2)

W_pre  = model.get_weights()
model.train_on_batch(x, y)
W_post = model.get_weights()

#%%##########################################################################
print("PRE-TRAIN:\n%s"  % W_pre)
print("POST_TRAIN:\n%s" % W_post)
print("DIFF:\n{}\n{}".format(W_post[0] - W_pre[0], W_post[1] - W_pre[1]))

P.S., setting the TF_EAGER environment variable is redundant; it's done in testing to control eager/graph behavior in the tests directory, but keras_adamw detects it automatically.

Actually... you'll see the bias weights do change. In fact, it'll always be the very last weight in the network. This is a legitimate bug, and I'll fix it soon (Issue here); in the meantime, you can apply the fix below in your local install:

Rearrange code in _resource_apply_dense and _resource_apply_sparse as follows (keep var_update as-is, move others below it):

var_update = state_ops.assign(var, var_t, use_locking=self._use_locking)

# Learning rate multipliers
# Cosine annealing
(iteration_done, t_cur_update, eta_t_update
 ) = _update_t_cur_eta_t_apply_lr_mult(self, lr_t, var)
if iteration_done and not self._init_notified:
    self._init_notified = True

Fixed in v1.32, and added lr_t. See updated example.py.

I will examine the changes you inserted but as a first comment I was setting TF_EAGER explicitly because in my case it's unset and it complains when checked for the actual value.

@eypros That's strange - what's the "complaint", a warning? And which TF version?