understandable-machine-intelligence-lab / Quantus

Quantus is an eXplainable AI toolkit for responsible evaluation of neural network explanations

Home Page:https://quantus.readthedocs.io/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Possible error in the application of aggregate_func in the 'Metric' class found in the 'base.py' script, especialy while using numpy methods.

Shreyas-Gururaj opened this issue · comments

Screenshot from 2023-01-10 01-02-41

Please confirm whether the mean should only be aggregated over the axis=0 direction, i.e along the number of steps (Fraction of pixels flipped)

Essentially the line 235 would be replaced by line 239 in the attached reference picture.

Thanks,
Shreyas Gururaj.

@Shreyas-Gururaj Thanks for your feedback! Could you provide more detail about the metric you were trying to use, and maybe a code snippet that would help us to recreate the error? That would be very helpful.

@dilyabareeva I was trying to use pixel flipping metric, with the 'explain_func' argument set to 'generate_zennit_explanation' and passing my own LRP composite and canonizer to the 'explain_func_kwargs' argument. The code snippet to reproduce this error would be a very large snippet.

In short the aggregate function np.mean applied in the metric class returns a numpy array with a single element, instead of returning one element for each step of the pixel flipping. So, my guess was to take the np.mean along the direction of each individual steps accross each batch (i.e axis=0 direction).

@Shreyas-Gururaj Quantus provides flexibility in terms of aggregate function choice. For your case, I would suggest passing aggregate_func=lambda x: numpy.mean(x, axis=0) to the metric. Changing the base class to include axis=0 as you suggest might not make sense for all the metrics. Hope this helps.

@dilyabareeva Hi, Thanks for the update. Yes, that might break other running codes. I would make the necessary changes.