xl0 / lovely-tensors

Tensors, ready for human consumption

Home Page:https://xl0.github.io/lovely-tensors

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Showing images with matplotlib backend

OscarPellicer opened this issue · comments

I've just discovered this library today and I really like it! I use vscode, and I wanted to use lovely tensors' image plotting capabilities within a standard python debug session. Since matplotlib works while debugging, at the moment I am defining this hacky function, ltplot, for every python file where I want to do tensor debugging:

import lovely_tensors as lt
lt.monkey_patch()
import matplotlib.pyplot as plt, matplotlib.image as mpimg, io
def ltplot(img):
    'Plot an lt image array such as x.chans, x.rgb, x.plt'
    fp = io.BytesIO(img._repr_png_())
    with fp: img = mpimg.imread(fp, format='png')
    plt.imshow(img)
    plt.show()

And then I would do something like:

ltplot(x.chans)

As a suggestion, maybe we could have something like: lt.monkey_patch(backend='matplotlib'), so that it automatically defaults to plotting using matplotlib.

Thank you! 💕

I forgot that not everyone had converted their workflow to the blessed ways of nbdev. :) (and it totally works in VScode, that's how I use it).

I think it's a reasonable request, either, as you suggested, as an argument to monkey_patch, or maybe better if I rework (and rename) PRINT_OPTIONS to cover the image backend too. I should be able to get on it in around a week, but if you'd like to implement it, PRs are very welcome.

I didn't know about nbdev, but it looks very interesting. I will look into it, and I may convert too!

I second this feature, it would be nice to be able to do tensor.plt_show(name="optional window name") for those of us not using notebooks (or VSCode). Are PR's welcome?

Yep, very much!

.plt() actually can take an ax= for the matplotlib axis, but it's not the case for the image methods since they are not based on matplotlib.

As for the exact interface, I'm not completely certain what it should look like.
Maybe tensor.plt.to_mpl(name="optional window name")?

I'm also thinking about a better way to normalize the image values, particularly useful for .chan

For example, in this code https://xl0.github.io/Visualize-Understand-CNNs/impl.html , I end up with functions like sigmas() and pos_sigmas() all over the place. They bring values into the desired [0,1] or [0.5 ,1] range, and I should probably add a normalization function to lovely for this purpose, for example:

tensor.normalizer("sigma", 3) -> touch.Tensor rescaled in a way that maps values μ +/- 3σ into [0, 1] interval.

So the full call would be along the lines of tensot.normalize("sigmas", 3).chans.to_mpl("Window Name")
Or, since nobody got time to type, t.n("sigma", 3).chans.mpl(...)

I'd rather have this chained interface over having the functions with a dozen of partially overlapping parameters. Do you have any thoughts on this or other ideas for the interface?

Sounds good! Prefer if it's easy to remember and quick to type because I'll often be exploring tensors interactively in pdb. If there's only one backend, then even just tensor.plt.show() would be nice. Don't know what other backends there would be, but maybe that could be set elsewhere, when initializing the library or something, just to keep the API simple. When I want to display a plot in a window I don't really care much how it's done, just that it gets done :)

Yep. I think another possible backend is visdom, https://github.com/fossasia/visdom
It's pretty fast and interactive - for example, you can stream training (or generated) images for every iteration with very minimal hit on performance.

But that's a maybe for the distant future, and for now, I think it's going to be just matplotlib. Agree, .show is a better name.

I can jump on this pretty soon. Or would you like to work on it?

You can go ahead!

Very interesting suggestions. I was off a few days, but wanted to add some more ideas for discussion:

  • The plot functions should automatically be able to handle values outside the [0,1] range without overflowing or artifacts. This could be easily done by always applying basic normalization before plotting: i= (i - i.min()) / (i.max() - i.min()). The rationale is that you never want to see an image with artifacts, I would always want to see it normalized in some way.
  • I would select the backend on calling the monkey patching, as I suggested: lt.monkey_patch(backend='matplotlib'), that way we can use exactly the same interface as the library now has. The rationale is that I would only want to use one backend at a time, so it should not be specified for every plotting call, but rather on importing / setting up the library.
  • At least for the matplotlib backend, I would like the possibility to have a bar showing the actual range of values (not the normalized ones). This is not achievable unless (de)normalization is performed within the plotting function (see next point). I would add it as something like input.rgb(title='', bar=True)
  • Since (de)normalization is always used for plotting purposes, I would not do something like t.n("sigma", 3), as I do not care about that intermediate (de)normalized tensor, I just want to plot it! So I would rather have something like we have now: input.rgb(denorm=([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])).
  • I would provide different normalization strategies, the default one being my first point in this list, and others being percentile-based (see code below, very useful in practice to get rid of extreme values, although slower to compute), or z-score based. I would like an interface like: input.rgb(norm='minmax'), where norm can be 'minmax', threshold number(s), or 'zscore', and the denorm parameter can stay as is.
def threshold_normalization(image, thres=[0.1, 99.9]):
    val_l, val_h = np.percentile(image, thres)
    image[image < val_l] = val_l
    image[image > val_h] = val_h
    return (image.astype(np.float32) - val_l) / (val_h - val_l + 1e-6)

