sail-sg / jax_xc

Exchange correlation functionals translated from libxc to jax

Home Page:https://jax-xc.readthedocs.io/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Enable vxc

DyeKuu opened this issue · comments

Current impl only enables exc: we still need vxc in some cases.

+1, we need that too.

From libxc 6.0.0 I checked the maple files built for lda_exc/lda_c_rpa and lda_vxc/lda_xc_tih


Polarization := "unpol":
Digits := 20:             (* constants will have 20 digits *)
interface(warnlevel=0):   (* supress all warnings          *)
with(CodeGeneration):

$include <lda_c_rpa.mpl>


dens := (r0, r1) -> r0:
zeta := (r0, r1) -> 0:

dmfd10 := (v0, v1) ->  eval(diff(mf(v0, v1), v0)):

dmfd01 := (v0, v1) ->  eval(diff(mf(v0, v1), v1)):

dmfd20 := (v0, v1) ->  eval(diff(dmfd10(v0, v1), v0)):

dmfd11 := (v0, v1) ->  eval(diff(dmfd01(v0, v1), v0)):

dmfd02 := (v0, v1) ->  eval(diff(dmfd01(v0, v1), v1)):

dmfd30 := (v0, v1) ->  eval(diff(dmfd20(v0, v1), v0)):

dmfd21 := (v0, v1) ->  eval(diff(dmfd11(v0, v1), v0)):

dmfd12 := (v0, v1) ->  eval(diff(dmfd02(v0, v1), v0)):

dmfd03 := (v0, v1) ->  eval(diff(dmfd02(v0, v1), v1)):




# zk is energy per unit particle
mzk  := (r0, r1) -> \
   + \
    f(r_ws(dens(r0, r1)), zeta(r0, r1)) \
   :

  (* mf is energy per unit volume *)
  mf   := (r0, r1) -> eval(dens(r0, r1)*mzk(r0, r1)):

$include <util.mpl>

C([ zk_0_ = mzk(rho_0_, rho_1_), vrho_0_ = dmfd10(rho_0_, rho_1_), vrho_1_ = dmfd01(rho_0_, rho_1_), v2rho2_0_ = dmfd20(rho_0_, rho_1_), v2rho2_1_ = dmfd11(rho_0_, rho_1_), v2rho2_2_ = dmfd02(rho_0_, rho_1_), v3rho3_0_ = dmfd30(rho_0_, rho_1_), v3rho3_1_ = dmfd21(rho_0_, rho_1_), v3rho3_2_ = dmfd12(rho_0_, rho_1_), v3rho3_3_ = dmfd03(rho_0_, rho_1_)], optimize, deducetypes=false):

and


Polarization := "unpol":
Digits := 20:             (* constants will have 20 digits *)
interface(warnlevel=0):   (* supress all warnings          *)
with(CodeGeneration):

$include <lda_xc_tih.mpl>


dens := (r0, r1) -> r0:
zeta := (r0, r1) -> 0:

dmf0d10 := (v0, v1) ->  eval(diff(mf0(v0, v1), v0)):

dmf0d01 := (v0, v1) ->  eval(diff(mf0(v0, v1), v1)):

dmf0d20 := (v0, v1) ->  eval(diff(dmf0d10(v0, v1), v0)):

dmf0d11 := (v0, v1) ->  eval(diff(dmf0d01(v0, v1), v0)):

dmf0d02 := (v0, v1) ->  eval(diff(dmf0d01(v0, v1), v1)):




mzk  := (r0, r1) -> \
   + \
    f(r_ws(dens(r0, r1)), zeta(r0, r1)) \
   :

(* mf is the up potential *)
mf0   := (r0, r1) -> eval(mzk(r0, r1)):
mf1   := (r0, r1) -> eval(mzk(r1, r0)):

$include <util.mpl>

C([ vrho_0_ = mf0(rho_0_, rho_1_), v2rho2_0_ = dmf0d10(rho_0_, rho_1_), v2rho2_1_ = dmf0d01(rho_0_, rho_1_), v3rho3_0_ = dmf0d20(rho_0_, rho_1_), v3rho3_1_ = dmf0d11(rho_0_, rho_1_), v3rho3_2_ = dmf0d02(rho_0_, rho_1_)], optimize, deducetypes=false):

It looks like we can harvest the vxc from the vrho_0_ output, and we can get vxc out of exc but not vice versa. It would be better to start from the maple file generated for vxc and do the correspondence to that for exc. But I'm not sure that I understand the math here. The code shown here is doing nothing fancy here (like functional derivative) but just say that vxc is the partial derivative of exc * rho to its input arguments of rho, (the same goes for mgga even with tau required).

  • Why there is vrho_0_ and vrho_1_? It looks like we need to add them together and make it the final output of vxc, for example in lda_c_rpa's codegen:
