tensorflow / lattice

Lattice methods in TensorFlow

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Calibrations are not Monotonic

PoorvaRane opened this issue · comments

I'm trying to train a calibration to incorporate into tf ranking, similar to #35

My code for the calibration looks like this:

    kp_inits = {
        'num_keypoints':    101,
        'input_min':        0.0,
        'input_max':        25.0,
        'output_min':       0.0,
        'output_max':       1.0
    }
    sample_input = [
        tf.compat.v1.layers.flatten(group_features[name])
          for name in ['1','2', '3', '4', '5', '6']
      ]
    sample_layer = tf.concat(sample_input, 1)
    calib_layer = tfl.layers.PWLCalibration(
        input_keypoints=np.linspace(
              kp_inits['input_min'],
              kp_inits['input_max'],
              num=kp_inits['num_keypoints'],
              dtype=np.float32),
          units=len(sample_input),
          output_min=kp_inits['output_min'],
          output_max=kp_inits['output_max'],
          monotonicity='increasing',
          name="sample_calib")(sample_layer)

The calibrations learned by this model are shown below. The bounds are being respected, but the calibrations do not appear monotonic. Any insights on what could be missing?

tensor_name:  groupwise_dnn_v2/group_score/final_calib/pwl_calibration_kernel
array([[0.00226, 0.00571, 0.00000, 0.00020, 0.00000, 0.00122],
       [0.00093, 0.00571, 0.00000, 0.00030, 0.00151, 0.00084],
       [0.00060, 0.00571, 0.00479, 0.00057, 0.00214, 0.00158],
       [0.00058, 0.00571, 0.06334, 0.00070, 0.00306, 0.00708],
       [0.00049, 0.00571, 0.06504, 0.00077, 0.00456, 0.00930],
       [0.00034, 0.00571, 0.04417, 0.00113, 0.00586, 0.00959],
       [0.00027, 0.00571, 0.02631, 0.00113, 0.01959, 0.00967],
       [0.00022, 0.00571, 0.02858, 0.00125, 0.02807, 0.00976],
       [0.00015, 0.00571, 0.01595, 0.00147, 0.02679, 0.00968],
       [0.00006, 0.00571, 0.02027, 0.00158, 0.02347, 0.00969],
       [0.00006, 0.00571, 0.02155, 0.00156, 0.01759, 0.00968],
       [0.00007, 0.00571, 0.02454, 0.00170, 0.01365, 0.00969],
       [0.00027, 0.00571, 0.02354, 0.00218, 0.01264, 0.00967],
       [0.00072, 0.00571, 0.01969, 0.00236, 0.00865, 0.00966],
       [0.00072, 0.00571, 0.01615, 0.00239, 0.00707, 0.00962],
       [0.00073, 0.00571, 0.01533, 0.00250, 0.00666, 0.00966],
       [0.00074, 0.00569, 0.01530, 0.00243, 0.01226, 0.00971],
       [0.00075, 0.00562, 0.01565, 0.00255, 0.01212, 0.00972],
       [0.00087, 0.00562, 0.01165, 0.00242, 0.01296, 0.00973],
       [0.00077, 0.00562, 0.00930, 0.00264, 0.01398, 0.00981],
       [0.00065, 0.00562, 0.00851, 0.00263, 0.01098, 0.00998],
       [0.00055, 0.00562, 0.00652, 0.01464, 0.00634, 0.01010],
       [0.00418, 0.00562, 0.00747, 0.03795, 0.00606, 0.01015],
       [0.09591, 0.00562, 0.00930, 0.04783, 0.00609, 0.01021],
       [0.08237, 0.00562, 0.01206, 0.05157, 0.00618, 0.01027],
       [0.06946, 0.00562, 0.01499, 0.05142, 0.00673, 0.01040],
       [0.04949, 0.00563, 0.01555, 0.04344, 0.00697, 0.01057],
       [0.03420, 0.00563, 0.01226, 0.03337, 0.00713, 0.01063],
       [0.01487, 0.00563, 0.01086, 0.02088, 0.00765, 0.01063],
       [0.01598, 0.00563, 0.00994, 0.01747, 0.00851, 0.01065],
       [0.01073, 0.00560, 0.00936, 0.01679, 0.00775, 0.01066],
       [0.03031, 0.00569, 0.00805, 0.01175, 0.00889, 0.01059],
       [0.03343, 0.00575, 0.00763, 0.01121, 0.01002, 0.01051],
       [0.02606, 0.00581, 0.00820, 0.02078, 0.00908, 0.01043],
       [0.02077, 0.00583, 0.00856, 0.02880, 0.00955, 0.01044],
       [0.01382, 0.00591, 0.01037, 0.03544, 0.00865, 0.01046],
       [0.03073, 0.00594, 0.00828, 0.03041, 0.00840, 0.01050],
       [0.03391, 0.00594, 0.00746, 0.02282, 0.00790, 0.01044],
       [0.01764, 0.00594, 0.00724, 0.01575, 0.00704, 0.01045],
       [0.00958, 0.00594, 0.00875, 0.01957, 0.00712, 0.01046],
       [0.01312, 0.00594, 0.00840, 0.02152, 0.00719, 0.01043],
       [0.00948, 0.00594, 0.00937, 0.02334, 0.00870, 0.01038],
       [0.00919, 0.00594, 0.00861, 0.02835, 0.01373, 0.01038],
       [0.00858, 0.00594, 0.00837, 0.03308, 0.01630, 0.01039],
       [0.00821, 0.00594, 0.00813, 0.03074, 0.01830, 0.01035],
       [0.00769, 0.00594, 0.00778, 0.02386, 0.01803, 0.01038],
       [0.00692, 0.00594, 0.00723, 0.02606, 0.01497, 0.01042],
       [0.00617, 0.00594, 0.00752, 0.02047, 0.01534, 0.01046],
       [0.00491, 0.00594, 0.00549, 0.02386, 0.01614, 0.01054],
       [0.00356, 0.00594, 0.00430, 0.02269, 0.01515, 0.01060],
       [0.00272, 0.00594, 0.00469, 0.02590, 0.01450, 0.01066],
       [0.00229, 0.00594, 0.00458, 0.02557, 0.01211, 0.01068],
       [0.01078, 0.00594, 0.00434, 0.01946, 0.01062, 0.01071],
       [0.01128, 0.00594, 0.00421, 0.02055, 0.00987, 0.01077],
       [0.01048, 0.00594, 0.00449, 0.01359, 0.01097, 0.01075],
       [0.00929, 0.00594, 0.00493, 0.00661, 0.01122, 0.01074],
       [0.00908, 0.00594, 0.00520, 0.00450, 0.01201, 0.01077],
       [0.00915, 0.00594, 0.00539, 0.00457, 0.01152, 0.01070],
       [0.00865, 0.00594, 0.00550, 0.00395, 0.01111, 0.01106],
       [0.00780, 0.00594, 0.00559, 0.00321, 0.01133, 0.01122],
       [0.00660, 0.00594, 0.00597, 0.00328, 0.01170, 0.01175],
       [0.00500, 0.00594, 0.00572, 0.00270, 0.01158, 0.01207],
       [0.00403, 0.00594, 0.00572, 0.00256, 0.01027, 0.01225],
       [0.00307, 0.00594, 0.00582, 0.00232, 0.00874, 0.01247],
       [0.00296, 0.00594, 0.00540, 0.00255, 0.00742, 0.01260],
       [0.00251, 0.00594, 0.00485, 0.00243, 0.00769, 0.01298],
       [0.00195, 0.00594, 0.00418, 0.00194, 0.00740, 0.01293],
       [0.00176, 0.00594, 0.00366, 0.00158, 0.00730, 0.01336],
       [0.00180, 0.00594, 0.00359, 0.00158, 0.00700, 0.01317],
       [0.00117, 0.00594, 0.00410, 0.00157, 0.00688, 0.01324],
       [0.00110, 0.00594, 0.00426, 0.00138, 0.00663, 0.01354],
       [0.00103, 0.00594, 0.00443, 0.00102, 0.00638, 0.01395],
       [0.00091, 0.00594, 0.00450, 0.00076, 0.00621, 0.01255],
       [0.00082, 0.00594, 0.00471, 0.00051, 0.00602, 0.01161],
       [0.00078, 0.00594, 0.00476, 0.00048, 0.00611, 0.01280],
       [0.00064, 0.00594, 0.00502, 0.00042, 0.00627, 0.01283],
       [0.00061, 0.00594, 0.00540, 0.00068, 0.00639, 0.01050],
       [0.00055, 0.00594, 0.00563, 0.00036, 0.00642, 0.00967],
       [0.00049, 0.00594, 0.00592, 0.00014, 0.00462, 0.00966],
       [0.00041, 0.00594, 0.00623, 0.00008, 0.00465, 0.00946],
       [0.00027, 0.00594, 0.00633, 0.00000, 0.00444, 0.00900],
       [0.00017, 0.00594, 0.00646, 0.00000, 0.00494, 0.00867],
       [0.00012, 0.00594, 0.00646, 0.00000, 0.00551, 0.00840],
       [0.00006, 0.00594, 0.00628, 0.00000, 0.00601, 0.00850],
       [0.00000, 0.00594, 0.00638, 0.00000, 0.00579, 0.00836],
       [0.00000, 0.00594, 0.00674, 0.00000, 0.00466, 0.00855],
       [0.00000, 0.00594, 0.00692, 0.00000, 0.00417, 0.00834],
       [0.00000, 0.00594, 0.00722, 0.00000, 0.00405, 0.00664],
       [0.00000, 0.00594, 0.00784, 0.00000, 0.00371, 0.00354],
       [0.00000, 0.00594, 0.00866, 0.00000, 0.00367, 0.00315],
       [0.00000, 0.00594, 0.00929, 0.00000, 0.00349, 0.00210],
       [0.00000, 0.00594, 0.01070, 0.00000, 0.00290, 0.00230],
       [0.00000, 0.00594, 0.01128, 0.00000, 0.00200, 0.00137],
       [0.00000, 0.00594, 0.01080, 0.00000, 0.00205, 0.00143],
       [0.00000, 0.00594, 0.01057, 0.00000, 0.00189, 0.00127],
       [0.00000, 0.00594, 0.00886, 0.00000, 0.00189, 0.00000],
       [0.00000, 0.00594, 0.00513, 0.00000, 0.00190, 0.00000],
       [0.00000, 0.00594, 0.00323, 0.00000, 0.00092, 0.00000],
       [0.00000, 0.00594, 0.00260, 0.00000, 0.00074, 0.00038],
       [0.00000, 0.00594, 0.00145, 0.00000, 0.00047, 0.00065],
       [0.00000, 0.00594, 0.00030, 0.00000, 0.00000, 0.00032]],
      dtype=float32)

The kernel for PWL encodes the deltas between keypoint outputs (which are all positive for a monotonic layer). You need to do a cumsum over it, or simply use keypoints_outputs() methods on the layer: https://www.tensorflow.org/lattice/api_docs/python/tfl/layers/PWLCalibration#keypoints_outputs