maxim - translate to keras
innat opened this issue · comments
I tried to translate maxim form jax to keras. All looks fine but the number of the training parameter looks abnormally large. For the following config, I got total 674,238,099
params.
H, W = 224, 224
INS, MODEL = MAXIM(
features= 32,
depth=3,
num_stages=1,
num_groups=2,
num_bottleneck_blocks=2,
block_gmlp_factor=2,
grid_gmlp_factor=2,
input_proj_factor=2,
channels_reduction=4,
num_supervision_scales=3,
use_bias=True,
lrelu_slope=0.1,
use_global_mlp=10,
use_cross_gating=False,
high_res_stages=1,
block_size_hr=[2, 2],
block_size_lr=[2, 2],
grid_size_hr=[2, 2],
grid_size_lr=[2, 2],
num_outputs=3,
dropout_rate=0.5,
)
Could you please check the plot diagram? (in case you notice any misconnection; if you click on the image body below, it will open on new tab and it would be easy to inspect.).
Hi thanks for the enormous efforts @innat! Could you please share a pointer to the keras code?