func_vxc_pol(const xc_func_type *p, size_t ip, const double *rho, xc_lda_out_params *out)
{
  double t1, t3, t4, t5, t6, t7, t8, t10;
  double t11, t13, t14, t17, t18, tzk0;

  double t19, t23, t25, t27, tvrho0, tvrho1;


  t1 = M_CBRT3;
  t3 = POW_1_3(0.1e1 / M_PI);
  t4 = t1 * t3;
  t5 = M_CBRT4;
  t6 = t5 * t5;
  t7 = rho[0] + rho[1];
  t8 = POW_1_3(t7);
  t10 = t6 / t8;
  t11 = t4 * t10;
  t13 = log(t11 / 0.4e1);
  t14 = 0.311e-1 * t13;
  t17 = 0.225e-2 * t4 * t10 * t13;
  t18 = 0.425e-2 * t11;
  tzk0 = t14 - 0.48e-1 + t17 - t18;

  if(out->zk != NULL && (p->info->flags & XC_FLAGS_HAVE_EXC))
    out->zk[ip*p->dim.zk + 0] += tzk0;

  t19 = 0.1e1 / t7;
  t23 = t6 / t8 / t7;
  t25 = t4 * t23 * t13;
  t27 = t4 * t23;
  tvrho0 = t14 - 0.48e-1 + t17 - t18 + t7 * (-0.10366666666666666667e-1 * t19 - 0.75e-3 * t25 + 0.6666666666666666667e-3 * t27);

  if(out->vrho != NULL && (p->info->flags & XC_FLAGS_HAVE_VXC))
    out->vrho[ip*p->dim.vrho + 0] += tvrho0;

  tvrho1 = tvrho0;

  if(out->vrho != NULL && (p->info->flags & XC_FLAGS_HAVE_VXC))
    out->vrho[ip*p->dim.vrho + 1] += tvrho1;

}
  • Do we wanna get the diff code out of maple or leave it to jax's autodiff?

If latter, basically what we need to do is just do a jax.grad if we can construct a function like lambda rho: exc(rho) * rho that doesn't involve the position r. But that may need to break the design as all the callable for now is depending on either purely on r (before 0.0.4) or both rho and r.

Taking the codegen in python for example:

def pol(p, r, s=(None, None, None), l=(None, None), tau=(None, None)):
  params = p.params
  (r0, r1), (s0, s1, s2), (l0, l1), (tau0, tau1) = r, s, l, tau
  t1 = jnp.cbrt(3)
  t3 = jnp.cbrt(0.1e1 / jnp.pi)
  t4 = t1 * t3
  t5 = jnp.cbrt(4)
  t6 = t5 ** 2
  t8 = jnp.cbrt(r0 + r1)
  t10 = t6 / t8
  t11 = t4 * t10
  t13 = jnp.log(t11 / 0.4e1)
  res = 0.311e-1 * t13 - 0.48e-1 + 0.225e-2 * t4 * t10 * t13 - 0.425e-2 * t11
  return res


def unpol(p, r, s=None, l=None, tau=None):
  params = p.params
  r0, s0, l0, tau0 = r, s, l, tau
  t1 = jnp.cbrt(3)
  t3 = jnp.cbrt(0.1e1 / jnp.pi)
  t4 = t1 * t3
  t5 = jnp.cbrt(4)
  t6 = t5 ** 2
  t7 = jnp.cbrt(r0)
  t9 = t6 / t7
  t10 = t4 * t9
  t12 = jnp.log(t10 / 0.4e1)
  res = 0.311e-1 * t12 - 0.48e-1 + 0.225e-2 * t4 * t9 * t12 - 0.425e-2 * t10
  return res


def invoke(
  p: NamedTuple, rho: Callable, r: jnp.ndarray, mo: Optional[Callable] = None,
  deorbitalize: Optional[float] = None,
):
  args = rho_to_arguments(p, rho, r, mo, deorbitalize)
  ret = pol(p, *args) if p.nspin == 2 else unpol(p, *args)
  dens = args[0] if p.nspin == 1 else sum(args[0])
  ret = float(dens >= p.dens_threshold) * ret
  return ret

We don't need to propagate the grad through the position. So we just need a partial function of pol or unpol where all are fixed except rho (r in the code here). My guess is that we could do something like vxc = jax.grad(lambda rho: rho * pol(p, rho, *args)).

cc @mavenlin

