Initialize `RaveledParamsMap` with dictionaries
rlouf opened this issue · comments
Rémi Louf commented
Currently one has to pass an iterable (that is then converted to a tuple) to initialize RaveledParamsMap
:
import aesara as at
from aehmc.utils import RaveledParamsMap
tau_vv = at.vector("tau")
lambda_vv = at.vector("lambda")
rp_map = RaveledParamsMap((tau_vv, lambda_vv))
q = rp_map.ravel_params((tau_vv, lambda_vv))
tau_part = rp_map.unravel_params(q)[tau_vv]
lambda_part = rp_map.unravel_params(q)[lambda_vv]
In some circumstances we need the map to be indexed with other variables. For instance when we work with transformed variables and need the map to link the original value variables to the transformed variables (which may have different shapes/dtypes). In this case we need to overwrite the ref_params
property:
from aeppl.transforms import LogTransform
lambda_vv_trans = LogTransform().forward(lambda_vv)
rp_map_trans = RaveledParamsMap((tau_vv, lambda_vv_trans))
rp_map_trans.ref_params = (tau_vv, lambda_vv)
q = rp_map_trans.ravel_params((tau_vv, lambda_vv))
tau_part = rp_map_trans.unravel_params(q)[tau_vv]
lambda_trans_part = rp_map_trans.unravel_params(q)[lambda_vv]
I suggest to simplify this by allowing the RaveledParamsMap
to be initialized with a dictionary:
rp_map_trans = RaveledParamsMap({tau_vv: tau_vv, lambda_vv: lambda_vv_trans})
Shapes and dtypes are infered from the dictionaries' values, the map is indexed by the dictionaries' keys.