facebookresearch / theseus

A library for differentiable nonlinear optimization

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Possible memory leak: Theseus doesn't release gpu memory after finished

EXing opened this issue · comments

❓ Questions and Help

def bundle_adjustment(global_map: Map, fix_pose: bool, cfg: omegaconf.OmegaConf,
                      start_timestamp: Optional[float] = None,
                      window_size: Optional[int] = None):
    """
    Args:
        global_map: global map
        cfg
        fix_pose: fix pose or not during BA
        start_timestamp: start point
        window_size: window size for local bundle adjustment
    """
    problem = Problem(global_map.k, global_map.p_unit_sphere_or_plane, fix_pose, global_map.depth_parameterization)
    if window_size is None:
        window_size = 1e4
    if start_timestamp is None:
        start_timestamp = global_map.keyframes.keys()[0]
    global_map.bfs_edge(start_timestamp, window_size, problem.add_node_cost_function, problem.add_edge_cost_function)

    optimizer = th.LevenbergMarquardt(
        problem.objective,
        linear_solver_cls=th.LUDenseSolver,
        linearization_cls=th.DenseLinearization,
    )
    theseus_layer = th.TheseusLayer(optimizer)
    with global_map.keyframes_mutex:
        theseus_outputs, info = theseus_layer.forward(optimizer_kwargs={
            "verbose": cfg.inner_optim.verbose,
            "adaptive_damping": True})

Before bundle_adjustment(global_map, True, cfg, timestamp, 2):

Sat Sep 16 13:54:46 2023       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.86.10              Driver Version: 535.86.10    CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA GeForce RTX 2080 ...    On  | 00000000:01:00.0 Off |                  N/A |
| 30%   34C    P2              49W / 250W |    406MiB /  8192MiB |      5%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                                         
+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|    0   N/A  N/A      1391      G   /usr/lib/xorg/Xorg                           45MiB |
|    0   N/A  N/A      2768      G   /usr/lib/xorg/Xorg                          109MiB |
|    0   N/A  N/A      3132      G   /usr/bin/gnome-shell                         31MiB |
|    0   N/A  N/A   2774370      G   /usr/lib/rustdesk/rustdesk                   14MiB |
|    0   N/A  N/A   2777179    C+G   ...kun/.virtualenvs/theseus/bin/python      191MiB |
+---------------------------------------------------------------------------------------+

After bundle_adjustment(global_map, True, cfg, timestamp, 2):

Sat Sep 16 13:55:18 2023       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.86.10              Driver Version: 535.86.10    CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA GeForce RTX 2080 ...    On  | 00000000:01:00.0 Off |                  N/A |
| 30%   35C    P8              24W / 250W |   2966MiB /  8192MiB |     10%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                                         
+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|    0   N/A  N/A      1391      G   /usr/lib/xorg/Xorg                           45MiB |
|    0   N/A  N/A      2768      G   /usr/lib/xorg/Xorg                          121MiB |
|    0   N/A  N/A      3132      G   /usr/bin/gnome-shell                         33MiB |
|    0   N/A  N/A   2774370      G   /usr/lib/rustdesk/rustdesk                   14MiB |
|    0   N/A  N/A   2777179    C+G   ...kun/.virtualenvs/theseus/bin/python     2737MiB |
+---------------------------------------------------------------------------------------+

Note: the variables are created and stored in another place, and have more lifetime than Theseus layer.

After bundle_adjustment(global_map, True, cfg, timestamp, 2):

Sat Sep 16 16:59:18 2023       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.86.10              Driver Version: 535.86.10    CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA GeForce RTX 2080 ...    On  | 00000000:01:00.0 Off |                  N/A |
| 30%   40C    P2              52W / 250W |    882MiB /  8192MiB |      7%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                                         
+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|    0   N/A  N/A      1391      G   /usr/lib/xorg/Xorg                           45MiB |
|    0   N/A  N/A      2768      G   /usr/lib/xorg/Xorg                          144MiB |
|    0   N/A  N/A      3132      G   /usr/bin/gnome-shell                         49MiB |
|    0   N/A  N/A   2850951      G   /usr/lib/rustdesk/rustdesk                   14MiB |
|    0   N/A  N/A   2851862      G   /usr/lib/firefox/firefox                     10MiB |
|    0   N/A  N/A   3489809    C+G   ...kun/.virtualenvs/theseus/bin/python      603MiB |
+---------------------------------------------------------------------------------------+

with

 with torch.no_grad():
       with global_map.keyframes_mutex:
           print("Start theseus iteration:")
           theseus_outputs, info = theseus_layer.forward(optimizer_kwargs={
               "verbose": cfg.inner_optim.verbose,
               "adaptive_damping": True})
           print("End theseus iteration.")
   torch.cuda.empty_cache()
def print_gpu_memory_usage(device=None):
    if device is None:
        device = torch.cuda.current_device()
    allocated_memory = torch.cuda.memory_allocated(device)
    cached_memory = torch.cuda.memory_cached(device)
    print(f"GPU Memory Allocated: {allocated_memory / (1024 ** 2):.2f} MB")
    print(f"GPU Memory Cached: {cached_memory / (1024 ** 2):.2f} MB")

