ar4 / deepwave

Wave propagation modules for PyTorch.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

CUDA out of memory when running Reverse-Time Migration of Marmousi example

XuVV opened this issue · comments

commented

Hi, Doctor. Sorry to bother you again.

When I tried to run Reverse-Time Migration of Marmousi example using the following code, it shows that CUDA out of memory.

Run optimisation/inversion

n_epochs = 1
n_batch = 60
n_shots_per_batch = (n_shots + n_batch - 1) // n_batch
for epoch in range(n_epochs):
    epoch_loss = 0
    # optimiser.zero_grad()
    for batch in range(n_batch):
        print(batch)
        optimiser.zero_grad()
        batch_start = batch * n_shots_per_batch
        batch_end = min(batch_start + n_shots_per_batch, n_shots)
        if batch_end <= batch_start:
            continue
        s = slice(batch_start, batch_end)

        simulated_data = scalar_born(v_mig.detach(), scatter, dx, dt,
                                     source_amplitudes=source_amplitudes[s].detach(),
                                     source_locations=source_locations[s].detach(),
                                     receiver_locations=receiver_locations[s].detach(),
                                     pml_freq=freq)
        loss = (1e9 * loss_fn(simulated_data[-1] * mask[s], observed_scatter_masked[s]))
        epoch_loss += loss.item()
        loss.backward()
        optimiser.step()
        # del simulated_data
        # torch.cuda.empty_cache()
    print(epoch_loss)

I found that it can be run at the 1st batch, the command "scalar_born" will take about 11002Mb space of GPU, but at the 2nd batch, instead of releasing this 11002Mb space (or let's say using this same space), the command "scalar_born" will take another about 11002Mb of GPU, then at and after the 3rd batch, the situation of taking GPU memory doesn't change anymore.

I am so confuesd about that why at the 2nd batch, it will take another space of GPU, is it normal?

Hello again,

PyTorch's caching, when it decides to free unused memory, and the use of optimizers that accumulate information over iterations, make it hard to predict memory usage. I see that you have already tried to force the cache to be emptied. You might try to expand that a bit to something like this to see if it helps (you will need to add import gc):

del loss, simulated_data
gc.collect()
torch.cuda.empty_cache()

If that is not sufficient, then you will probably have to use smaller batch sizes. You can still perform the optimizer step with gradients from the same number of shots, if you wish, by accumulating the gradients over multiple batches before performing a step. You could do that with something like this (where I accumulate the gradients over two batches, each half the size of yours, before performing an optimizer step):

n_epochs = 1
n_batch = 120
n_shots_per_batch = (n_shots + n_batch - 1) // n_batch
for epoch in range(n_epochs):
    epoch_loss = 0
    for outer_batch in range(n_batch//2):
        optimiser.zero_grad()
        for inner_batch in range(2):
            batch = outer_batch * 2 + inner_batch
            print(batch)
            batch_start = batch * n_shots_per_batch
            batch_end = min(batch_start + n_shots_per_batch, n_shots)
            if batch_end <= batch_start:
                continue
            s = slice(batch_start, batch_end)

            simulated_data = scalar_born(v_mig.detach(), scatter, dx, dt,
                                         source_amplitudes=source_amplitudes[s].detach(),
                                         source_locations=source_locations[s].detach(),
                                         receiver_locations=receiver_locations[s].detach(),
                                         pml_freq=freq)
            loss = (1e9 * loss_fn(simulated_data[-1] * mask[s], observed_scatter_masked[s]))
            epoch_loss += loss.item()
            loss.backward()
        optimiser.step()
    print(epoch_loss)

If you are going to perform an optimizer step after each batch (whether bigger batches, as you were doing, or when accumulating over multiple smaller batches as in my example above), then I suggest that you might want to randomise which shots are in each batch between epochs.

commented

Thank you very much, Doctor. using del loss, simulated_data gc.collect() torch.cuda.empty_cache() can solve this problem.