lucidrains / vit-pytorch

Implementation of Vision Transformer, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Cuda memory for 3D VIT

JesseZZZZZ opened this issue · comments

image
this 356GIB is a little stunning... I don't think I changed the original code enormously, so does anyone know that it is my mistake or the original itself needs such huge cuda memory? Thanks a lot !

@JesseZZZZZ

try

import torch
from vit_pytorch.simple_flash_attn_vit_3d import SimpleViT

v = SimpleViT(
    image_size = 128,          # image size
    frames = 16,               # number of frames
    image_patch_size = 16,     # image patch size
    frame_patch_size = 2,      # frame patch size
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 8,
    mlp_dim = 2048,
    use_flash_attn = True
)

video = torch.randn(4, 3, 16, 128, 128) # (batch, channels, frames, height, width)

preds = v(video) # (4, 1000)

should help with memory, but you'll still face the compute cost

@JesseZZZZZ

try

import torch
from vit_pytorch.simple_flash_attn_vit_3d import SimpleViT

v = SimpleViT(
    image_size = 128,          # image size
    frames = 16,               # number of frames
    image_patch_size = 16,     # image patch size
    frame_patch_size = 2,      # frame patch size
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 8,
    mlp_dim = 2048,
    use_flash_attn = True
)

video = torch.randn(4, 3, 16, 128, 128) # (batch, channels, frames, height, width)

preds = v(video) # (4, 1000)

should help with memory, but you'll still face the compute cost

Thank you so much! It does fix my problem to some extent!