caiyuanhao1998 / MST-plus-plus

"MST++: Multi-stage Spectral-wise Transformer for Efficient Spectral Reconstruction" (CVPRW 2022) & (Winner of NTIRE 2022 Spectral Recovery Challenge) and a toolbox for spectral reconstruction

Home Page:https://arxiv.org/abs/2204.07908

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

How to get the predict results

randomNNN opened this issue · comments

Thanks for your open-source code. The MST++ is an amzing project in HSI reconstruction scene. But your code only have train and test code, which not contains the predict code. I'm a people of a new type of HSI reconstruction and I don't have any idea about the predict results. So, May you open your predict code in your repository?

Hi, we have uploaded the predicting code and updatad the README in our repo. You can reconstruct your RGB image by the following commands:

(1) Download the pretrained model zoo from (Google Drive / Baidu Disk, code: mst1) and place them to /MST-plus-plus/predict_code/model_zoo/.

(2) Run the following command to reconstruct your own RGB image.

cd /MST-plus-plus/predict_code/

# reconstruct by MST++
python test.py --rgb_path ./demo/ARAD_1K_0912.jpg  --method mst_plus_plus --pretrained_model_path ./model_zoo/mst_plus_plus.pth --outf ./exp/mst_plus_plus/  --gpu_id 0

# reconstruct by MST-L
python test.py --rgb_path ./demo/ARAD_1K_0912.jpg  --method mst --pretrained_model_path ./model_zoo/mst.pth --outf ./exp/mst/  --gpu_id 0

# reconstruct by MIRNet
python test.py --rgb_path ./demo/ARAD_1K_0912.jpg  --method mirnet --pretrained_model_path ./model_zoo/mirnet.pth --outf ./exp/mirnet/  --gpu_id 0

# reconstruct by HINet
python test.py --rgb_path ./demo/ARAD_1K_0912.jpg  --method hinet --pretrained_model_path ./model_zoo/hinet.pth --outf ./exp/hinet/  --gpu_id 0

# reconstruct by MPRNet
python test.py --rgb_path ./demo/ARAD_1K_0912.jpg  --method mprnet --pretrained_model_path ./model_zoo/mprnet.pth --outf ./exp/mprnet/  --gpu_id 0

# reconstruct by Restormer
python test.py --rgb_path ./demo/ARAD_1K_0912.jpg  --method restormer --pretrained_model_path ./model_zoo/restormer.pth --outf ./exp/restormer/  --gpu_id 0

# reconstruct by EDSR
python test.py --rgb_path ./demo/ARAD_1K_0912.jpg --method edsr --pretrained_model_path ./model_zoo/edsr.pth --outf ./exp/edsr/  --gpu_id 0

# reconstruct by HDNet
python test.py --rgb_path ./demo/ARAD_1K_0912.jpg  --method hdnet --pretrained_model_path ./model_zoo/hdnet.pth --outf ./exp/hdnet/  --gpu_id 0

# reconstruct by HRNet
python test.py --rgb_path ./demo/ARAD_1K_0912.jpg  --method hrnet --pretrained_model_path ./model_zoo/hrnet.pth --outf ./exp/hrnet/  --gpu_id 0

# reconstruct by HSCNN+
python test.py --rgb_path ./demo/ARAD_1K_0912.jpg  --method hscnn_plus --pretrained_model_path ./model_zoo/hscnn_plus.pth --outf ./exp/hscnn_plus/  --gpu_id 0

Please replace './demo/ARAD_1K_0912.jpg' with your RGB image path. The reconstructed results will be saved in /MST-plus-plus/predict_code/exp/.

Hi, we have uploaded the predicting code and updatad the README in our repo. You can reconstruct your RGB image by the following commands:

(1) Download the pretrained model zoo from (Google Drive / Baidu Disk, code: mst1) and place them to /MST-plus-plus/predict_code/model_zoo/.

(2) Run the following command to reconstruct your own RGB image.

cd /MST-plus-plus/predict_code/

# reconstruct by MST++
python test.py --rgb_path ./demo/ARAD_1K_0912.jpg  --method mst_plus_plus --pretrained_model_path ./model_zoo/mst_plus_plus.pth --outf ./exp/mst_plus_plus/  --gpu_id 0

# reconstruct by MST-L
python test.py --rgb_path ./demo/ARAD_1K_0912.jpg  --method mst --pretrained_model_path ./model_zoo/mst.pth --outf ./exp/mst/  --gpu_id 0

