tensorflow / lattice

Lattice methods in TensorFlow

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

monotonic calibration not working.

vishnuapp opened this issue · comments

Trying to train a calibration for some signals to incorporate into tf ranking.

The relevant code for the calibration is

` num_keypoints = 26
kp_inits = tfl.uniform_keypoints_for_signal(
num_keypoints=num_keypoints,
input_min=0.0,
input_max=1.0,
output_min=0.0,
output_max=1.0)

# Define input layer.
# First we just take first two features and combine them linearly.
# Then we combine the output of this with the third feature.
    
picked_input = [
  tf.layers.flatten(group_features[name])
    for name in ['36', '32', '35', '33', '38', '39']
  ]

input_layer = tf.concat(picked_input, 1)
cur_layer = tfl.calibration_layer(
   input_layer,
   num_keypoints=num_keypoints,
   keypoints_initializers=kp_inits,
   bound=True,
   monotonic=[1 for _ in range(6)],
   name="calibration")

logits = tf.layers.dense(cur_layer[0], units=1, name='linear_layer', activation="elu")

`

The learned model isnt monotonic, here are some of the calibration it has learnt
tensor_name: group_score/pwl_calibration/signal_2_bound_max 1.0 tensor_name: group_score/pwl_calibration/signal_2_bound_min 0.0 tensor_name: group_score/pwl_calibration/signal_2_keypoints_inputs [0. 0.04 0.08 0.12 0.16 0.19999999 0.24 0.28 0.32 0.35999998 0.39999998 0.44 0.48 0.52 0.56 0.59999996 0.64 0.68 0.71999997 0.76 0.79999995 0.84 0.88 0.91999996 0.96 1. ] tensor_name: group_score/pwl_calibration/signal_2_keypoints_inputs/Adagrad [0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1] tensor_name: group_score/pwl_calibration/signal_2_keypoints_outputs [ 0.5595347 0.00848915 -0.02862659 0.44848698 0.3586025 0.40749145 0.35288998 0.38407487 0.38621387 0.47819927 0.6856117 0.60562074 0.59473854 0.5449814 0.43999994 0.61086124 0.72133946 0.64237064 0.66826046 0.7117335 0.6590987 0.662649 0.5869861 0.87017834 0.7034538 1.2272371 ] tensor_name: group_score/pwl_calibration/signal_2_keypoints_outputs/Adagrad [4.567583 0.34649372 0.2375099 0.2630496 0.22509426 0.19528154 0.1826403 0.19447225 0.1917207 0.21152268 0.17799918 0.18089467 0.2096777 0.18614963 0.17668937 0.1913786 0.23144016 0.23107207 0.2278506 0.21568052 0.26991028 0.24701497 0.287972 0.36811396 0.62489855 2.2491465 ]

The bounds arent respected either.

You seem to have not applied the projections. Note that calibration_layer is returning:

A tuple of:
* calibrated tensor of shape [batch_size, ...], the same shape as
  uncalibrated.
* list of projection ops, that must be applied at each step (or every so
  many steps) to project the model to a feasible space: used for bounding
  the outputs or for imposing monotonicity. Empty if none are requested.
* None or tensor with regularization loss.

So you will need to apply the projection ops that are returned after each batch update. If you are writing your own loop, you can just add a session.run to apply the projection ops. With estimators, this can be done with a SessionRunHook. See the Base estimator for an example:
https://github.com/tensorflow/lattice/blob/master/tensorflow_lattice/python/estimators/base.py

worked well, Since i was using an estimator, i setup a SessionRunHook and it obeys the constraints,
Am still having some trouble using the projection ops when i save and restore a model during checkpointing, but i think i'll work thru to figure it out.

Thanks for the help,

On a related note, is there a reason the tensor name doesnt include in any part the name sent in for the call for calibration_layer. That would make it difficult to use multiple calls without collision in the names?

You can always use tf.name_scope for that. But TF will add a suffix to avoid collisions if you recreate the layer in the same scope. Marking as closed.