grf-labs / policytree

Policy learning via doubly robust empirical welfare maximization over trees

Home Page:https://grf-labs.github.io/policytree/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Evaluating a policy tree in an observational setting

njawadekar opened this issue · comments

In this tutorial, it is described how one can evaluate a policy tree which is generated in an experimental setting:
# Only valid for experimental setting!
# predicting pi(X[i]) for each value in the test set
# policytree labels each treatment as 1,2,3... Here, we subtract one so that zero represents control, as we're used to in the case of a binary treatment.
w.opt <- predict(policy, X[-train,]) - 1 A <- w.opt == 1

# Copied and pasted from Policy Evaluation section
value.estimate <- mean(Y[A & (W==1)]) * mean(A) + mean(Y[!A & (W==0)]) * mean(!A) value.stderr <- sqrt(var(Y[A & (W==1)]) / sum(A & (W==1)) * mean(A)^2 + var(Y[!A & (W==0)]) / sum(!A & W==0) * mean(!A)^2) print(paste("Value estimate:", value.estimate, "Std. Error:", value.stderr))

As shown, one can estimate the overall expected risk of the outcome in a particular subsample (e.g., test set) by multiplying the conditional probabilities of Y by the estimated prevalences of A under the identified policy.

I have two questions related to this:
(1) What if we wanted to apply this same concept, but in an observational setting? Do you have any code / a modified doubly robust estimator which could help perform this evaluation of the policy tree in the presence of confounding?
(2) Furthermore, what if we wanted to implement this evaluation on the entire sample (rather than only the test sample). Do you have a recommended approach/code that would effectively do this?

Thanks, I can try to implement policy tree using k-fold validation. But if I have 10 folds (and therefore, 10 different policy trees), how do you suggest I evaluate this "policy"? Would it make sense for me to evaluate each policy on the holdout set 10 different times, and then take the average performance across all 10?