flax.errors.ScopeParamNotFoundError: No parameter named "kernel" exists in "/stage_1_output_conv_2".
zarmondo11 opened this issue · comments
Hello.
I try to run this rep on google colab and it works fine with Enhancement pre-trained model but when I want to load Deblurring model and use predict() to get output, this error appears...
MODEL_PATH = "Deblurring/GoPro/checkpoint.npz"
FLAGS = DummyFlags(ckpt_path = MODEL_PATH, task = "Enhancement")
params = get_params(FLAGS.ckpt_path)
model = build_model()
import requests
from io import BytesIO
url = 'https://replicate.com/api/models/google-research/maxim/files/6707a57f-4957-4047-b020-2160aed1d27a/1fromGOPR0950.png'
image_bytes = BytesIO(requests.get(url).content)
result = predict(image_bytes)
f, ax = plt.subplots(1,2, figsize = (35,20))
ax[0].imshow(np.array(Image.open(image_bytes)))
ax[1].imshow(result)
ax[0].set_title("Original Image")
ax[1].set_title("Enhanced Image")
plt.show()
UnfilteredStackTrace Traceback (most recent call last)
in ()
8
----> 9 result = predict(image_bytes)
1018 frames
UnfilteredStackTrace: flax.errors.ScopeParamNotFoundError: No parameter named "kernel" exists in "/stage_1_output_conv_2". (https://flax.readthedocs.io/en/latest/flax.errors.html#flax.errors.ScopeParamNotFoundError)The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
The above exception was the direct cause of the following exception:
ScopeParamNotFoundError Traceback (most recent call last)
/usr/local/lib/python3.7/dist-packages/flax/linen/linear.py in call(self, inputs)
356
357 kernel = self.param('kernel', self.kernel_init, kernel_shape,
--> 358 self.param_dtype)
359 kernel = jnp.asarray(kernel, self.dtype)
360ScopeParamNotFoundError: No parameter named "kernel" exists in "/stage_1_output_conv_2". (https://flax.readthedocs.io/en/latest/flax.errors.html#flax.errors.ScopeParamNotFoundError)
what's the problem? is it with pre-trained models? if yes, how can I fix it or make my own model?
Hi please revise the following part:
model = build_model(task="Deblurring")
because the default model if for Dehazing:
def build_model(task = "Dehazing"):
model_mod = importlib.import_module(f'maxim.models.{_MODEL_FILENAME}')
model_configs = ml_collections.ConfigDict(_MODEL_CONFIGS)
model_configs.variant = _MODEL_VARIANT_DICT[task]
model = model_mod.Model(**model_configs)
return model
thanks a lot. <3