From libxc 6.0.0 I checked the maple files built for lda_exc/lda_c_rpa and lda_vxc/lda_xc_tih


Polarization := "unpol":
Digits := 20:             (* constants will have 20 digits *)
interface(warnlevel=0):   (* supress all warnings          *)
with(CodeGeneration):

$include <lda_c_rpa.mpl>


dens := (r0, r1) -> r0:
zeta := (r0, r1) -> 0:

dmfd10 := (v0, v1) ->  eval(diff(mf(v0, v1), v0)):

dmfd01 := (v0, v1) ->  eval(diff(mf(v0, v1), v1)):

dmfd20 := (v0, v1) ->  eval(diff(dmfd10(v0, v1), v0)):

dmfd11 := (v0, v1) ->  eval(diff(dmfd01(v0, v1), v0)):

dmfd02 := (v0, v1) ->  eval(diff(dmfd01(v0, v1), v1)):

dmfd30 := (v0, v1) ->  eval(diff(dmfd20(v0, v1), v0)):

dmfd21 := (v0, v1) ->  eval(diff(dmfd11(v0, v1), v0)):

dmfd12 := (v0, v1) ->  eval(diff(dmfd02(v0, v1), v0)):

dmfd03 := (v0, v1) ->  eval(diff(dmfd02(v0, v1), v1)):




# zk is energy per unit particle
mzk  := (r0, r1) -> \
   + \
    f(r_ws(dens(r0, r1)), zeta(r0, r1)) \
   :

  (* mf is energy per unit volume *)
  mf   := (r0, r1) -> eval(dens(r0, r1)*mzk(r0, r1)):

$include <util.mpl>

C([ zk_0_ = mzk(rho_0_, rho_1_), vrho_0_ = dmfd10(rho_0_, rho_1_), vrho_1_ = dmfd01(rho_0_, rho_1_), v2rho2_0_ = dmfd20(rho_0_, rho_1_), v2rho2_1_ = dmfd11(rho_0_, rho_1_), v2rho2_2_ = dmfd02(rho_0_, rho_1_), v3rho3_0_ = dmfd30(rho_0_, rho_1_), v3rho3_1_ = dmfd21(rho_0_, rho_1_), v3rho3_2_ = dmfd12(rho_0_, rho_1_), v3rho3_3_ = dmfd03(rho_0_, rho_1_)], optimize, deducetypes=false):

and


Polarization := "unpol":
Digits := 20:             (* constants will have 20 digits *)
interface(warnlevel=0):   (* supress all warnings          *)
with(CodeGeneration):

$include <lda_xc_tih.mpl>


dens := (r0, r1) -> r0:
zeta := (r0, r1) -> 0:

dmf0d10 := (v0, v1) ->  eval(diff(mf0(v0, v1), v0)):

dmf0d01 := (v0, v1) ->  eval(diff(mf0(v0, v1), v1)):

dmf0d20 := (v0, v1) ->  eval(diff(dmf0d10(v0, v1), v0)):

dmf0d11 := (v0, v1) ->  eval(diff(dmf0d01(v0, v1), v0)):

dmf0d02 := (v0, v1) ->  eval(diff(dmf0d01(v0, v1), v1)):




mzk  := (r0, r1) -> \
   + \
    f(r_ws(dens(r0, r1)), zeta(r0, r1)) \
   :

(* mf is the up potential *)
mf0   := (r0, r1) -> eval(mzk(r0, r1)):
mf1   := (r0, r1) -> eval(mzk(r1, r0)):

$include <util.mpl>

C([ vrho_0_ = mf0(rho_0_, rho_1_), v2rho2_0_ = dmf0d10(rho_0_, rho_1_), v2rho2_1_ = dmf0d01(rho_0_, rho_1_), v3rho3_0_ = dmf0d20(rho_0_, rho_1_), v3rho3_1_ = dmf0d11(rho_0_, rho_1_), v3rho3_2_ = dmf0d02(rho_0_, rho_1_)], optimize, deducetypes=false):

It looks like we can harvest the vxc from the vrho_0_ output, and we can get vxc out of exc but not vice versa. It would be better to start from the maple file generated for vxc and do the correspondence to that for exc. But I'm not sure that I understand the math here. The code shown here is doing nothing fancy here (like functional derivative) but just say that vxc is the partial derivative of exc * rho to its input arguments of rho, (the same goes for mgga even with tau required).

  • Why there is vrho_0_ and vrho_1_? It looks like we need to add them together and make it the final output of vxc, for example in lda_c_rpa's codegen:
