[BUG] utils/main/pad_at_dim - recommend refactoring to use torch.nn.functional.pad
evelynmitchell opened this issue · comments
The original function as written:
def pad_at_dim(t, pad, dim=-1, value=0.0):
dims_from_right = (-dim - 1) if dim < 0 else (t.ndim - dim - 1)
zeros = (0, 0) * dims_from_right
return F.pad(t, (*zeros, *pad), value=value)
Assumes simple behavior. The PyTorch or Tensorflow implementation:
torch.nn.functional.pad (https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html)
https://github.com/tensorflow/tensorflow/blob/v2.14.0/tensorflow/python/ops/array_ops.py#L3452-L3508
has more complex, and correct behavior.
I noticed this, because none of the tests in test_pad_at_dim.py are passing.
pad_at_dim is used in:
playground/models/stacked_mm_bitnet
nn/biases/alibi.py
nn/modules/shift_tokens.py
zeta/structs/transformer.py
Upvote & Fund
- We're using Polar.sh so you can upvote and help fund this issue.
- We receive the funding once the issue is completed & confirmed by you.
- Thank you in advance for helping prioritize & fund our backlog.
Stale issue message