rachtibat / zennit-crp

An eXplainable AI toolkit with Concept Relevance Propagation and Relevance Maximization

Home Page:https://www.nature.com/articles/s42256-023-00711-8

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

run_distributed method does not consider batch size

maxdreyer opened this issue · comments

Hi @rachtibat,

the run_distributed method of the FeatureVisualization class does not take into account the actual batch_size for the multi-target case.

Maybe include something like:

if n_samples > batch_size:
    batches_ = math.ceil(len(conditions) / batch_size)
else:
    batches_ = 1

for b_ in range(batches_):
    data_broadcast_ = data_broadcast[b_ * batch_size: (b_ + 1) * batch_size]
    # print(len(conditions), len(data_broadcast_))
    conditions_ = conditions[b_ * batch_size: (b_ + 1) * batch_size]
    # dict_inputs is linked to FeatHooks
    dict_inputs["sample_indices"] = sample_indices[b_ * batch_size: (b_ + 1) * batch_size]
    dict_inputs["targets"] = targets[b_ * batch_size: (b_ + 1) * batch_size]

# composites are already registered before
    self.attribution(data_broadcast_, conditions_, None, exclude_parallel=False)

This would fix some GPU memory issue of mine.

Best,
Max