hzxie / Pix2Vox

The official implementation of "Pix2Vox: Context-aware 3D Reconstruction from Single and Multi-view Images". (Xie et al., ICCV 2019)

Home Page:https://haozhexie.com/project/pix2vox

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Problem in results!!

chihabhedidi opened this issue · comments

i have trained the model with my Gpu and i wanted to test the result, i used the code that someone put it in #28 but it dosnt give me the result, am getting this result, and this is my code to test it:

from datetime import datetime
import os
from PIL import Image
from typing import OrderedDict
 
import cv2
import numpy as np
from models.encoder import Encoder
from models.decoder import Decoder
from models.refiner import Refiner
from models.merger import Merger
 
from config import cfg
 
import torch
 
from utils import binvox_visualization, data_transforms
 
encoder = Encoder(cfg)
decoder = Decoder(cfg)
refiner = Refiner(cfg)
merger = Merger(cfg)
 
encoder = Encoder(cfg)
decoder = Decoder(cfg)
refiner = Refiner(cfg)
merger = Merger(cfg)
 
checkpoint = torch.load(cfg.CONST.WEIGHTS, map_location=torch.device('cpu'))
 
fix_checkpoint = {}
fix_checkpoint['encoder_state_dict'] = OrderedDict((k.split('module.')[1:][0], v) for k, v in checkpoint['encoder_state_dict'].items())
fix_checkpoint['decoder_state_dict'] = OrderedDict((k.split('module.')[1:][0], v) for k, v in checkpoint['decoder_state_dict'].items())
fix_checkpoint['refiner_state_dict'] = OrderedDict((k.split('module.')[1:][0], v) for k, v in checkpoint['refiner_state_dict'].items())
fix_checkpoint['merger_state_dict'] = OrderedDict((k.split('module.')[1:][0], v) for k, v in checkpoint['merger_state_dict'].items())
 
epoch_idx = checkpoint['epoch_idx']
encoder.load_state_dict(fix_checkpoint['encoder_state_dict'])
decoder.load_state_dict(fix_checkpoint['decoder_state_dict'])
 
if cfg.NETWORK.USE_REFINER:
    print('Use refiner')
    refiner.load_state_dict(fix_checkpoint['refiner_state_dict'])
if cfg.NETWORK.USE_MERGER:
    print('Use merger')
    merger.load_state_dict(fix_checkpoint['merger_state_dict'])
 
 
encoder.eval()
decoder.eval()
refiner.eval()
merger.eval()
 
img1_path = '/home/nz/Pictures/Screenshots/Screenshot from 2024-04-24 12-41-07.png'
img1_np = cv2.imread(img1_path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
 
sample = np.array([img1_np])
 
IMG_SIZE = cfg.CONST.IMG_H, cfg.CONST.IMG_W
CROP_SIZE = cfg.CONST.CROP_IMG_H, cfg.CONST.CROP_IMG_W
 
test_transforms = data_transforms.Compose([
    data_transforms.CenterCrop(IMG_SIZE, CROP_SIZE),
    data_transforms.RandomBackground(cfg.TEST.RANDOM_BG_COLOR_RANGE),
    data_transforms.Normalize(mean=cfg.DATASET.MEAN, std=cfg.DATASET.STD),
    data_transforms.ToTensor(),
])
 
rendering_images = test_transforms(rendering_images=sample)
rendering_images = rendering_images.unsqueeze(0)
 
with torch.no_grad():
    image_features = encoder(rendering_images)
    raw_features, generated_volume = decoder(image_features)
 
    if cfg.NETWORK.USE_MERGER and epoch_idx >= cfg.TRAIN.EPOCH_START_USE_MERGER:
        generated_volume = merger(raw_features, generated_volume)
    else:
        generated_volume = torch.mean(generated_volume, dim=1)
 
    if cfg.NETWORK.USE_REFINER and epoch_idx >= cfg.TRAIN.EPOCH_START_USE_REFINER:
        generated_volume = refiner(generated_volume)
 
generated_volume = generated_volume.squeeze(0)
 
img_dir = './image_outputs'
gv = generated_volume.cpu().numpy()
gv_new = np.swapaxes(gv, 2, 1)
rendering_views = binvox_visualization.get_volume_views(gv_new, os.path.join(img_dir + '/' + str(datetime.now())), epoch_idx)

voxels-000230
and this is the image input :
test

i also tried the image input with both white and transparent background.

Try using the pretrained model (if you like huggingface, get it from here) and see if your input produces a good output. If it doesn't, chances are the eval script has some issue. If it does, chances are you aren't training it well or long enough.