EYcab / imodelsX

Interpretable prompting and models for NLP (using large language models).

Home Page:https://csinva.io/imodelsX/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Library to explain a dataset in natural language.

πŸ“– demo notebooks

Model Reference Description
iPrompt πŸ“–, πŸ—‚οΈ, πŸ”—, πŸ“„ Generates a human-interpretable prompt that
explains patterns in data (Official)
Emb-GAM πŸ“–, πŸ—‚οΈ, πŸ”—, πŸ“„ Fit better linear model using an LLM to extract embeddings (Official)
D3 πŸ“–, πŸ—‚οΈ, πŸ”—, πŸ“„ Explain the difference between two distributions
Linear Finetune β €β €β €πŸ—‚οΈ Scikit-learn interface to finetune a single linear layer
on top of LLM embeddings for classification/regression
AutoPrompt β €β €β €πŸ—‚οΈ, πŸ”—, πŸ“„ Find a natural-language prompt using input-gradients (βŒ› In progress)
(Coming soon!) βŒ› We plan to support other interpretable models like RLPrompt,
concept bottleneck models, NAMs, and NBDT

Demo notebooks πŸ“–, Doc πŸ—‚οΈ, Reference code implementation πŸ”—, Research paper πŸ“„

Quickstart

Installation: pip install imodelsx (or, for more control, clone and install from source)

Demos: see the demo notebooks

iPrompt

from imodelsx import explain_dataset_iprompt, get_add_two_numbers_dataset

# get a simple dataset of adding two numbers
input_strings, output_strings = get_add_two_numbers_dataset(num_examples=100)
for i in range(5):
    print(repr(input_strings[i]), repr(output_strings[i]))

# explain the relationship between the inputs and outputs
# with a natural-language prompt string
prompts, metadata = explain_dataset_iprompt(
    input_strings=input_strings,
    output_strings=output_strings,
    checkpoint='EleutherAI/gpt-j-6B', # which language model to use
    num_learned_tokens=3, # how long of a prompt to learn
    n_shots=3, # shots per example

    n_epochs=15, # how many epochs to search
    verbose=0, # how much to print
    llm_float16=True, # whether to load the model in float_16
)
--------
prompts is a list of found natural-language prompt strings

D3 (DescribeDistributionalDifferences)

import imodelsx
hypotheses, hypothesis_scores = imodelsx.explain_datasets_d3(
    pos=positive_samples, # List[str] of positive examples
    neg=negative_samples, # another List[str]
    num_steps=100,
    num_folds=2,
    batch_size=64,
)

Emb-GAM

from imodelsx import EmbGAMClassifier
import datasets
import numpy as np

# set up data
dset = datasets.load_dataset('rotten_tomatoes')['train']
dset = dset.select(np.random.choice(len(dset), size=300, replace=False))
dset_val = datasets.load_dataset('rotten_tomatoes')['validation']
dset_val = dset_val.select(np.random.choice(len(dset_val), size=300, replace=False))

# fit model
m = EmbGAMClassifier(
    checkpoint='textattack/distilbert-base-uncased-rotten-tomatoes',
    ngrams=2, # use bigrams
)
m.fit(dset['text'], dset['label'])

# predict
preds = m.predict(dset_val['text'])
print('acc_val', np.mean(preds == dset_val['label']))

# interpret
print('Total ngram coefficients: ', len(m.coefs_dict_))
print('Most positive ngrams')
for k, v in sorted(m.coefs_dict_.items(), key=lambda item: item[1], reverse=True)[:8]:
    print('\t', k, round(v, 2))
print('Most negative ngrams')
for k, v in sorted(m.coefs_dict_.items(), key=lambda item: item[1])[:8]:
    print('\t', k, round(v, 2))

Related work

  • imodels package (JOSS 2021 github) - interpretable ML package for concise, transparent, and accurate predictive modeling (sklearn-compatible).
  • Adaptive wavelet distillation (NeurIPS 2021 pdf, github) - distilling a neural network into a concise wavelet model
  • Transformation importance (ICLR 2020 workshop pdf, github) - using simple reparameterizations, allows for calculating disentangled importances to transformations of the input (e.g. assigning importances to different frequencies)
  • Hierarchical interpretations (ICLR 2019 pdf, github) - extends CD to CNNs / arbitrary DNNs, and aggregates explanations into a hierarchy
  • Interpretation regularization (ICML 2020 pdf, github) - penalizes CD / ACD scores during training to make models generalize better
  • PDR interpretability framework (PNAS 2019 pdf) - an overarching framewwork for guiding and framing interpretable machine learning

About

Interpretable prompting and models for NLP (using large language models).

https://csinva.io/imodelsX/


Languages

Language:Python 91.2%Language:Jupyter Notebook 8.8%