An python repository to perform fair sampling which is applied in submitted paper in @todo.
Download this repository with git clone
or equivalent.
git clone https://github.com/lsha49/FairBERT_deploy.git
- Python 3.8
- Tensorflow > 1.5
- tensorflow-estimator 2.7.0
- tensorflow-macos 2.7.0
- tensorflow-metal 0.3.0
- Sklearn > 0.19.0
We detail below how to implement hardness constraint (H-bias) on seed dataset. See example code in Util.py
The H-bias can be calculated by calKDN
function.
After generating samples, we evaluate the kDN distribution by first calculating kDN by kdn_score()
.
kdnResult = kdn_score(features, labels, number_of_neighbors)
Then calculating JS distance by distance.jensenshannon
and selected samples which lower H-bias.
distance.jensenshannon()
We applied abroca
package in ABROCA. A sample calculation of ABROCA:
slice = compute_abroca(abrocaDf,
pred_col = 'prob_1' ,
label_col = 'label',
protected_attr_col = 'gender',
majority_protected_attr_val = '2',
compare_type = 'binary', # binary, overall, etc...
n_grid = 10000,
plot_slices = False)
A Logistic regression model is implemented in Util.py
by logisticRegression
function.
A sample GridSearched model:
lrc = LogisticRegression(C=4.281332398719396, class_weight=None, dual=False,
fit_intercept=True, intercept_scaling=1, max_iter=100,
n_jobs=1, penalty='l1', random_state=None,
solver='liblinear', tol=0.0001, verbose=0, warm_start=False)
A sample embedding extraction from BERT model is implemented in ```MEmb.py``, where BERT embedding is extracted.
hidden_states = model(torch.tensor(tokenizer.encode(entry,truncation=True)).unsqueeze(0))[1]
We followed the same pretraining procedule as shown in huggingface
See a sample implmentation in MTrain.py
.
BertForMaskedLM.from_pretrained("bert-base-uncased")
BertForNextSentencePrediction.from_pretrained("bert-base-uncased")
AL sampling is implemented by alipy
See a sample implmentation of QBC in MALSample.py
.
See a comprehensive documentation of all the query selection function in here
alibox.get_query_strategy(strategy_name='QueryInstanceQBC').select(labelledSet, unLabelledSet, model=xxx, batch_size=xxx)