KeKsBoTer / torch-dwt

1D,2D,3D Discrete Wavelet Tansform (DWT) implemented in Pytorch

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

3D Discrete Wavelet Transform (DWT) im Pytorch

This package implements the 1D,2D,3D Discrete Wavelet Transform and inverse DWT (IDWT) in Pytorch. The package was heavily inspired by pytorch_wavelets and extends its functionality into the third dimension.

The wavelets are provided by the PyWavelets package.

All operations in this package are fully differentiable.

Installation

git clone https://github.com/KeKsBoTer/torch-dwt
cd torch-dwt
pip install -e .

Example Usage

3D

from torch_dwt.functional import dwt3,idwt3
import torch

# 8 images with 3 color channels and size of 100x100
x = torch.rand(8,3,100,100,100)
coefs = dwt3(x,"sym2") # coefs of shape (1,2,3,50)
# reconstruct signal from coefficients
y = idwt3(coefs,"sym2")

2D

from torch_dwt.functional import dwt2,idwt2
import torch

# 8 images with 3 color channels and size of 100x100
x = torch.rand(8,3,100,100)
coefs = dwt2(x,"db2") # coefs of shape (1,2,3,50)
# reconstruct signal from coefficients
y = idwt2(coefs,"db2")

1D

from torch_dwt.functional import dwt,idwt
import torch

# batch of size 8 with 3 channels
x = torch.rand(8,3,100)
coefs = dwt(x,"haar") # coefs of shape (1,2,3,50)
# reconstruct signal from coefficients
y = idwt(coefs,"haar")

Testing

For testing we compare our implementation againts PyWavelets. This command runs the tests:

# navigate into torch-dwt directory
pytest .

About

1D,2D,3D Discrete Wavelet Tansform (DWT) implemented in Pytorch


Languages

Language:Python 100.0%