vgsatorras / few-shot-gnn

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

how to train without the trick of rotation?

Daniel123jia opened this issue · comments

def load_dataset(self, train, size):
    print("Loading dataset")
    if train:
        with open(os.path.join(self.root, 'compacted_datasets', 'omniglot_train.pickle'), 'rb') as handle:
            data = pickle.load(handle)
            print('data为{}'.format(data.keys()))
    else:
        with open(os.path.join(self.root, 'compacted_datasets', 'omniglot_test.pickle'), 'rb') as handle:
            data = pickle.load(handle)
    print("Num classes before rotations: "+str(len(data)))
    ”“”
    data_rot = {}
    # resize images and normalize
    # print('data的类型是{}'.format(type(data)))
    # print('data.keys()是是{}'.format(data.keys()))
    # print('len(data[5])是{}'.format(len(data[5])))
    for class_ in data:
        # print('class_为{}'.format(class_))
        for rot in range(4):
            data_rot[class_ * 4 + rot] = []
            for i in range(len(data[class_])):
                image2resize = pil_image.fromarray(np.uint8(data[class_][i]*255))
                image_resized = image2resize.resize((size[1], size[0]))
                image_resized = np.array(image_resized, dtype='float32')/127.5 - 1
                image = self.rotate_image(image_resized, rot)
                image = np.expand_dims(image, axis=0)
                data_rot[class_ * 4 + rot].append(image)
    # print('data_rot的keys为{}'.format(data_rot.keys()))
    print("Dataset Loaded")
    print("Num classes after rotations: "+str(len(data_rot)))
    self.sanity_check(data_rot)
    “”“
    return data

I commented out the data enhancement part directly. But there are some dimensional matching errors.
I would be very grateful if you can help me solve it. thank you!

I have already solve it now!