# reconstruct by MIRNet
python test.py --rgb_path ./demo/ARAD_1K_0912.jpg  --method mirnet --pretrained_model_path ./model_zoo/mirnet.pth --outf ./exp/mirnet/  --gpu_id 0

# reconstruct by HINet
python test.py --rgb_path ./demo/ARAD_1K_0912.jpg  --method hinet --pretrained_model_path ./model_zoo/hinet.pth --outf ./exp/hinet/  --gpu_id 0

# reconstruct by MPRNet
python test.py --rgb_path ./demo/ARAD_1K_0912.jpg  --method mprnet --pretrained_model_path ./model_zoo/mprnet.pth --outf ./exp/mprnet/  --gpu_id 0

# reconstruct by Restormer
python test.py --rgb_path ./demo/ARAD_1K_0912.jpg  --method restormer --pretrained_model_path ./model_zoo/restormer.pth --outf ./exp/restormer/  --gpu_id 0

# reconstruct by EDSR
python test.py --rgb_path ./demo/ARAD_1K_0912.jpg --method edsr --pretrained_model_path ./model_zoo/edsr.pth --outf ./exp/edsr/  --gpu_id 0

# reconstruct by HDNet
python test.py --rgb_path ./demo/ARAD_1K_0912.jpg  --method hdnet --pretrained_model_path ./model_zoo/hdnet.pth --outf ./exp/hdnet/  --gpu_id 0

# reconstruct by HRNet
python test.py --rgb_path ./demo/ARAD_1K_0912.jpg  --method hrnet --pretrained_model_path ./model_zoo/hrnet.pth --outf ./exp/hrnet/  --gpu_id 0

# reconstruct by HSCNN+
python test.py --rgb_path ./demo/ARAD_1K_0912.jpg  --method hscnn_plus --pretrained_model_path ./model_zoo/hscnn_plus.pth --outf ./exp/hscnn_plus/  --gpu_id 0

Please replace './demo/ARAD_1K_0912.jpg' with your RGB image path. The reconstructed results will be saved in /MST-plus-plus/predict_code/exp/.

Thanks for your reply!!!

I feel so sorry to bother you again. But when I run your predict code in the virtual environment torch1.7.1+torchvision0.8. Here is a bug!!!
屏幕截图 2022-05-13 234758
Can you help me to solve this bug?

Hi, our method is trained and tested in Linux and we have not tested it in Windows.
It seems that there is something wrong with your CUDA. Have you correctly installed your CUDA?
You can:

  1. Run our code in the Linux system.
  2. Or you can try to run the code on the CPU, you need to make the following modifications:
    (1) Replace predict_code/architecture/__init__.py with:
import torch
from .edsr import EDSR
from .HDNet import HDNet
from .hinet import HINet
from .hrnet import SGN
from .HSCNN_Plus import HSCNN_Plus
from .MIRNet import MIRNet
from .MPRNet import MPRNet
from .MST import MST
from .MST_Plus_Plus import MST_Plus_Plus
from .Restormer import Restormer

def model_generator(method, pretrained_model_path=None):
    if method == 'mirnet':
        model = MIRNet(n_RRG=3, n_MSRB=1, height=3, width=1)
    elif method == 'mst_plus_plus':
        model = MST_Plus_Plus()
        # model = MST_Plus_Plus()
    elif method == 'mst':
        model = MST(dim=31, stage=2, num_blocks=[4, 7, 5])
    elif method == 'hinet':
        model = HINet(depth=4)
    elif method == 'mprnet':
        model = MPRNet(num_cab=4)
    elif method == 'restormer':
        model = Restormer()
    elif method == 'edsr':
        model = EDSR()
    elif method == 'hdnet':
        model = HDNet()
    elif method == 'hrnet':
        model = SGN()
    elif method == 'hscnn_plus':
        model = HSCNN_Plus()
    else:
        print(f'Method {method} is not defined !!!!')
    if pretrained_model_path is not None:
        print(f'load model from {pretrained_model_path}')
        checkpoint = torch.load(pretrained_model_path, map_location=lambda storage, loc: storage)
        model.load_state_dict({k.replace('module.', ''): v for k, v in checkpoint['state_dict'].items()},
                              strict=True)
    return model

(2) Replace prediect_code/test.py with the following code:

