Error in `predict_profile` when a Dataframe with MultiIndex is Used
CahidArda opened this issue · comments
I am calling predict_profile
method with a dataframe which has multiple columns as index. This results in an error.
How to Replicate
This occurs when pandas version is after 1.3.0. It doesn't occur in 1.2.5 and before.
import numpy as np
data = np.array([[242,902,3,435],
[125,684,3,143],
[162,284,3,124],
[712,844,3,145],
[122,864,3,114],
[155,100,3,25]])
target = np.array([723,554,932,543,654,345])
import pandas as pd
data = pd.DataFrame(data, columns=[f"col{i}" for i in range(4)])
data = data.set_index(["col1", "col2"])
from sklearn.tree import DecisionTreeClassifier
clf = DecisionTreeClassifier()
clf.fit(data, target)
import dalex as dx
exp = dx.Explainer(clf, data, target)
pred_profile = exp.predict_profile(data, y=target)
from dalex.model_explanations import AggregatedProfiles
aggregated_profile = AggregatedProfiles()
aggregated_profile.fit(pred_profile)
aggregated_profile.plot(show=True)
Error Message
Calculating ceteris paribus: 0%| | 0/2 [00:00<?, ?it/s]
Traceback (most recent call last):
File "file.py", line 20, in <module>
exp.predict_profile(data, y=target)
File "C:\Users\user\Documents\GitHub\DALEX\python\dalex\dalex\_explainer\object.py", line 404, in predict_profile
_predict_profile.fit(self, new_observation, y, verbose)
File "C:\Users\user\Documents\GitHub\DALEX\python\dalex\dalex\predict_explanations\_ceteris_paribus\object.py", line 126, in fit
self.result, self.new_observation = utils.calculate_ceteris_paribus(
File "C:\Users\user\Documents\GitHub\DALEX\python\dalex\dalex\predict_explanations\_ceteris_paribus\utils.py", line 23, in calculate_ceteris_paribus
profiles = calculate_variable_profile(explainer.predict_function,
File "C:\Users\user\Documents\GitHub\DALEX\python\dalex\dalex\predict_explanations\_ceteris_paribus\utils.py", line 66, in calculate_variable_profile
profile.append(single_variable_profile(predict_function, model, data, variable, split_points))
File "C:\Users\user\Documents\GitHub\DALEX\python\dalex\dalex\predict_explanations\_ceteris_paribus\utils.py", line 112, in single_variable_profile
new_data.loc[:, '_ids_'] = ids
File "C:\Users\user\anaconda3\envs\dalex\lib\site-packages\pandas\core\indexing.py", line 723, in __setitem__
iloc._setitem_with_indexer(indexer, value, self.name)
File "C:\Users\user\anaconda3\envs\dalex\lib\site-packages\pandas\core\indexing.py", line 1667, in _setitem_with_indexer
self.obj[key] = value
File "C:\Users\user\anaconda3\envs\dalex\lib\site-packages\pandas\core\frame.py", line 3607, in __setitem__
self._set_item(key, value)
File "C:\Users\user\anaconda3\envs\dalex\lib\site-packages\pandas\core\frame.py", line 3779, in _set_item
value = self._sanitize_column(value)
File "C:\Users\user\anaconda3\envs\dalex\lib\site-packages\pandas\core\frame.py", line 4505, in _sanitize_column
return sanitize_array(value, self.index, copy=True, allow_2d=True)
File "C:\Users\user\anaconda3\envs\dalex\lib\site-packages\pandas\core\construction.py", line 500, in sanitize_array
data = extract_array(data, extract_numpy=True)
File "C:\Users\user\anaconda3\envs\dalex\lib\site-packages\pandas\core\construction.py", line 423, in extract_array
obj = obj.array
File "C:\Users\user\anaconda3\envs\dalex\lib\site-packages\pandas\core\indexes\multi.py", line 725, in array
raise ValueError(
ValueError: MultiIndex has no single backing array. Use 'MultiIndex.to_numpy()' to get a NumPy array of tuples.
Solution
I have found the same error in the issues of another project and they suggest downgrading pandas to 1.2.4. Downgrading pandas to 1.2.4 solved the issue for me too.
I noticed another thing while looking at the single_variable_profile
function in predict_explanations._ceteris_paribus.utils
file. Changing the following line from:
ids = np.repeat(data.index, split_points.shape[0])
to:
ids = np.repeat(data.index.values, split_points.shape[0])
solves the issue in pandas versions 1.2.5 and 1.3.0.
Hi @CahidArda, would you like to make a pull request with the solution?
Hello @hbaniecki, I have created a PR for the change.