google / flax

Flax is a neural network library for JAX that is designed for flexibility.

Home Page:https://flax.readthedocs.io

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

how to change this torch to flax?

lu-ming-lei opened this issue · comments

optimizer = torch.optim.Adam([variable1, variable2], 0.001)

how to change this torch to flax?

I try this but no success: optimizer = flax.optim.Adam([variable1, variable2], 0.001)

Hey @lu-ming-lei can you provide a longer code snippet? In Flax we don't pass the variables we want to optmize when initialising the optimizer

optimizer = torch.optim.Adam([
  {'params': variables1, 'lr':0.1},  
   {'params': variables2, 'lr':0.01}, 
  ])
how to change this code into flax?