scikit-hep / hist

Histogramming for analysis powered by boost-histogram

Home Page:https://hist.readthedocs.io

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

[FEATURE] Splom plot

nsmith- opened this issue · comments

Describe the problem, if any, that your feature request is related to

To visualize a high-dimensional histogram, it would be nice to conveniently create the so-called "splom plot" which is a matrix of 2D projections of the histogram onto each pair of axes, with a 1D histogram along the diagonal.

Describe the feature you'd like

An example plot would be as follows

import numpy as np
import matplotlib.pyplot as plt
import hist

h = (
    hist.Hist.new
    .Reg(10, 0, 1, name="var 1")
    .Reg(30, -3, 3, name="var 2")
    .Reg(10, 0, 1, name="var 3")
    .Double()
    .fill(*np.random.multivariate_normal([0.2, 0.0, 0.7], np.eye(3), size=100_000).T)
)

naxes = len(h.axes)
fig, axes = plt.subplots(naxes, naxes, figsize=(4*naxes, 4*naxes), facecolor="w")
for i, axrow in enumerate(axes):
    for j, ax in enumerate(axrow):
        if j > i:
            ax.axis("off")
        elif j == i:
            h.project(h.axes[i].name).plot(ax=ax)
        else:
            h.project(h.axes[j].name, h.axes[i].name).plot(ax=ax)

fig.tight_layout()

which produces
image

There is some work left in aligning the axes well. Commenting out fig.tight_layout() looks even worse.

Describe alternatives, if any, you've considered

Alternative solutions exists in scientific python ecosystem, e.g. in pandas or corner but they require unbinned inputs.

Improved example:

import numpy as np
import matplotlib.pyplot as plt
import hist

h = (
    hist.Hist.new
    .Reg(10, 0, 1, name="var 1")
    .Reg(30, -3, 3, name="var 2")
    .Reg(10, 0, 1, name="var 3")
    .Double()
    .fill(*np.random.multivariate_normal([0.2, 0.0, 0.7], np.diag([0.4, 1.0, 0.1]), size=100_000).T)
)

naxes = len(h.axes)
fig, axes = plt.subplots(naxes, naxes, figsize=(4*naxes, 4*naxes), facecolor="w")
for i, axrow in enumerate(axes):
    for j, ax in enumerate(axrow):
        if j > i:
            ax.axis("off")
        elif j == i:
            hp = h.project(h.axes[i].name)
            hp.plot(ax=ax)
            ax.set_xlim(hp.axes[0].edges[0], hp.axes[0].edges[-1])
        else:
            hp = h.project(h.axes[j].name, h.axes[i].name)
            hp.plot(ax=ax, cbar=False)
            ax.set_xlim(hp.axes[0].edges[0], hp.axes[0].edges[-1])
            ax.set_ylim(hp.axes[1].edges[0], hp.axes[1].edges[-1])

fig.tight_layout()

I think the colorbar is probably optional in this setting.
A question: does h.project include overflow bins on the other axes? Regardless, I think probably there should be a flag to enable/disable the overflow inclusion.

It was also suggested that perhaps the corner authors would be willing to provide a pre-binned API. I'm not sure how much work that would be and how much of its resulting functionality would end up overlapping with mplhep.