TSR generates error
hugo-rddi opened this issue · comments
Hello,
Some errors are occurring during TSR execution. We tried different configurations (tried the different saliency methods available, FO, IG, SVS, ...).
Context:
The following cell worked.
int_mod = TSR(
model,
train_x.shape[-1],
train_x.shape[-2],
method='IG',
mode='feat'
)
But this one didn't :
exp=int_mod.explain(
input_x,
labels=class_target,
TSR=True
)
Error :
147 elif self.method == "IG":
148 base = baseline_single
--> 149 attributions = self.Grad.attribute(
150 input, baselines=baseline_single, target=labels
151 )
152 elif self.method == "DL":
153 base = baseline_single
41 def wrapper(*args, **kwargs):
---> 42 return func(*args, **kwargs)
...
527 )
528 else:
--> 529 raise AssertionError(f"Target type {type(target)} is not valid.")
AssertionError: Target type is not valid.
Any idea what can be the problem here ?
Model, train_x, labels and so on are not the origin of the problem given the fact that they respect the same format as expected and worked in NativeGuide execution.
Thanks in advance !
Hi,
did you parse the label explicitly into int ? int(np.argmax(y_target,
axis=1)[0])`
The base library (captum) is unfortunately pretty strict with that.
Really appreciate your feedback, it really helps us to improve the library. I will add some preprocessing for that in the next release. So that others do not run into the same issues.
I added some preprocessing that should solve it:
pip install https://github.com/fzi-forschungszentrum-informatik/TSInterpret/archive/refs/heads/main.zip
As soon as I have additional time I will also improve the error handeling (I know there is still a lot to do ;)).
Again you were right ! Needed to be strictly casted as int.
Thanks, have a nice day
Nice.