alexk101 / torch_migraphx

Libraries integrating migraphx with pytorch

Repository from Github https://github.comalexk101/torch_migraphxRepository from Github https://github.comalexk101/torch_migraphx

Torch-MIGraphX

Torch-MIGraphX integrates AMD's graph inference engine with the PyTorch ecosystem. It provides a mgx_module object that may be invoked in the same manner as any other torch module, but utilizes the MIGraphX inference engine internally.

This library currently supports two paths for lowering:

  1. FX Tracing: Uses tracing API provided by the torch.fx library.
  2. Dynamo Backend: Importing torch_migraphx automatically registers the "migraphx" backend that can be used with the torch.compile API.

Getting Started

Docker

The simplest and recommended way to get started is using the provided Dockerfile. Build using:

./build_image.sh

Start container using:

sudo docker run -it --network=host --device=/dev/kfd --device=/dev/dri --group-add=video --ipc=host --cap-add=SYS_PTRACE --security-opt seccomp=unconfined torch_migraphx

The default Dockerfile builds on the nightly pytorch container and installs the latest source version of MIGraphX and torch_migraphx. For more builds refer to the docker directory.

Install From Source

Install Pre-reqs:

Build and install from source

If you wish to set a custom path for your cmake install

export CMAKE_BIN=/path/to/my/cmake

If you wish to set a custom number of processes for building (otherwise all will be used).

export CMAKE_BIN=/path/to/my/cmake

Then to install

git clone https://github.com/ROCmSoftwarePlatform/torch_migraphx.git
cd ./torch_migraphx/py
export TORCH_CMAKE_PATH=$(python -c "import torch; print(torch.utils.cmake_prefix_path)")
pip install .

Example Usage

# FX Tracing
torch_migraphx.fx.lower_to_mgx(torch_model, sample_inputs)

# Dynamo Backend
torch.compile(torch_model, backend="migraphx")

Lower resnet50 using FX Tracing

import torch
import torchvision
import torch_migraphx

resnet = torchvision.models.resnet50()
sample_input = torch.randn(2, 3, 64, 64)
resnet_mgx = torch_migraphx.fx.lower_to_mgx(resnet, [sample_input])
result = resnet_mgx(sample_input)

Lower densenet using torch.compile

import torch
import torchvision
import torch_migraphx

densenet = torchvision.models.densenet161().cuda()
sample_input = torch.randn(2, 3, 512, 512).cuda()
densenet_mgx = torch.compile(densenet, backend="migraphx")
result = densenet_mgx(sample_input.cuda())

For more examples please refer to the examples directory.

About

Libraries integrating migraphx with pytorch

License:BSD 3-Clause "New" or "Revised" License


Languages

Language:Python 96.6%Language:C++ 1.9%Language:Dockerfile 0.9%Language:Shell 0.5%Language:CMake 0.1%