still with small unknown memory leak

GPU Memory Allocated: 15.21 MB
GPU Memory Cached: 42.00 MB
Start theseus iteration:
Nonlinear optimizer. Iteration: 0. Error: 1787063296.0
Nonlinear optimizer. Iteration: 1. Error: 226572944.0
Nonlinear optimizer. Iteration: 2. Error: 98107632.0
Nonlinear optimizer. Iteration: 3. Error: 37768908.0
Nonlinear optimizer. Iteration: 4. Error: 96250.1640625
Nonlinear optimizer. Iteration: 5. Error: 96250.1640625
End theseus iteration.
GPU Memory Allocated: 29.39 MB
GPU Memory Cached: 1510.00 MB

Hi @EXing. I see you are using LUDenseSolver and DenseLinearization. With these options all memory management relies on the underlying one by torch (this path doesn't use any custom C++ code) . Are you hitting out of memory errors? IIRC, torch might leave CUDA memory occupied by tensors for a while if it doesn't need it, and we don't release any tensors manually. Also, we don't allocate any tensors in global scope either, but I can double check this, just in case.

def print_gpu_memory_usage(device=None):
    if device is None:
        device = torch.cuda.current_device()
    allocated_memory = torch.cuda.memory_allocated(device)
    cached_memory = torch.cuda.memory_cached(device)
    print(f"GPU Memory Allocated: {allocated_memory / (1024 ** 2):.2f} MB")
    print(f"GPU Memory Cached: {cached_memory / (1024 ** 2):.2f} MB")

still with small unknown memory leak

GPU Memory Allocated: 15.21 MB
GPU Memory Cached: 42.00 MB
Start theseus iteration:
Nonlinear optimizer. Iteration: 0. Error: 1787063296.0
Nonlinear optimizer. Iteration: 1. Error: 226572944.0
Nonlinear optimizer. Iteration: 2. Error: 98107632.0
Nonlinear optimizer. Iteration: 3. Error: 37768908.0
Nonlinear optimizer. Iteration: 4. Error: 96250.1640625
Nonlinear optimizer. Iteration: 5. Error: 96250.1640625
End theseus iteration.
GPU Memory Allocated: 29.39 MB
GPU Memory Cached: 1510.00 MB

As shown in this, the allocated memory increased after optimization with torch.no_grad().

It makes sense that memory increases because both cost function vectorization and vmap allocate new tensors that batch everything together. My question is if you are seeing that this memory is not getting eventually released by torch automatically if it's needed later.

After each BA, All the tensors of Theseus will no longer used except the variables stored outside bundle_adjustment().
So in theory, there shouldn't be any increase in Allocated Memory

Hi @EXing. I took a look today, and confirmed that there is some sort of leak when using cost function vectorization, which I haven't been able to find the cause of yet. When using vectorize=False in the Theseus layer constructor there doesn't seem to be a leak; however I have to manually call gc.collect() and torch.cuda.empty_cache() for the memory stats to report 0 again.

I looked a bit more today, and couldn't find a reason for the apparent memory leak on Theseus side. The following script illustrates what I'm talking about.

import gc
import theseus as th
import torch


def print_mem(msg):
    m1 = torch.cuda.memory_allocated() / 1024
    m2 = torch.cuda.memory_reserved() / 1024
    print(msg, m1, m2)


def test_mem():
    vs = []
    nvars = 1000
    dim = 1000
    for i in range(nvars):
        vs.append(th.Vector(dim, name=f"v{i}"))

    cfs = []
    for i in range(nvars // 2):
        cfs.append(
            th.Difference(
                vs[i], vs[i + nvars // 2], th.ScaleCostWeight(1.0), name=f"cf{i}"
            )
        )

    # Mimic what vectorization does
    x = th.Vector(dim, name="x")
    y = th.Vector(dim, name="y")
    full = th.Difference(x, y, th.ScaleCostWeight(1.0), name="full")
    x_list = []
    y_list = []
    for cf in cfs:
        x_list.append(cf.var.tensor)
        y_list.append(cf.target.tensor)
    x.update(torch.cat(x_list, dim=0).cuda())
    y.update(torch.cat(x_list, dim=0).cuda())
    full.weight.scale.to("cuda:0")
    j, e = full.jacobians()
    print_mem("inside function")


print_mem("before theseus")
with torch.no_grad():
    test_mem()
print_mem("after theseus")
gc.collect()
torch.cuda.empty_cache()
print_mem("after clear cache")

The output is

before theseus 0.0 0.0
inside function 1967973.0 5883904.0
after theseus 8320.0 5883904.0
after clear cache 8320.0 20480.0

As you can see, manually clearing garbage collector and cuda cache releases huge amounts of memory. However, there are a few MBs remaining that I'm not able to clear. They seem somehow related to the matrix product in the jacobian of our local() implementation, but AFAIK it's some obscure torch behavior rather than something we are doing on our side.

Unless you are having other issues with memory, @EXing, I'm planning to close this issue soon.