Use ML Models for Variance Reduction in A/B experiments
TL;DR
ML models can significantly reduce variance and hence increase the power of your A/B experiments.
Introduction
Machine Learning (ML) models are good at making predictions. But in general they are not widely used for recovering causality. I have always been wondering if they can do us any good for causal inference. I will explore this very topic by running some simulations and report back to the community with a series of articles.
In this series, I plan to examine the following cases:
- (This article) The treatment is assigned completely at random (i.e., in an A/B experiment).
- We have observational data and the ignorability assumption holds.
- We have observational data and the exclusion restriction holds.
Note that I plan to work with binary treatments throughout the whole series. If you have never seen the term “ignorability” or “exclusion restriction,” do not worry about them now. I will explain them when I write about them.
When the treatment is assigned completely at random, ML models not only can recover causality, but can do it better. You can achieve a higher statistical power if you use a ML model (compared to vanilla t-tests or linear regressions).
Next I will examine several simulation cases, from simple to difficult for a model to learn, as evidence.
Homogeneous Treatment Effect
What is the homogeneous treatment effect? Homogeneous treatment effect means each and every sample in your dataset has the exact same effect from the treatment. This possibly will never be true in reality, but it is a nice baseline case. If our model fails in this case, then we do not need to go further.
Let’s generate some simulation data to work with. We will stick to simulated data so that we can control the true treatment effect. Otherwise we will never verify if our model has recovered the truth because of the fundamental problem in causal inference.
In this dataset, treatment
is independently drawn from a Bernoulli distribution with p=0.5
. We have 6 features other than the treatment. Features a
, b
and c
influence the outcome in a complex and non-linear way. Feature e
and f
do not influence the outcome in any way. The treatment effect is 0.1
throughout the board. Our task is to recover this 0.1
.
Let’s take a look at the distribution of the outcome, grouped by the treatment.
From the plot, we can see the outcome is continuously distributed. The distribution of the control group is not that different from the treated group. If we just eyeball it, we might conclude that there is little to no effect.
Let’s build a ML model that uses all we have (features a
to f
and the treatment) to predict outcome
.
To estimate the treatment effect, we need to run the prediction twice. We set treatment
to 0 and 1 for all samples and predict the outcomes respectively. Then we take the difference of the two sets of predictions and take the average of the differences. If the model can recover the average treatment effect, then we should get a number that is close to 0.1.
We get:
Predicted Average Treatment Effect: 0.1028
0.1028, which is not equal to, but quite close to 0.1. Does this mean we fail? Or is this just noise? To answer this question, we need the standard error and/or the confidence interval of the estimated effect. But unfortunately we do not have a nice formula for the standard error because the model is a black box. We have to find the standard error and the confidence interval through non-parametric bootstrap.
The result from the bootstrap goes as follows:
We can see the estimated effect is indeed very close to 0.1 and the 95% confidence interval covers the true value. Let’s also plot the results from all bootstrap runs to check if the distribution looks reasonable.
Good! Now I am confident to say ML models can recover the causal effect in this case.
But you probably want to argue this is just the standard A/B test and we don’t need LightGBM to recover the causal effect. You are right. Let’s see how other approaches perform compared to the ML based one. Let’s first do the standard t-test. Because it is so standard, I will omit the code and show you the result here:
The vanilla t-test obviously also recovers the effect. However, the most striking difference is the standard error and the confidence interval. Here the standard error is 280% of the one given by the LightGBM bootstrap.
Is this because the standard error given by the ML based approach is wrong? Let me give you two other standard errors by trying two other approaches.
The first approach is a linear regression with the treatment indicator and all the linear terms of the features (i.e., treatment + a + b + c + d + e + f + constant
). The coefficient on treatment
is the estimated effect, packages like statsmodels
will give us the associated standard error and the confidence interval. The result goes as follows.
The estimated effect is still correct, and the standard error is smaller than the t-test one.
Before I explain why, let’s look at another approach. We will still stick with the linear regression, but this time let’s add some polynomial terms to better capture the non-linear effects we baked in earlier. I will add quadratic and interactive terms on top of existing linear terms (that is, we add a**2 + b**2 + … + f**2 + a*b + a*c + … + e*f
). And here is the result:
The standard error further goes down, but it is still no way near the ML based one.
What is going on here? Why do we have different standard errors and why the one from ML models is the lowest?
Let’s answer the former first. All these approaches can give us unbiased and consistent estimates of the causal effect, but the noise level contained in each estimate is different. In the t-test, we only utilize information on the treatment (in practice, this means we only utilize the cohort assignment information). All the contributions from other features to the outcome are considered “noise” under this approach. So the noise level is the highest.
If we want to reduce the standard error, then we need more data to overcome the noise. We can continue using t-test and gather more data points. This is a perfectly fine strategy if your traffic is cheap.
But more data does not necessarily imply more data points. We can add more data horizontally by adding more features. Features that help to predict the outcome help us to reduce the noise. With the help of these features, the treatment column does not have to fight alone. In some sense, the noise on the treatment’s shoulder is less.
With this, we can understand why the standard errors from these approaches rank in this way.
In the first linear regression, the linear terms take some burden off of the treatment, but these linear terms cannot fully approximate the nonlinear functional forms we cooked up earlier. So the non-linear parts are still “noise” to the treatment. Therefore, we see reduced standard error, but not quite comparable to other approaches.
In the second linear regression, the polynomial terms help to fit the non-linear contributions of the features. Therefore we see further reduction from the first linear regression. However, these polynomial terms still cannot perfectly approximate the non-linearity we cooked up earlier and the treatment still has quite a bit left-overs to deal with.
ML models are good at fitting unknown and complex functions. With a sufficiently flexible model, we can fit the true relationship between features (i.e., a
to f
) and the outcome as much as possible. Therefore, we can leave little noise for the treatment. This is why the ML based standard error is much lower.
For some not 100% scientific evidence, I fitted all the models without the treatment column (for t-test, this just means a model with a constant term only). Let’s take a look at their residuals’ standard deviations.
Again not 100% scientific, but you can think of these standard deviations as the noise levels that the treatment column needs to fight against. Clearly the one given by the LightGBM model is the lowest.
Heterogeneous Causal Effects
In the previous example, you might think I’m cheating because the treatment is only linearly added to the outcome. What if we make it more complex so that each sample’s causal effect is different? Let’s give it a try.
I will examine two cases. The first case is mathematically a bit simpler: each sample’s treatment effect only depends on feature c
. In the second case, features a
, b
and c
all influence the treatment effect in different ways. If the ML based approach can take care of both, then I feel comfortable relying on ML (in this case LightGBM) models in practice.
The following block is the code that generates the outcome for the two cases.
Case 1’s true average causal effect is about -0.4432
and case 2’s true average causal effect is about -2.9390
(they are computed by numerical integration).
For case 1, here is the result comparison among all four approaches:
And here is the result for case 2:
They all manage to recover the true effect. As in the previous case, the standard errors go down when non-treatment features capture more noise. For the highest power, you still want to choose the ML based approach.
What have we learned from this exercise?
In an experimental setting, the flip side of noise level is statistical power. Higher noise level, lower statistical power, and vice versa.
When you do a t-test in your A/B test workflow, everything else the same, you need the highest number of data points to fight against the noise. If you have enough traffic, then this is perfectly okay.
But when you do not have enough traffic and yet you’re not willing to lower the statistical power or increase the effect size, then you can include features that help to predict the outcome variable. One important caveat is that those features cannot be influenced by your treatment, otherwise you have label leakage. So one practical approach to implement in your pipeline is logging some features of your subject at the same time they enter your intervention page.
In the analysis phase, you can use either linear regressions or ML models to achieve noise reduction. I suggest you implement both. Linear regression, although not as effective, is much faster to compute. You have to do bootstrap when you use the ML based approach, and you need to train a new model in each bootstrap run. In this mock example with only 20k data points, the ML based approach takes about 30 minutes in the free version of Google’s Colab while the linear regression is almost instantaneous.