The definition of bins in Predicted Aligned Error Head(PAE) may be wrong
xiergo opened this issue · comments
I am confused about the definition of bins in Predicted Aligned Error Head(PAE). The breaks
is defined as [0, 0.5, 1, ..., 31]
# self.config.max_error_bin=31, self.config.num_bins=64
breaks = jnp.linspace(
0., self.config.max_error_bin, self.config.num_bins - 1)
and the centers are [0.25, 0.75, ..., 30.5, 31.5, 32.5], according to:
def _calculate_bin_centers(breaks: np.ndarray):
"""Gets the bin centers from the bin edges.
Args:
breaks: [num_bins - 1] the error bin edges.
Returns:
bin_centers: [num_bins] the error bin centers.
"""
step = (breaks[1] - breaks[0])
# Add half-step to get the center
bin_centers = breaks + step / 2
# Add a catch-all bin at the end.
bin_centers = np.concatenate([bin_centers, [bin_centers[-1] + step]],
axis=0)
return bin_centers
Then the 64 bins are [0, 0.5], [0.5, 1] ..., [31, 31.5], [31.5, +inf].
But the bins defined in the PAE-loss are [-inf, 0], [0, 0.5], ...[31, +inf], which are left shifted for one bin, based on the definition in alphafold/alphafold/model/modules.py line 1200
:
sq_breaks = jnp.square(breaks) #[0, 0.5, ..., 31]
true_bins = jnp.sum((
error_dist2[..., None] > sq_breaks).astype(jnp.int32), axis=-1)
errors = softmax_cross_entropy(
labels=jax.nn.one_hot(true_bins, num_bins, axis=-1), logits=logits)
For example, for error_dist=0.75, which should fall into the second bin [0.5, 1], but (0.75>breaks).sum()
is 2, the one_hot values are [0, 0, 1, 0, ..., 0] with the third entry being 1, which is incorrect.