google-parfait / tensorflow-federated

An open-source framework for machine learning and other computations on decentralized data.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Exploding memory while training federated model on FLAIR Dataset

christymarc opened this issue · comments

I am training a federated model on the FLAIR dataset and hitting the limit of my machine's allotted CPU memory usage. I noticed upon every successive iteration of federated training that CPU memory usage is accumulating over 500 iterations of FLAIR training from 1GB of usage in the 1st round to 50GB in the 500th round. Additionally, if I loop to train multiple FLAIR models in the same script, the CPU memory usage continues to accumulate from 50GB+. This function is how I am tracking memory usage:

def memory(memory_log):
    pid = os.getpid()
    py = psutil.Process(pid)
    memoryUse = py.memory_info()[0]/2.**30  # memory use in GB
    memory_log.write(f'memory use:{memoryUse}\n')

I have isolated the memory issue to the following line of code, where I create the client_data:

cache_dir = os.path.join(os.path.expanduser('~'), '.tff')
train_set, valid_set, _ = tff.simulation.datasets.flair.load_data(cache_dir=cache_dir)

...

sampled_clients = np.random.choice(
            train_set.client_ids, size=NUM_CLIENTS, replace=False
        )

# HERE IS THE PROBLEM
client_data = [train_set.create_tf_dataset_for_client(x) for x in sampled_clients]

Even when I comment out all the other code related to training the federated model, if the client_data line of code is not commented out, the rapid memory issue persists.

I am wondering if there is anything I can do to stop memory usage from accumulating like this so I can train federated models on the FLAIR dataset.

Additional Info:
I have tried other datasets like CelebA and seen that there CPU memory usage also accumulates but at a much slower pace (I assume because of the difference in dataset sizes). I have also tried using the same list of sampled_clients each time, but this also does not alleviate the memory issue.

Environment:

  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Pop!_OS 22.04 LTS
  • Python package versions (e.g., TensorFlow Federated, TensorFlow): tff 0.63.0, tensorflow-gpu 2.13.1, tensorflow-privacy 0.8.11
  • Python version: 3.9.18
  • Bazel version (if building from source): N/A
  • CUDA/cuDNN version: cudatoolkit 11.8.0, cudnn 8.8.0.121

Thank you for the details! I have a hypothesis about this, looking into it now.

Just to clarify: The memory blowup happens if you do something like

sampled_clients = ...  # Fix a set of clients
for _ in range(100):
  client_data = [train_set.create_tf_dataset_for_client(x) for x in sampled_clients]
  # Do some model training with client_data

Does it happen if you fix the client datasets across all rounds, e.g.:

sampled_clients = ...  # Fix a set of clients
client_data = [train_set.create_tf_dataset_for_client(x) for x in sampled_clients]
for _ in range(100):
  # Do some model training with client_data

@christymarc I believe that 1ab31a2 should resolve your issue.

If you'd like to use it with an earlier TFF version, you should be able to do something like the following:

First, sub-class FilePerUserClientData to avoid the issue fixed in 1ab31a2:

class StreamlinedFilePerUserClientData(tff.simulation.datasets.FilePerUserClientData):

  def __init__(
      self,
      client_ids_to_files: Mapping[str, str],
      dataset_fn: Callable[[str], tf.data.Dataset],
  ):
    self._client_ids_to_files = client_ids_to_files
    self._dataset_fn = dataset_fn
    super().__init__(client_ids_to_files, dataset_fn)

  def create_tf_dataset_for_client(self, client_id: str) -> tf.data.Dataset:
    if client_id not in self.client_ids:
      raise ValueError(
          'ID [{i}] is not a client in this ClientData. See '
          'property `client_ids` for the list of valid ids.'.format(i=client_id)
      )
    return self._dataset_fn(self._client_ids_to_files[client_id])

Then, to load flair, do the following as usual:

tff.simulation.datasets.flair.download_data(data_dir)
tff.simulation.datasets.flair.download_data(data_dir, cache_dir)

Then follow the load_data code here but replace FilePerUserClientData with StreamlinedFilePerUserClientData from above.

Please let me know if that solves your problem.

@christymarc All tests I've run so far suggest that this has been fixed by the aforementioned change. I'm marking this as resolved for now, but if you continue to hit this issue even with the change (or workaround above), please feel free to re-open.