google-research / maxim

[CVPR 2022 Oral] Official repository for "MAXIM: Multi-Axis MLP for Image Processing". SOTA for denoising, deblurring, deraining, dehazing, and enhancement.

Home Page:https://arxiv.org/abs/2201.02973

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

flax.errors.ScopeParamNotFoundError: No parameter named "kernel" exists in "/stage_1_output_conv_2".

zarmondo11 opened this issue · comments

Hello.
I try to run this rep on google colab and it works fine with Enhancement pre-trained model but when I want to load Deblurring model and use predict() to get output, this error appears...

MODEL_PATH = "Deblurring/GoPro/checkpoint.npz"
FLAGS = DummyFlags(ckpt_path = MODEL_PATH, task = "Enhancement") 
params = get_params(FLAGS.ckpt_path)
model = build_model()

import requests
from io import BytesIO

url = 'https://replicate.com/api/models/google-research/maxim/files/6707a57f-4957-4047-b020-2160aed1d27a/1fromGOPR0950.png'
image_bytes = BytesIO(requests.get(url).content)

result = predict(image_bytes)

f, ax = plt.subplots(1,2, figsize = (35,20))

ax[0].imshow(np.array(Image.open(image_bytes)))
ax[1].imshow(result) 

ax[0].set_title("Original Image")
ax[1].set_title("Enhanced Image")

plt.show()

UnfilteredStackTrace Traceback (most recent call last)
in ()
8
----> 9 result = predict(image_bytes)
10

18 frames
UnfilteredStackTrace: flax.errors.ScopeParamNotFoundError: No parameter named "kernel" exists in "/stage_1_output_conv_2". (https://flax.readthedocs.io/en/latest/flax.errors.html#flax.errors.ScopeParamNotFoundError)

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.


The above exception was the direct cause of the following exception:

ScopeParamNotFoundError Traceback (most recent call last)
/usr/local/lib/python3.7/dist-packages/flax/linen/linear.py in call(self, inputs)
356
357 kernel = self.param('kernel', self.kernel_init, kernel_shape,
--> 358 self.param_dtype)
359 kernel = jnp.asarray(kernel, self.dtype)
360

ScopeParamNotFoundError: No parameter named "kernel" exists in "/stage_1_output_conv_2". (https://flax.readthedocs.io/en/latest/flax.errors.html#flax.errors.ScopeParamNotFoundError)

what's the problem? is it with pre-trained models? if yes, how can I fix it or make my own model?

Hi please revise the following part:

model = build_model(task="Deblurring")

because the default model if for Dehazing:

def build_model(task = "Dehazing"):
  model_mod = importlib.import_module(f'maxim.models.{_MODEL_FILENAME}')
  model_configs = ml_collections.ConfigDict(_MODEL_CONFIGS)

  model_configs.variant = _MODEL_VARIANT_DICT[task]

  model = model_mod.Model(**model_configs)
  return model

thanks a lot. <3