cafeal / NormalizingFlows

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

tf.distribute.MirroredStrategy()

yangyijune opened this issue · comments

I try to use tf.distribute.MirroredStrategy to speed the training process, but get an error as follows:
ValueError:'colocate_vars_with' must only be passed a variable created in this tf.distribute.Strategy.scope(), not<tf.Variable 'batch_normlization/gamma:0' shape(2, 0), dtype=float32>

The MirroredStrategy I used is:
with mirrored_strategy.scope():
model = LogProb(distribution, bijector)
loss_fn = ..
model.compile(....)

Have you used tf.distribute.Strategy to accelerate the training process? If you do, would you help me to fix this error?