Problem in results!!
chihabhedidi opened this issue · comments
Hedidi Chihab commented
i have trained the model with my Gpu and i wanted to test the result, i used the code that someone put it in #28 but it dosnt give me the result, am getting this result, and this is my code to test it:
from datetime import datetime
import os
from PIL import Image
from typing import OrderedDict
import cv2
import numpy as np
from models.encoder import Encoder
from models.decoder import Decoder
from models.refiner import Refiner
from models.merger import Merger
from config import cfg
import torch
from utils import binvox_visualization, data_transforms
encoder = Encoder(cfg)
decoder = Decoder(cfg)
refiner = Refiner(cfg)
merger = Merger(cfg)
encoder = Encoder(cfg)
decoder = Decoder(cfg)
refiner = Refiner(cfg)
merger = Merger(cfg)
checkpoint = torch.load(cfg.CONST.WEIGHTS, map_location=torch.device('cpu'))
fix_checkpoint = {}
fix_checkpoint['encoder_state_dict'] = OrderedDict((k.split('module.')[1:][0], v) for k, v in checkpoint['encoder_state_dict'].items())
fix_checkpoint['decoder_state_dict'] = OrderedDict((k.split('module.')[1:][0], v) for k, v in checkpoint['decoder_state_dict'].items())
fix_checkpoint['refiner_state_dict'] = OrderedDict((k.split('module.')[1:][0], v) for k, v in checkpoint['refiner_state_dict'].items())
fix_checkpoint['merger_state_dict'] = OrderedDict((k.split('module.')[1:][0], v) for k, v in checkpoint['merger_state_dict'].items())
epoch_idx = checkpoint['epoch_idx']
encoder.load_state_dict(fix_checkpoint['encoder_state_dict'])
decoder.load_state_dict(fix_checkpoint['decoder_state_dict'])
if cfg.NETWORK.USE_REFINER:
print('Use refiner')
refiner.load_state_dict(fix_checkpoint['refiner_state_dict'])
if cfg.NETWORK.USE_MERGER:
print('Use merger')
merger.load_state_dict(fix_checkpoint['merger_state_dict'])
encoder.eval()
decoder.eval()
refiner.eval()
merger.eval()
img1_path = '/home/nz/Pictures/Screenshots/Screenshot from 2024-04-24 12-41-07.png'
img1_np = cv2.imread(img1_path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
sample = np.array([img1_np])
IMG_SIZE = cfg.CONST.IMG_H, cfg.CONST.IMG_W
CROP_SIZE = cfg.CONST.CROP_IMG_H, cfg.CONST.CROP_IMG_W
test_transforms = data_transforms.Compose([
data_transforms.CenterCrop(IMG_SIZE, CROP_SIZE),
data_transforms.RandomBackground(cfg.TEST.RANDOM_BG_COLOR_RANGE),
data_transforms.Normalize(mean=cfg.DATASET.MEAN, std=cfg.DATASET.STD),
data_transforms.ToTensor(),
])
rendering_images = test_transforms(rendering_images=sample)
rendering_images = rendering_images.unsqueeze(0)
with torch.no_grad():
image_features = encoder(rendering_images)
raw_features, generated_volume = decoder(image_features)
if cfg.NETWORK.USE_MERGER and epoch_idx >= cfg.TRAIN.EPOCH_START_USE_MERGER:
generated_volume = merger(raw_features, generated_volume)
else:
generated_volume = torch.mean(generated_volume, dim=1)
if cfg.NETWORK.USE_REFINER and epoch_idx >= cfg.TRAIN.EPOCH_START_USE_REFINER:
generated_volume = refiner(generated_volume)
generated_volume = generated_volume.squeeze(0)
img_dir = './image_outputs'
gv = generated_volume.cpu().numpy()
gv_new = np.swapaxes(gv, 2, 1)
rendering_views = binvox_visualization.get_volume_views(gv_new, os.path.join(img_dir + '/' + str(datetime.now())), epoch_idx)
i also tried the image input with both white and transparent background.