Enable vxc
DyeKuu opened this issue · comments
Current impl only enables exc: we still need vxc in some cases.
+1, we need that too.
From libxc 6.0.0 I checked the maple files built for lda_exc/lda_c_rpa
and lda_vxc/lda_xc_tih
Polarization := "unpol":
Digits := 20: (* constants will have 20 digits *)
interface(warnlevel=0): (* supress all warnings *)
with(CodeGeneration):
$include <lda_c_rpa.mpl>
dens := (r0, r1) -> r0:
zeta := (r0, r1) -> 0:
dmfd10 := (v0, v1) -> eval(diff(mf(v0, v1), v0)):
dmfd01 := (v0, v1) -> eval(diff(mf(v0, v1), v1)):
dmfd20 := (v0, v1) -> eval(diff(dmfd10(v0, v1), v0)):
dmfd11 := (v0, v1) -> eval(diff(dmfd01(v0, v1), v0)):
dmfd02 := (v0, v1) -> eval(diff(dmfd01(v0, v1), v1)):
dmfd30 := (v0, v1) -> eval(diff(dmfd20(v0, v1), v0)):
dmfd21 := (v0, v1) -> eval(diff(dmfd11(v0, v1), v0)):
dmfd12 := (v0, v1) -> eval(diff(dmfd02(v0, v1), v0)):
dmfd03 := (v0, v1) -> eval(diff(dmfd02(v0, v1), v1)):
# zk is energy per unit particle
mzk := (r0, r1) -> \
+ \
f(r_ws(dens(r0, r1)), zeta(r0, r1)) \
:
(* mf is energy per unit volume *)
mf := (r0, r1) -> eval(dens(r0, r1)*mzk(r0, r1)):
$include <util.mpl>
C([ zk_0_ = mzk(rho_0_, rho_1_), vrho_0_ = dmfd10(rho_0_, rho_1_), vrho_1_ = dmfd01(rho_0_, rho_1_), v2rho2_0_ = dmfd20(rho_0_, rho_1_), v2rho2_1_ = dmfd11(rho_0_, rho_1_), v2rho2_2_ = dmfd02(rho_0_, rho_1_), v3rho3_0_ = dmfd30(rho_0_, rho_1_), v3rho3_1_ = dmfd21(rho_0_, rho_1_), v3rho3_2_ = dmfd12(rho_0_, rho_1_), v3rho3_3_ = dmfd03(rho_0_, rho_1_)], optimize, deducetypes=false):
and
Polarization := "unpol":
Digits := 20: (* constants will have 20 digits *)
interface(warnlevel=0): (* supress all warnings *)
with(CodeGeneration):
$include <lda_xc_tih.mpl>
dens := (r0, r1) -> r0:
zeta := (r0, r1) -> 0:
dmf0d10 := (v0, v1) -> eval(diff(mf0(v0, v1), v0)):
dmf0d01 := (v0, v1) -> eval(diff(mf0(v0, v1), v1)):
dmf0d20 := (v0, v1) -> eval(diff(dmf0d10(v0, v1), v0)):
dmf0d11 := (v0, v1) -> eval(diff(dmf0d01(v0, v1), v0)):
dmf0d02 := (v0, v1) -> eval(diff(dmf0d01(v0, v1), v1)):
mzk := (r0, r1) -> \
+ \
f(r_ws(dens(r0, r1)), zeta(r0, r1)) \
:
(* mf is the up potential *)
mf0 := (r0, r1) -> eval(mzk(r0, r1)):
mf1 := (r0, r1) -> eval(mzk(r1, r0)):
$include <util.mpl>
C([ vrho_0_ = mf0(rho_0_, rho_1_), v2rho2_0_ = dmf0d10(rho_0_, rho_1_), v2rho2_1_ = dmf0d01(rho_0_, rho_1_), v3rho3_0_ = dmf0d20(rho_0_, rho_1_), v3rho3_1_ = dmf0d11(rho_0_, rho_1_), v3rho3_2_ = dmf0d02(rho_0_, rho_1_)], optimize, deducetypes=false):
It looks like we can harvest the vxc
from the vrho_0_
output, and we can get vxc
out of exc
but not vice versa. It would be better to start from the maple file generated for vxc
and do the correspondence to that for exc
. But I'm not sure that I understand the math here. The code shown here is doing nothing fancy here (like functional derivative) but just say that vxc
is the partial derivative of exc * rho
to its input arguments of rho
, (the same goes for mgga even with tau required).
- Why there is
vrho_0_
andvrho_1_
? It looks like we need to add them together and make it the final output ofvxc
, for example inlda_c_rpa
's codegen:
func_vxc_pol(const xc_func_type *p, size_t ip, const double *rho, xc_lda_out_params *out)
{
double t1, t3, t4, t5, t6, t7, t8, t10;
double t11, t13, t14, t17, t18, tzk0;
double t19, t23, t25, t27, tvrho0, tvrho1;
t1 = M_CBRT3;
t3 = POW_1_3(0.1e1 / M_PI);
t4 = t1 * t3;
t5 = M_CBRT4;
t6 = t5 * t5;
t7 = rho[0] + rho[1];
t8 = POW_1_3(t7);
t10 = t6 / t8;
t11 = t4 * t10;
t13 = log(t11 / 0.4e1);
t14 = 0.311e-1 * t13;
t17 = 0.225e-2 * t4 * t10 * t13;
t18 = 0.425e-2 * t11;
tzk0 = t14 - 0.48e-1 + t17 - t18;
if(out->zk != NULL && (p->info->flags & XC_FLAGS_HAVE_EXC))
out->zk[ip*p->dim.zk + 0] += tzk0;
t19 = 0.1e1 / t7;
t23 = t6 / t8 / t7;
t25 = t4 * t23 * t13;
t27 = t4 * t23;
tvrho0 = t14 - 0.48e-1 + t17 - t18 + t7 * (-0.10366666666666666667e-1 * t19 - 0.75e-3 * t25 + 0.6666666666666666667e-3 * t27);
if(out->vrho != NULL && (p->info->flags & XC_FLAGS_HAVE_VXC))
out->vrho[ip*p->dim.vrho + 0] += tvrho0;
tvrho1 = tvrho0;
if(out->vrho != NULL && (p->info->flags & XC_FLAGS_HAVE_VXC))
out->vrho[ip*p->dim.vrho + 1] += tvrho1;
}
- Do we wanna get the diff code out of maple or leave it to jax's autodiff?
If latter, basically what we need to do is just do a jax.grad
if we can construct a function like lambda rho: exc(rho) * rho
that doesn't involve the position r
. But that may need to break the design as all the callable for now is depending on either purely on r
(before 0.0.4) or both rho
and r
.
Taking the codegen in python for example:
def pol(p, r, s=(None, None, None), l=(None, None), tau=(None, None)):
params = p.params
(r0, r1), (s0, s1, s2), (l0, l1), (tau0, tau1) = r, s, l, tau
t1 = jnp.cbrt(3)
t3 = jnp.cbrt(0.1e1 / jnp.pi)
t4 = t1 * t3
t5 = jnp.cbrt(4)
t6 = t5 ** 2
t8 = jnp.cbrt(r0 + r1)
t10 = t6 / t8
t11 = t4 * t10
t13 = jnp.log(t11 / 0.4e1)
res = 0.311e-1 * t13 - 0.48e-1 + 0.225e-2 * t4 * t10 * t13 - 0.425e-2 * t11
return res
def unpol(p, r, s=None, l=None, tau=None):
params = p.params
r0, s0, l0, tau0 = r, s, l, tau
t1 = jnp.cbrt(3)
t3 = jnp.cbrt(0.1e1 / jnp.pi)
t4 = t1 * t3
t5 = jnp.cbrt(4)
t6 = t5 ** 2
t7 = jnp.cbrt(r0)
t9 = t6 / t7
t10 = t4 * t9
t12 = jnp.log(t10 / 0.4e1)
res = 0.311e-1 * t12 - 0.48e-1 + 0.225e-2 * t4 * t9 * t12 - 0.425e-2 * t10
return res
def invoke(
p: NamedTuple, rho: Callable, r: jnp.ndarray, mo: Optional[Callable] = None,
deorbitalize: Optional[float] = None,
):
args = rho_to_arguments(p, rho, r, mo, deorbitalize)
ret = pol(p, *args) if p.nspin == 2 else unpol(p, *args)
dens = args[0] if p.nspin == 1 else sum(args[0])
ret = float(dens >= p.dens_threshold) * ret
return ret
We don't need to propagate the grad through the position. So we just need a partial function of pol
or unpol
where all are fixed except rho
(r
in the code here). My guess is that we could do something like vxc = jax.grad(lambda rho: rho * pol(p, rho, *args))
.
cc @mavenlin
From libxc 6.0.0 I checked the maple files built for
lda_exc/lda_c_rpa
andlda_vxc/lda_xc_tih
Polarization := "unpol": Digits := 20: (* constants will have 20 digits *) interface(warnlevel=0): (* supress all warnings *) with(CodeGeneration): $include <lda_c_rpa.mpl> dens := (r0, r1) -> r0: zeta := (r0, r1) -> 0: dmfd10 := (v0, v1) -> eval(diff(mf(v0, v1), v0)): dmfd01 := (v0, v1) -> eval(diff(mf(v0, v1), v1)): dmfd20 := (v0, v1) -> eval(diff(dmfd10(v0, v1), v0)): dmfd11 := (v0, v1) -> eval(diff(dmfd01(v0, v1), v0)): dmfd02 := (v0, v1) -> eval(diff(dmfd01(v0, v1), v1)): dmfd30 := (v0, v1) -> eval(diff(dmfd20(v0, v1), v0)): dmfd21 := (v0, v1) -> eval(diff(dmfd11(v0, v1), v0)): dmfd12 := (v0, v1) -> eval(diff(dmfd02(v0, v1), v0)): dmfd03 := (v0, v1) -> eval(diff(dmfd02(v0, v1), v1)): # zk is energy per unit particle mzk := (r0, r1) -> \ + \ f(r_ws(dens(r0, r1)), zeta(r0, r1)) \ : (* mf is energy per unit volume *) mf := (r0, r1) -> eval(dens(r0, r1)*mzk(r0, r1)): $include <util.mpl> C([ zk_0_ = mzk(rho_0_, rho_1_), vrho_0_ = dmfd10(rho_0_, rho_1_), vrho_1_ = dmfd01(rho_0_, rho_1_), v2rho2_0_ = dmfd20(rho_0_, rho_1_), v2rho2_1_ = dmfd11(rho_0_, rho_1_), v2rho2_2_ = dmfd02(rho_0_, rho_1_), v3rho3_0_ = dmfd30(rho_0_, rho_1_), v3rho3_1_ = dmfd21(rho_0_, rho_1_), v3rho3_2_ = dmfd12(rho_0_, rho_1_), v3rho3_3_ = dmfd03(rho_0_, rho_1_)], optimize, deducetypes=false):
and
Polarization := "unpol": Digits := 20: (* constants will have 20 digits *) interface(warnlevel=0): (* supress all warnings *) with(CodeGeneration): $include <lda_xc_tih.mpl> dens := (r0, r1) -> r0: zeta := (r0, r1) -> 0: dmf0d10 := (v0, v1) -> eval(diff(mf0(v0, v1), v0)): dmf0d01 := (v0, v1) -> eval(diff(mf0(v0, v1), v1)): dmf0d20 := (v0, v1) -> eval(diff(dmf0d10(v0, v1), v0)): dmf0d11 := (v0, v1) -> eval(diff(dmf0d01(v0, v1), v0)): dmf0d02 := (v0, v1) -> eval(diff(dmf0d01(v0, v1), v1)): mzk := (r0, r1) -> \ + \ f(r_ws(dens(r0, r1)), zeta(r0, r1)) \ : (* mf is the up potential *) mf0 := (r0, r1) -> eval(mzk(r0, r1)): mf1 := (r0, r1) -> eval(mzk(r1, r0)): $include <util.mpl> C([ vrho_0_ = mf0(rho_0_, rho_1_), v2rho2_0_ = dmf0d10(rho_0_, rho_1_), v2rho2_1_ = dmf0d01(rho_0_, rho_1_), v3rho3_0_ = dmf0d20(rho_0_, rho_1_), v3rho3_1_ = dmf0d11(rho_0_, rho_1_), v3rho3_2_ = dmf0d02(rho_0_, rho_1_)], optimize, deducetypes=false):
It looks like we can harvest the
vxc
from thevrho_0_
output, and we can getvxc
out ofexc
but not vice versa. It would be better to start from the maple file generated forvxc
and do the correspondence to that forexc
. But I'm not sure that I understand the math here. The code shown here is doing nothing fancy here (like functional derivative) but just say thatvxc
is the partial derivative ofexc * rho
to its input arguments ofrho
, (the same goes for mgga even with tau required).
- Why there is
vrho_0_
andvrho_1_
? It looks like we need to add them together and make it the final output ofvxc
, for example inlda_c_rpa
's codegen:func_vxc_pol(const xc_func_type *p, size_t ip, const double *rho, xc_lda_out_params *out) { double t1, t3, t4, t5, t6, t7, t8, t10; double t11, t13, t14, t17, t18, tzk0; double t19, t23, t25, t27, tvrho0, tvrho1; t1 = M_CBRT3; t3 = POW_1_3(0.1e1 / M_PI); t4 = t1 * t3; t5 = M_CBRT4; t6 = t5 * t5; t7 = rho[0] + rho[1]; t8 = POW_1_3(t7); t10 = t6 / t8; t11 = t4 * t10; t13 = log(t11 / 0.4e1); t14 = 0.311e-1 * t13; t17 = 0.225e-2 * t4 * t10 * t13; t18 = 0.425e-2 * t11; tzk0 = t14 - 0.48e-1 + t17 - t18; if(out->zk != NULL && (p->info->flags & XC_FLAGS_HAVE_EXC)) out->zk[ip*p->dim.zk + 0] += tzk0; t19 = 0.1e1 / t7; t23 = t6 / t8 / t7; t25 = t4 * t23 * t13; t27 = t4 * t23; tvrho0 = t14 - 0.48e-1 + t17 - t18 + t7 * (-0.10366666666666666667e-1 * t19 - 0.75e-3 * t25 + 0.6666666666666666667e-3 * t27); if(out->vrho != NULL && (p->info->flags & XC_FLAGS_HAVE_VXC)) out->vrho[ip*p->dim.vrho + 0] += tvrho0; tvrho1 = tvrho0; if(out->vrho != NULL && (p->info->flags & XC_FLAGS_HAVE_VXC)) out->vrho[ip*p->dim.vrho + 1] += tvrho1; }
- Do we wanna get the diff code out of maple or leave it to jax's autodiff?
If latter, basically what we need to do is just do a
jax.grad
if we can construct a function likelambda rho: exc(rho) * rho
that doesn't involve the positionr
. But that may need to break the design as all the callable for now is depending on either purely onr
(before 0.0.4) or bothrho
andr
.Taking the codegen in python for example:
def pol(p, r, s=(None, None, None), l=(None, None), tau=(None, None)): params = p.params (r0, r1), (s0, s1, s2), (l0, l1), (tau0, tau1) = r, s, l, tau t1 = jnp.cbrt(3) t3 = jnp.cbrt(0.1e1 / jnp.pi) t4 = t1 * t3 t5 = jnp.cbrt(4) t6 = t5 ** 2 t8 = jnp.cbrt(r0 + r1) t10 = t6 / t8 t11 = t4 * t10 t13 = jnp.log(t11 / 0.4e1) res = 0.311e-1 * t13 - 0.48e-1 + 0.225e-2 * t4 * t10 * t13 - 0.425e-2 * t11 return res def unpol(p, r, s=None, l=None, tau=None): params = p.params r0, s0, l0, tau0 = r, s, l, tau t1 = jnp.cbrt(3) t3 = jnp.cbrt(0.1e1 / jnp.pi) t4 = t1 * t3 t5 = jnp.cbrt(4) t6 = t5 ** 2 t7 = jnp.cbrt(r0) t9 = t6 / t7 t10 = t4 * t9 t12 = jnp.log(t10 / 0.4e1) res = 0.311e-1 * t12 - 0.48e-1 + 0.225e-2 * t4 * t9 * t12 - 0.425e-2 * t10 return res def invoke( p: NamedTuple, rho: Callable, r: jnp.ndarray, mo: Optional[Callable] = None, deorbitalize: Optional[float] = None, ): args = rho_to_arguments(p, rho, r, mo, deorbitalize) ret = pol(p, *args) if p.nspin == 2 else unpol(p, *args) dens = args[0] if p.nspin == 1 else sum(args[0]) ret = float(dens >= p.dens_threshold) * ret return retWe don't need to propagate the grad through the position. So we just need a partial function of
pol
orunpol
where all are fixed exceptrho
(r
in the code here). My guess is that we could do something likevxc = jax.grad(lambda rho: rho * pol(p, rho, *args))
.cc @mavenlin
The jax solution looks nice here, basically we're making use of the semi-local properties of the functional and derive the vxc manually.
I think this is the solution to go for lda, but we may want to consider gga and mgga before proceeding, as the functional gradient goes through the nabla rho
.