jonasrauber / eagerpy

PyTorch, TensorFlow, JAX and NumPy — all of them natively using the same code

Home Page:https://eagerpy.jonasrauber.de

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Equivalent of `np.diag`?

Holt59 opened this issue · comments

Is there an equivalent of np.diag in eagerpy? If not, what would be the proper way to create a diagonal matrix from a vector?

Not yet, but that's easy to add (PyTorch, TensorFlow, and JAX all seem to support it). Do you want to make a PR?

Thanks for the answer, I can try to make a PR but I never used JAX.

Side question: Is there a way to extend eagerpy.* by using, e.g., entry points or other python stuff? I currently have a ep_utils package that I import with custom function (e.g. diag), but it would be nice to be able to just install an "extension" package and then be able to do eagerpy.diag.

In principle, extending packages with other packages is possible, but when I did that for the development version of Foolbox Native, it was quite a pain (requires certain pip versions, not everything works, etc.).

So I started to write a PR for this but I'm kind of confused with your test sets... Currently I have something like this:

@compare_all
def test_diag_1(dummy: Tensor) -> Tensor:
    t = ep.arange(dummy, 4).float32()
    return ep.diag(t)


@compare_all
def test_diag_2(dummy: Tensor) -> Tensor:
    t = ep.arange(dummy, 4).float32()
    return ep.diag(t, k=2)


@compare_all
def test_diag_3(dummy: Tensor) -> Tensor:
    t = ep.arange(dummy, 9).float32().reshape((3, 3))
    return ep.diag(t)


@compare_all
def test_diag_4(dummy: Tensor) -> Tensor:
    t = ep.arange(dummy, 9).float32().reshape((3, 3))
    return ep.diag(t, k=2)

But...

  1. If I understand compare_all correctly, it will simply check consistency between the methods, not the actual results. Should I wrote a different (non-decorated) test_diag?

  2. When I run pytest, it skips everything... Any idea why? I just found out about make test...

  1. You are right, compare_all and the other decorators like this simply check consistency by comparing each framework to the NumPy implementation. I consider that sufficient because the underlying frameworks are responsible for correctness themselves (EagerPy provides a consistent wrapper but not custom implementations). No need to write a non-decorated test_diag.

  2. You already found the solution. The reason for that is that we have to run the tests for the different frameworks in separate processes because the (GPU) frameworks sometimes break each other.