func_vxc_pol(const xc_func_type *p, size_t ip, const double *rho, xc_lda_out_params *out)
{
  double t1, t3, t4, t5, t6, t7, t8, t10;
  double t11, t13, t14, t17, t18, tzk0;

  double t19, t23, t25, t27, tvrho0, tvrho1;


  t1 = M_CBRT3;
  t3 = POW_1_3(0.1e1 / M_PI);
  t4 = t1 * t3;
  t5 = M_CBRT4;
  t6 = t5 * t5;
  t7 = rho[0] + rho[1];
  t8 = POW_1_3(t7);
  t10 = t6 / t8;
  t11 = t4 * t10;
  t13 = log(t11 / 0.4e1);
  t14 = 0.311e-1 * t13;
  t17 = 0.225e-2 * t4 * t10 * t13;
  t18 = 0.425e-2 * t11;
  tzk0 = t14 - 0.48e-1 + t17 - t18;

  if(out->zk != NULL && (p->info->flags & XC_FLAGS_HAVE_EXC))
    out->zk[ip*p->dim.zk + 0] += tzk0;

  t19 = 0.1e1 / t7;
  t23 = t6 / t8 / t7;
  t25 = t4 * t23 * t13;
  t27 = t4 * t23;
  tvrho0 = t14 - 0.48e-1 + t17 - t18 + t7 * (-0.10366666666666666667e-1 * t19 - 0.75e-3 * t25 + 0.6666666666666666667e-3 * t27);

  if(out->vrho != NULL && (p->info->flags & XC_FLAGS_HAVE_VXC))
    out->vrho[ip*p->dim.vrho + 0] += tvrho0;

  tvrho1 = tvrho0;

  if(out->vrho != NULL && (p->info->flags & XC_FLAGS_HAVE_VXC))
    out->vrho[ip*p->dim.vrho + 1] += tvrho1;

}
  • Do we wanna get the diff code out of maple or leave it to jax's autodiff?

If latter, basically what we need to do is just do a jax.grad if we can construct a function like lambda rho: exc(rho) * rho that doesn't involve the position r. But that may need to break the design as all the callable for now is depending on either purely on r (before 0.0.4) or both rho and r.

Taking the codegen in python for example:

def pol(p, r, s=(None, None, None), l=(None, None), tau=(None, None)):
  params = p.params
  (r0, r1), (s0, s1, s2), (l0, l1), (tau0, tau1) = r, s, l, tau
  t1 = jnp.cbrt(3)
  t3 = jnp.cbrt(0.1e1 / jnp.pi)
  t4 = t1 * t3
  t5 = jnp.cbrt(4)
  t6 = t5 ** 2
  t8 = jnp.cbrt(r0 + r1)
  t10 = t6 / t8
  t11 = t4 * t10
  t13 = jnp.log(t11 / 0.4e1)
  res = 0.311e-1 * t13 - 0.48e-1 + 0.225e-2 * t4 * t10 * t13 - 0.425e-2 * t11
  return res


def unpol(p, r, s=None, l=None, tau=None):
  params = p.params
  r0, s0, l0, tau0 = r, s, l, tau
  t1 = jnp.cbrt(3)
  t3 = jnp.cbrt(0.1e1 / jnp.pi)
  t4 = t1 * t3
  t5 = jnp.cbrt(4)
  t6 = t5 ** 2
  t7 = jnp.cbrt(r0)
  t9 = t6 / t7
  t10 = t4 * t9
  t12 = jnp.log(t10 / 0.4e1)
  res = 0.311e-1 * t12 - 0.48e-1 + 0.225e-2 * t4 * t9 * t12 - 0.425e-2 * t10
  return res


def invoke(
  p: NamedTuple, rho: Callable, r: jnp.ndarray, mo: Optional[Callable] = None,
  deorbitalize: Optional[float] = None,
):
  args = rho_to_arguments(p, rho, r, mo, deorbitalize)
  ret = pol(p, *args) if p.nspin == 2 else unpol(p, *args)
  dens = args[0] if p.nspin == 1 else sum(args[0])
  ret = float(dens >= p.dens_threshold) * ret
  return ret

We don't need to propagate the grad through the position. So we just need a partial function of pol or unpol where all are fixed except rho (r in the code here). My guess is that we could do something like vxc = jax.grad(lambda rho: rho * pol(p, rho, *args)).

cc @mavenlin

The jax solution looks nice here, basically we're making use of the semi-local properties of the functional and derive the vxc manually.

I think this is the solution to go for lda, but we may want to consider gga and mgga before proceeding, as the functional gradient goes through the nabla rho.

Using automatic functional derivative to support this.

#31