aesara-devs / aehmc

An HMC/NUTS implementation in Aesara

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Initialize `RaveledParamsMap` with dictionaries

rlouf opened this issue · comments

Currently one has to pass an iterable (that is then converted to a tuple) to initialize RaveledParamsMap:

import aesara as at
from aehmc.utils import RaveledParamsMap

tau_vv = at.vector("tau")
lambda_vv = at.vector("lambda")

rp_map = RaveledParamsMap((tau_vv, lambda_vv))

q = rp_map.ravel_params((tau_vv, lambda_vv))
tau_part = rp_map.unravel_params(q)[tau_vv]
lambda_part = rp_map.unravel_params(q)[lambda_vv]

In some circumstances we need the map to be indexed with other variables. For instance when we work with transformed variables and need the map to link the original value variables to the transformed variables (which may have different shapes/dtypes). In this case we need to overwrite the ref_params property:

from aeppl.transforms import LogTransform

lambda_vv_trans = LogTransform().forward(lambda_vv)

rp_map_trans = RaveledParamsMap((tau_vv, lambda_vv_trans))
rp_map_trans.ref_params = (tau_vv, lambda_vv)

q = rp_map_trans.ravel_params((tau_vv, lambda_vv))
tau_part = rp_map_trans.unravel_params(q)[tau_vv]
lambda_trans_part = rp_map_trans.unravel_params(q)[lambda_vv]

I suggest to simplify this by allowing the RaveledParamsMap to be initialized with a dictionary:

rp_map_trans = RaveledParamsMap({tau_vv: tau_vv, lambda_vv: lambda_vv_trans})

Shapes and dtypes are infered from the dictionaries' values, the map is indexed by the dictionaries' keys.