CPU Memory leak when training any model
pujaltes opened this issue · comments
Describe the bug
I find that CPU memory grows linearly during training until resulting in an OOM error when training on Data Center GPU Max 1550 GPUs. This likely the same bug as #476 and #462.
The minimal reproducible example below results in a steady increase of 10MB/s of CPU memory on XPUs and completely stable CPU memory on Ampere GPUs. I'm using 2024.0 OneDNN OneAPI.
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from torch.utils.data import Dataset, DataLoader
from torch.autograd import Variable
import psutil
class TransformerModel(nn.Module):
def __init__(self, ntoken, ninp, nhead, nhid, nlayers, dropout=0.5):
super(TransformerModel, self).__init__()
self.src_mask = None
self.pos_encoder = PositionalEncoding(ninp, dropout)
encoder_layers = TransformerEncoderLayer(ninp, nhead, nhid, dropout)
self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
self.encoder = nn.Embedding(ntoken, ninp)
self.ninp = ninp
self.decoder = nn.Linear(ninp, ntoken)
self.init_weights()
def _generate_square_subsequent_mask(self, sz):
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
mask = mask.float().masked_fill(mask == 0, float("-inf")).masked_fill(mask == 1, float(0.0))
return mask
def init_weights(self):
initrange = 0.1
self.encoder.weight.data.uniform_(-initrange, initrange)
self.decoder.bias.data.zero_()
self.decoder.weight.data.uniform_(-initrange, initrange)
def forward(self, src):
if self.src_mask is None or self.src_mask.size(0) != len(src):
device = src.device
mask = self._generate_square_subsequent_mask(len(src)).to(device)
self.src_mask = mask
src = self.encoder(src) * math.sqrt(self.ninp)
src = self.pos_encoder(src)
output = self.transformer_encoder(src, self.src_mask)
output = self.decoder(output)
return output
class PositionalEncoding(nn.Module):
"Implement the PE function."
def __init__(self, d_model, dropout, max_len=5000):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout)
# Compute the positional encodings once in log space.
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer("pe", pe)
def forward(self, x):
x = x + Variable(self.pe[:, : x.size(1)], requires_grad=False)
return self.dropout(x)
class SyntheticDataset(Dataset):
def __init__(self, num_sequences, seq_length, ntoken):
self.num_sequences = num_sequences
self.seq_length = seq_length
self.sequences = torch.randint(0, ntoken, size=(num_sequences, seq_length), dtype=torch.long)
def __len__(self):
return self.num_sequences
def __getitem__(self, idx):
sequence = self.sequences[idx]
return sequence
if __name__ == "__main__":
device = "xpu"
if device == "xpu":
import intel_extension_for_pytorch as ipex
# Define model parameters
ntoken = 1000 # size of vocabulary
ninp = 512 # embedding dimension
nhead = 32 # number of heads in multi-head attention
nhid = 2048 # hidden dimension of feedforward network
nlayers = 60 # number of TransformerEncoderLayer in TransformerEncoder
dropout = 0.2 # dropout probability
# Create synthetic dataset and dataloader
num_sequences = 2000
seq_length = 600
dataset = SyntheticDataset(num_sequences, seq_length, ntoken)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)
# Initialize model
model = TransformerModel(ntoken, ninp, nhead, nhid, nlayers, dropout).to(device)
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Number of Model Parameters: {num_params:,}")
# Define loss function and optimizer
criterion = nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
if device == "xpu":
model, optimizer = ipex.optimize(model, optimizer=optimizer)
# Training loop
for epoch in range(500):
total_loss = 0
for batch in dataloader:
optimizer.zero_grad()
batch = batch.to(device)
output = model(batch)
loss = criterion(output.view(-1, ntoken), batch.view(-1))
loss.backward()
optimizer.step()
total_loss += loss.item()
# Print memory usage
print(f"Memory Usage: {psutil.Process().memory_info().rss / 1024 ** 3:.2f} GB")
# print(f"Batch Loss: {loss.item()}")
# print(f"Epoch {epoch+1}, Loss: {total_loss / len(dataloader)}")
Versions
Versions of relevant libraries:
[pip3] intel-extension-for-pytorch==2.0.120+xpu
[pip3] numpy==1.24.3
[pip3] torch==2.0.1a0+cxx11.abi
[pip3] torchvision==0.15.2a0+cxx11.abi
[conda] intel-extension-for-pytorch 2.0.120 py39_xpu_0 intel
[conda] mkl 2024.0.0 intel_49630 intel
[conda] mkl-dpcpp 2024.0.0 intel_49630 intel
[conda] mkl-service 2.4.0 py39h3539a15_40 intel
[conda] mkl_fft 1.3.6 py39h1d81ff8_61 intel
[conda] mkl_random 1.2.2 py39h5a378b4_81 intel
[conda] mkl_umath 0.1.1 py39h2b1685c_91 intel
[conda] numpy 1.24.3 py39ha320b8e_5 intel
[conda] numpy-base 1.24.3 py39hbac2b65_5 intel
[conda] onemkl-sycl-blas 2024.0.0 intel_49630 intel
[conda] onemkl-sycl-datafitting 2024.0.0 intel_49630 intel
[conda] onemkl-sycl-dft 2024.0.0 intel_49630 intel
[conda] onemkl-sycl-lapack 2024.0.0 intel_49630 intel
[conda] onemkl-sycl-rng 2024.0.0 intel_49630 intel
[conda] onemkl-sycl-sparse 2024.0.0 intel_49630 intel
[conda] onemkl-sycl-stats 2024.0.0 intel_49630 intel
[conda] onemkl-sycl-vm 2024.0.0 intel_49630 intel
[conda] pytorch 2.0.1 py39_xpu_1 intel
[conda] torchvision 0.15.2 py39_xpu_0 intel
Thanks for reporting, we are working on this issue, will update here when it is fixed.
Hey @huiyan2021,
Thank you for your prompt response! The issue disappears when using the latest version of IPEX and OneAPI 2024.1 (results included below).
Do we know where this bug originated from? Does this mean OneAPI 2024.0 (and by extension IPEX<2.1.2
) is deprecated? Given the size of the issue it would be a good idea to explicitly warn users about known issues in previous versions.
Output when using OneAPI 2024.1:
Number of Model Parameters: 190,168,040
Epoch 1, Memory Usage: 1.44 GB
Epoch 2, Memory Usage: 1.89 GB
Epoch 3, Memory Usage: 1.89 GB
Epoch 4, Memory Usage: 1.89 GB
Epoch 5, Memory Usage: 1.89 GB
Epoch 6, Memory Usage: 1.89 GB
Epoch 7, Memory Usage: 1.89 GB
Epoch 8, Memory Usage: 1.89 GB
Epoch 9, Memory Usage: 1.89 GB
Epoch 10, Memory Usage: 1.89 GB
Output with OneAPI 2024.0:
Number of Model Parameters: 190,168,040
Epoch 1, Memory Usage: 1.68 GB
Epoch 2, Memory Usage: 2.48 GB
Epoch 3, Memory Usage: 3.00 GB
Epoch 4, Memory Usage: 3.79 GB
Epoch 5, Memory Usage: 4.58 GB
Epoch 6, Memory Usage: 5.37 GB
Epoch 7, Memory Usage: 6.16 GB
Epoch 8, Memory Usage: 6.95 GB
Epoch 9, Memory Usage: 7.74 GB
Epoch 10, Memory Usage: 8.53 GB
Hi @pujaltes,
Do we know where this bug originated from?
FYI: 89b1a92#diff-bbb9194721e5e2097b5a098c6b3a135790b107c6333e45d8e08f40ed4e97553bR279
Does this mean OneAPI 2024.0 (and by extension IPEX<2.1.2) is deprecated? Given the size of the issue it would be a good idea to explicitly warn users about known issues in previous versions.
Good suggestion, but it's difficult to maintain all the old versions, we encourage users to use the latest version for bug fix and better performance.
@huiyan2021, thank you for your help and pointing me to the bug!