kyegomez / zeta

Build high-performance AI models with modular building blocks

Home Page:https://zeta.apac.ai

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

[BUG] utils/main/pad_at_dim - recommend refactoring to use torch.nn.functional.pad

evelynmitchell opened this issue · comments

The original function as written:

def pad_at_dim(t, pad, dim=-1, value=0.0):
    dims_from_right = (-dim - 1) if dim < 0 else (t.ndim - dim - 1)
    zeros = (0, 0) * dims_from_right
    return F.pad(t, (*zeros, *pad), value=value)

Assumes simple behavior. The PyTorch or Tensorflow implementation:

torch.nn.functional.pad (https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html)
https://github.com/tensorflow/tensorflow/blob/v2.14.0/tensorflow/python/ops/array_ops.py#L3452-L3508

has more complex, and correct behavior.

I noticed this, because none of the tests in test_pad_at_dim.py are passing.

pad_at_dim is used in:

playground/models/stacked_mm_bitnet
nn/biases/alibi.py
nn/modules/shift_tokens.py
zeta/structs/transformer.py

Upvote & Fund

  • We're using Polar.sh so you can upvote and help fund this issue.
  • We receive the funding once the issue is completed & confirmed by you.
  • Thank you in advance for helping prioritize & fund our backlog.
Fund with Polar

Stale issue message