junyongyou / triq

TRIQ implementation

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Same output for every input image

sulakshgupta988 opened this issue · comments

def create_triq_model(n_quality_levels,
                      input_shape=(None, None, 3),
                      backbone='resnet50',
                      transformer_params=(2, 32, 8, 64),
                      maximum_position_encoding=193,
                      vis=False):
    chanDim = -1
    # define the model input
    inputs = Input(shape=input_shape)
    filters = (32, 64, 128)
    # loop over the number of filters
    for (i, f) in enumerate(filters):
        # if this is the first CONV layer then set the input
        # appropriately
        if i == 0:
            x = Rescaling(1./255)(inputs)

        # CONV => RELU => BN => POOL
        x = Conv2D(f, (3, 3), padding="same")(x)
        x = Activation("relu")(x)
        x = BatchNormalization(axis=chanDim)(x)
        x = MaxPooling2D(pool_size=(2, 2))(x)
    
    x = Conv2D(256, (3, 3), padding="same")(x)
    x = Activation("relu")(x)
    x = BatchNormalization(axis=chanDim)(x)
    x = MaxPooling2D(pool_size=(2, 2))(x)
    
    x = ZeroPadding2D(padding=(1, 1))(x)
    x = Conv2D(2048, (3, 3), padding="same")(x)
    x = Activation("relu")(x)
    x = BatchNormalization(axis=chanDim)(x)
    x = MaxPooling2D(pool_size=(2, 2))(x)
    dropout_rate = 0.1
    
    transformer = TriQImageQualityTransformer(
        num_layers=transformer_params[0],
        d_model=transformer_params[1],
        num_heads=transformer_params[2],
        mlp_dim=transformer_params[3],
        dropout=dropout_rate,
        n_quality_levels=n_quality_levels,
        maximum_position_encoding=maximum_position_encoding,
        vis=vis
    )
    outputs = transformer(x)
  
    model = Model(inputs=inputs, outputs=outputs)
    model.summary()
    return model

gpus = tf.config.experimental.list_physical_devices('GPU')
tf.config.experimental.set_visible_devices(gpus[0], 'GPU')
input_shape = (564, 504, 3)
#model = create_triq_model(n_quality_levels=5, input_shape=input_shape, backbone='vgg16')
model = create_triq_model(n_quality_levels=1, input_shape=input_shape, backbone='resnet50')

from tensorflow.keras.optimizers import Adam
opt = Adam(learning_rate=0.001, decay=1e-3 / 200)
model.compile(loss="mean_squared_error", optimizer=opt)
model.fit(trainImagesX, trainY, validation_data=(valImagesX, valY),
          epochs=108, batch_size=16)

In the above code, I have modified the create_triq_model function in such a way that it uses a custom CNN model instead of the RSNET or VGGNet. The custom CNN model is such that its output shape is (18, 16, 2048). This output is fed to TriqImageQualityTransformer.

The issue is that after training the model predicts the same value for every input. I have experimented with various hyperparameters. It might output different values for different hyperparameter settings but for a particular setting, for every image as input, it outputs the same output. One more thing to note is that if I do not use a transformer but instead use an Artificial Neural Network, then the network trains well.

Ca you please suggest what am I doing wrong here?

I suspect this is because from one of the medium layers, the output become a fixed value, e.g., 0. Maybe you can check the outputs from medium layers. If the hyper-parameters change, the model architecture also changes, and the outputs from medium layers change accordingly.

It is noted that from my personal experience an IQA model is heavily dependent on other pretrained nets, e.g., ResNet50 on ImageNet. If you use a custom net, meaning that you probably have not pretrained it on large-scale databases, which can definitely affect the performance.

Thanks a lot. I will use your suggestions.

Even when I use resnet50 as the backbone as you have used, the same problem occurs. Can you suggest some thoughts here? Can hyperparameter tuning help to solve this or some other problem might be the cause of this behavior?

I don't fully understand your problem. Did you mean if you are using exactly the same code as mine and you still got same output for your input images?

yes exactly

Can you try to use my trained weights (TRIQ.h5) and run image_quality_prediction.py on your images and see how it works? Will you still get same quality output of all your images?

I think another potential reason is x = Rescaling(1./255)(inputs). You first scale the pixel values to [0, 1], and then if you are using my generator, normalization will be performed. These can possibly normalize your images to 0. You can also check this.