import torch
import argparse
import torch.backends.cudnn as cudnn
import os
from architecture import *
from utils import save_matv73
import cv2
import numpy as np
import itertools
parser = argparse.ArgumentParser(description="SSR")
parser.add_argument('--method', type=str, default='mst_plus_plus')
parser.add_argument('--pretrained_model_path', type=str, default='./model_zoo/mst_plus_plus.pth')
parser.add_argument('--rgb_path', type=str, default='./demo/ARAD_1K_0912.jpg')
parser.add_argument('--outf', type=str, default='./exp/mst_plus_plus/')
parser.add_argument('--ensemble_mode', type=str, default='mean')
parser.add_argument("--gpu_id", type=str, default='0')
opt = parser.parse_args()
os.environ["CUDA_DEVICE_ORDER"] = 'PCI_BUS_ID'
os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpu_id
if not os.path.exists(opt.outf):
    os.makedirs(opt.outf)

def main():
    cudnn.benchmark = True
    pretrained_model_path = opt.pretrained_model_path
    method = opt.method
    model = model_generator(method, pretrained_model_path)
    test(model, opt.rgb_path, opt.outf)

def test(model, rgb_path, save_path):
    var_name = 'cube'
    bgr = cv2.imread(rgb_path)
    rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
    rgb = np.float32(rgb)
    rgb = (rgb - rgb.min()) / (rgb.max() - rgb.min())
    rgb = np.expand_dims(np.transpose(rgb, [2, 0, 1]), axis=0).copy()
    rgb = torch.from_numpy(rgb).float()
    print(f'Reconstructing {rgb_path}')
    with torch.no_grad():
        result = model(rgb)
    result = result.cpu().numpy() * 1.0
    result = np.transpose(np.squeeze(result), [1, 2, 0])
    result = np.minimum(result, 1.0)
    result = np.maximum(result, 0)

    mat_name = rgb_path.split('/')[-1][:-4] + '.mat'
    mat_dir = os.path.join(save_path, mat_name)
    save_matv73(mat_dir, var_name, result)
    print(f'The reconstructed hyper spectral image are saved as {mat_dir}.')

if __name__ == '__main__':
    main()

Thanks for your patient reply. I get the .mat profile accroding to your guidnce using CPU. But I don't know how to visulize every channels of the .mat profile. Can you teach me how to visualize the HSI channels in python?

Hi, we do not plan to open source the visualization code now for some considerations. You can try to visualize the hyperspectral image

  1. using the 'Hyperspectral Viewer' toolbox in MATLAB
    (1) Install Hyperspectral Viewer in MATLAB
    (2) Run the following code:
clear; clc;
file_path = "ARAD_1K_0912.mat";
pred = load(file_path).cube;
hyperspectralViewer(pred)

image

  1. or try the following code:
import h5py
import cv2
import numpy as np
path = "ARAD_1K_0912.mat"
with h5py.File(path, 'r') as mat:
    hyper = np.float32(np.array(mat['cube']))*255
cv2.imwrite('ARAD_1K_0912.png', hyper[15,:,:])

Hello author, I have got the MAT file, may I ask how I can extract the spectral information of a certain point?

what do you mean by extracting the spectral information of a certain point?

提取某个点的光谱信息是什么意思?

For example, in figure 4 and figure 5 in the literature MST++, there is a comparison of spectra in the lower left corner. How can I verify this information?

(i) Select a small spatial patch.

(ii) Compute the average spectral intensity of this region.

(iii) Compute the average correlation coefficient between the reconstructed HSI with GT HSI.

(i) 选择一个小的空间补丁。

(ii) 计算该区域的平均光谱强度。

(iii) 计算重建 HSI 与 GT HSI 之间的平均相关系数。
How to implement the second step?

直接计算均值。你可以直接说中文。

直接计算均值。你可以直接说中文。
哈哈,谢谢老哥。我现在刚开始接触这个高光谱图像,所以不太明白怎么提取mat文件中的光谱信息,你的代码中有这部分的内容吗?

有的,在另一个repo里面

https://github.com/caiyuanhao1998/MST

你可以看看4.3

然后你要是觉得我们的代码对你有帮助的话,麻烦点点star,fork,follow

好的,感谢!!

有的,在另一个repo里面

https://github.com/caiyuanhao1998/MST

你可以看看4.3

然后你要是觉得我们的代码对你有帮助的话,麻烦点点star,fork,follow

你好,我运行代码的时候遇到pred_block_无法识别的问题,老哥知道怎么解决吗?