Hmm, good points, let me go one by one.

  1. I'm not sure what's best as default - fit the whole range, or highlight overflowing bits. But it's a matter of choosing the default, this can be done later, based on use cases.

  2. I've implemented the config functionality, including a context manager:
    https://xl0.github.io/lovely-tensors/utils.config.html . I think the backend selection can go here, as it does not strictly depend on monkey patching. Agree, should not have to pass it on every call.

  3. Good point, I think it's a good idea to at least have the option for a colorbar. It should be possible to do so using .n(...) too. We can't modify the original tensor values inplace, so they have to stick around in memory anyway. n(...) will produce a tensor that will have an added attribute pointing to the original, and the downstream functions can check if the attribute is there to draw the colorbar. 🐵

  4. I actually did find it very useful to plot the result of the normalization. If .chans/.rgb on normalized data do not make total sense to me, I replace the call with .plt to see the distribution in relation to the [0,1] interval. It was definitely a useful debugging tool, and I'd like to keep it.

By the way, I've extended the interval to [-1. 1]. Here are my comments on the justification:
xl0/lovely-numpy#8

Would you have any comments on this choice?

  1. I like the percentile idea.
    For large input, we'd have to randomly pick a limited number of samples from the distribution, like it's done in .plt(). It will definitely be one of the options for normalization.

For the default - let's see. I'll make sure it's trivial to choose the right strategy, and for the default, I was thinking of making it dynamic, depending on the data distribution.

If the values have wild outliers (which seems to be the case for NN activations for example), using minmax normalization will just squish most values into a very narrow band around 0.

On the other hand, if the values are in the [0,1] interval, it does not make sense to re-scale them to [-1, 1], as it will no longer be obvious that they were all positives, and it will look counter-intuitive with a divergent colormap.

So, with no arguments, the normalization function will look at the data and pick a sensible method, while allowing a quick override.

Hmm?

commented
  • The plot functions should automatically be able to handle values outside the [0,1] range without overflowing or artifacts. This could be easily done by always applying basic normalization before plotting: i= (i - i.min()) / (i.max() - i.min()).

Strongly disagree. Normalization silently hides important information (images with different normalizations will appear identical). The default should never hide information. (This is especially important when plotting color images. Less so for channels and single-channel heatmaps.)

In my opinion, we should replace the values, that would overflow with some distinct color and print a warning in a style similar to the NaN/Inf warnings. This should be the default (imho).

I am not opposed to having automatic normalization as an optional parameter (via minmax or via quantiles or maybe even histogram normalization).

@xl0 @RuRo I agree with both your comments!

Indeed if going the .n() route, calling .rgb or .chans should not normalize the data in any way. I also didn't know about the color coding for NaN/Inf/outside of range, but as it is now, it is probably the best way of doing it! Furthermore, I would also add this color coding to .rgb, and briefly talk about it in the Index notebook, which is what I (most users) look at when first encountering the lib.

Finally, it might make sense to still be able to pass the arguments of set_config directly to monkey_patch(**kwargs), so that we need one less line of code for setting up the library. It is probably a bit messier than it is now, so I am not sure if it is worth it...

Yep, that's what it's like right now:
https://xl0.github.io/lovely-numpy/utils.colormap.html

Highlighting colors for < -1, 1< +/- inf and nan.
That's for .chans. For RGB, I'm not 100% sure what would be the best approach. Right now the values just roll over, which actually does make it visually clear that an overflow is happening. But I could also assign specific colours to the various invalid intervals. Of course, with RGB, the highlight colours will always overlap with valid ones. Have you seen any examples of this done well?

Hi!
I'm finally done with the matplotlib rework. Please take a look:
https://xl0.github.io/lovely-tensors/matplotlib.html#without-jupyter

I'll give Lovely JAX some love now, and will be back to work on the normalization stuff in about a week.

commented

@xl0

  1. For me
assert lovely_tensors.__version__ == '0.1.12'
lovely_tensors.set_config(fig_close=False, fig_show=True)
torch.randn(3, 1024, 1024).chans()

still doesn't open the plot (I am running in an IPython shell btw). But manually calling .show()

torch.randn(3, 1024, 1024).chans.fig.show()

does work. The numpy version also works as expected:

assert lovely_numpy.__version__ == '0.2.6'
lovely_numpy.set_config(fig_close=False, fig_show=True)
lovely_numpy.lo(np.random.randn(1024, 1024, 3)).chans()
  1. What do you think about dynamically checking, if the user is running in a non-Jupyter environment (i.e. IPython or a regular script) and making fig_show=True the default?

  2. You might want to do fig.show() instead of plt.show(). The latter will also show all other figures, that just happened to be open at that time, which might not be desirable.

  1. I'll see what's causing it tomorrow.

  2. I'm for it.

  3. My understanding is, fig.show() won't work without a call to plt.show() in a GUI backend:

https://matplotlib.org/stable/api/figure_api.html#matplotlib.figure.Figure.show

Is this correct?

commented

3. My understanding is, fig.show() won't work without a call to plt.show() in a GUI backend:

https://matplotlib.org/stable/api/figure_api.html#matplotlib.figure.Figure.show

Is this correct?

@xl0 Ah, my bad. Calling fig.show() works in both interactive IPython and interactive Python shells, but it does nothing in a standalone script. So a GUI isn't strictly required, but standalone scripts still need to call plt.show().

Bizarrely, some surface level googling seems to suggest, that matplotlib doesn't have a portable way to "show this specific figure only, don't touch the other ones". Huh.

@RuRo I fixed the issue with the figures not showing. Could you try the current git?

pip install git+https://github.com/xl0/lovely-tensors

Don't think there is a solution for fig.show(), I might play with it later.

I'll get back and add the dynamic fig_show default later. Or you can have a go at it.

commented

@xl0 yup, current master works as expected.