rust-ndarray / ndarray

ndarray: an N-dimensional array with array views, multidimensional slicing, and efficient operations

Home Page:https://docs.rs/ndarray/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Reshape isn't producing same output as numpy?

Heidar-An opened this issue · comments

Can't tell if I'm making a mistake somewhere but take a look a look at this:
(also, note, k and a are the same in both languages, I just didn't show the output here)

println!("anchors before reshape {}", anchors);
let anchors = anchors.into_shape((k * a, 4)).unwrap();
println!("anchors after reshape {}", anchors);

Produces:

anchors before reshape [[[[-8, -8, 23, 23],
   [0, 0, 15, 15]],

  [[0, -8, 31, 23],
   [8, 0, 23, 15]],

  [[8, -8, 39, 23],
   [16, 0, 31, 15]],

  ...,

  [[752, -8, 783, 23],
   [760, 0, 775, 15]],

  [[760, -8, 791, 23],
   [768, 0, 783, 15]],

  [[768, -8, 799, 23],
   [776, 0, 791, 15]]],


 [[[-8, 0, 23, 31],
   [0, 8, 15, 23]],

  [[0, 0, 31, 31],
   [8, 8, 23, 23]],

  [[8, 0, 39, 31],
   [16, 8, 31, 23]],

  ...,

  [[752, 0, 783, 31],
   [760, 8, 775, 23]],

  [[760, 0, 791, 31],
   [768, 8, 783, 23]],

  [[768, 0, 799, 31],
   [776, 8, 791, 23]]],


 [[[-8, 8, 23, 39],
   [0, 16, 15, 31]],

  [[0, 8, 31, 39],
   [8, 16, 23, 31]],

  [[8, 8, 39, 39],
   [16, 16, 31, 31]],

  ...,

  [[752, 8, 783, 39],
   [760, 16, 775, 31]],

  [[760, 8, 791, 39],
   [768, 16, 783, 31]],

  [[768, 8, 799, 39],
   [776, 16, 791, 31]]],


 ...,


 [[[-8, 408, 23, 439],
   [0, 416, 15, 431]],

  [[0, 408, 31, 439],
   [8, 416, 23, 431]],

  [[8, 408, 39, 439],
   [16, 416, 31, 431]],

  ...,

  [[752, 408, 783, 439],
   [760, 416, 775, 431]],

  [[760, 408, 791, 439],
   [768, 416, 783, 431]],

  [[768, 408, 799, 439],
   [776, 416, 791, 431]]],


 [[[-8, 416, 23, 447],
   [0, 424, 15, 439]],

  [[0, 416, 31, 447],
   [8, 424, 23, 439]],

  [[8, 416, 39, 447],
   [16, 424, 31, 439]],

  ...,

  [[752, 416, 783, 447],
   [760, 424, 775, 439]],

  [[760, 416, 791, 447],
   [768, 424, 783, 439]],

  [[768, 416, 799, 447],
   [776, 424, 791, 439]]],


 [[[-8, 424, 23, 455],
   [0, 432, 15, 447]],

  [[0, 424, 31, 455],
   [8, 432, 23, 447]],

  [[8, 424, 39, 455],
   [16, 432, 31, 447]],

  ...,

  [[752, 424, 783, 455],
   [760, 432, 775, 447]],

  [[760, 424, 791, 455],
   [768, 432, 783, 447]],

  [[768, 424, 799, 455],
   [776, 432, 791, 447]]]]
   
-----------------------------------------------

anchors after reshape [[-8, -8, 23, 23],
 [-8, 0, 23, 31],
 [-8, 8, 23, 39],
 [-8, 16, 23, 47],
 [-8, 24, 23, 55],
 ...,
 [776, 400, 791, 415],
 [776, 408, 791, 423],
 [776, 416, 791, 431],
 [776, 424, 791, 439],
 [776, 432, 791, 447]]

Python version:

print(f"anchors before reshape{anchors}")
anchors = anchors.reshape((K * A, 4))
print(f"anchors after reshape {anchors}")

Produces:

``` anchors before reshape[[[[ -8. -8. 23. 23.] [ 0. 0. 15. 15.]]

[[ 0. -8. 31. 23.]
[ 8. 0. 23. 15.]]

[[ 8. -8. 39. 23.]
[ 16. 0. 31. 15.]]

...

[[752. -8. 783. 23.]
[760. 0. 775. 15.]]

[[760. -8. 791. 23.]
[768. 0. 783. 15.]]

[[768. -8. 799. 23.]
[776. 0. 791. 15.]]]

[[[ -8. 0. 23. 31.]
[ 0. 8. 15. 23.]]

[[ 0. 0. 31. 31.]
[ 8. 8. 23. 23.]]

[[ 8. 0. 39. 31.]
[ 16. 8. 31. 23.]]

...

[[752. 0. 783. 31.]
[760. 8. 775. 23.]]

[[760. 0. 791. 31.]
[768. 8. 783. 23.]]

[[768. 0. 799. 31.]
[776. 8. 791. 23.]]]

[[[ -8. 8. 23. 39.]
[ 0. 16. 15. 31.]]

[[ 0. 8. 31. 39.]
[ 8. 16. 23. 31.]]

[[ 8. 8. 39. 39.]
[ 16. 16. 31. 31.]]

...

[[752. 8. 783. 39.]
[760. 16. 775. 31.]]

[[760. 8. 791. 39.]
[768. 16. 783. 31.]]

[[768. 8. 799. 39.]
[776. 16. 791. 31.]]]

...

[[[ -8. 408. 23. 439.]
[ 0. 416. 15. 431.]]

[[ 0. 408. 31. 439.]
[ 8. 416. 23. 431.]]

[[ 8. 408. 39. 439.]
[ 16. 416. 31. 431.]]

...

[[752. 408. 783. 439.]
[760. 416. 775. 431.]]

[[760. 408. 791. 439.]
[768. 416. 783. 431.]]

[[768. 408. 799. 439.]
[776. 416. 791. 431.]]]

[[[ -8. 416. 23. 447.]
[ 0. 424. 15. 439.]]

[[ 0. 416. 31. 447.]
[ 8. 424. 23. 439.]]

[[ 8. 416. 39. 447.]
[ 16. 424. 31. 439.]]

...

[[752. 416. 783. 447.]
[760. 424. 775. 439.]]

[[760. 416. 791. 447.]
[768. 424. 783. 439.]]

[[768. 416. 799. 447.]
[776. 424. 791. 439.]]]

[[[ -8. 424. 23. 455.]
[ 0. 432. 15. 447.]]

[[ 0. 424. 31. 455.]
[ 8. 432. 23. 447.]]

[[ 8. 424. 39. 455.]
[ 16. 432. 31. 447.]]

...

[[752. 424. 783. 455.]
[760. 432. 775. 447.]]

[[760. 424. 791. 455.]
[768. 432. 783. 447.]]

[[768. 424. 799. 455.]
[776. 432. 791. 447.]]]]


anchors after reshape [[ -8. -8. 23. 23.]
[ 0. 0. 15. 15.]
[ 0. -8. 31. 23.]
...
[768. 432. 783. 447.]
[768. 424. 799. 455.]
[776. 432. 791. 447.]]

</details>

What is going on here? Am I doing something dumb? Clearly, you can see they're not the same anymore. 
I also checked their `.sum()`s and they were the same (yes, I know that's a perfect equal method, but if the all the lines output are the same and `.sum()` is also the same, they were probably the same)

Are you sure both anchors are in the same memory order (c or f)?

how do I check / set this?

commented

ndarray is not a reimplementation of numpy. It's still good to compare, but it's not the basis for a bug.

into_shape has some wonky behaviour w.r.t how it handles memory layout. The newer to_shape() fixes those problems! You should use to_shape() if you can.

(Elaboration: Into shape doesn't do it /wrong/ but when it has a successful result, it preserves memory layout. C in gives C out. F in gives F out. But if you don't know what memory layout you have going in, into_shape can surprise you.)

It has been planned by me to update into_shape() with the same ideas (and sharing implementation). I have the code lying around, but it was not completed. It would be good to get it in, but it didn't come in together with to_shape because it would be a breaking change.

tl;dr: don't use reshape, don't use into shape. Use to_shape or get @bluss to fix the into_shape implementation.

commented

@Heidar-An compare .strides() ndarray with .strides on the array in Python. NB ndarray uses strides in units of elements, while numpy is in bytes. ndarray also shows the layout of the array in the {:?} formatter.

To make it easy for us to read the bug report, could you fill in the shape of the arrays before and after reshape in both scenarios? 🙂

how do I check / set this?

arr.flags in python, arr.is_standard_layout() in Rust. The standard is C order in numpy and ndarray so it's probably not the problem you have. Follow @bluss suggestion.

thanks so much everyone! (especially @bluss ) Using as_shape instead worked.
I was super surprised by how quickly you all responded, love this crate and the contributors!