This repository contains code for STaR-GATE: Teaching Language Models to Ask Clarifying Questions.
When prompting language models to complete a task, users often leave important aspects unsaid. While asking questions could resolve this ambiguity (GATE; Li et al., 2023), models often struggle to ask good questions. We explore a language model's ability to self-improve (STaR; Zelikman et al., 2022) by rewarding the model for generating useful questions-a simple method we dub STaR-GATE. We generate a synthetic dataset of 25,500 unique persona-task prompts to simulate conversations between a pretrained language model-the Questioner-and a Roleplayer whose preferences are unknown to the Questioner. By asking questions, the Questioner elicits preferences from the Roleplayer. The Questioner is iteratively finetuned on questions that increase the probability of high-quality responses to the task, which are generated by an Oracle with access to the Roleplayer's latent preferences. After two iterations of self-improvement, the Questioner asks better questions, allowing it to generate responses that are preferred over responses from the initial model on 72% of tasks. Our results indicate that teaching a language model to ask better questions leads to better personalized responses.
The final model checkpoint for the main experiment is posted here on the HuggingFace hub. Reach out over email or X (linked in my GitHub profile) if you have any questions.
When creating your conda environment to set up the project, first navigate to the root directory and run the following commands:
pip install -e .
pip install flash-attn --no-build-isolation
By default, all data and model checkpoints are saved in subdirectories under a directory called /scr/andukuri/assistant-gate-hgx
. Make sure you adjust file paths as necessary. The final contents of the directory - corresponding to the paper's results - can be accessed here.
To train models (not just use them), set up a Weights & Biases account for experiment logging, and create a project called assistant-gate
. After you create a project and log in to Weights & Biases using the command line interface as described here, all training runs as described below should be logged.
Assuming you've reconfigured pointers to directories carefully as described above, and navigated to experiments/star-gate
, STaR-GATE can be run end-to-end with a series of shell scripts which point to organized python files. Note that depending on whether you use a SLURM job scheduler (as in the existing scripts) or interact directly with your machine, you may need to adjust the setup in these shell scripts.
- Extract initial tasks from source dataset. Run the shell script
instruct-questions/scripts/extract-all.sh
. - Generate personas few-shot. Run the shell script
persona-generation/scripts/generate-personas.sh
. Before this step, you should make sure your OpenAI API key has been set in your environment usingexport OPENAI_API_KEY=<your api key>
. - Construct the oracle responses by giving GPT-4 access to both personas and tasks for each split. Run the shell scripts
build-gold-responses/scripts/generate-all-gold-responses.sh
andbuild-gold-responses/scripts/generate-all-gold-responses-test.sh
. You may have to runbuild-gold-responses/scripts/check-content-violations.sh
beforehand; in some cases GPT-4 generates personas which another copy of itself might consider offensive, so flagging these can help you track down the offending persona.
Important
The above step is very expensive. We make one GPT-4 call for each item in (A) training split A
with 250 tasks and 50 personas, (B) training split B
with 250 tasks and 50 personas, and (C) the test
split with 50 tasks and 10 personas. This step only happens once, so be careful and make sure your directories and output paths are configured correctly so the output oracle responses get saved the first time around.
At each iteration
-
Simulate conversations to generate training data. Run the shell script
simulate-conversations/scripts/m{t}-{split}.sh
, and pool the conversations by runningsimulate-conversations/scripts/m{t}-{split}-pool.sh
. The appropriate split isA
for even$t$ , andB
for odd$t$ ; we alternate splits to ensure that$m_t$ 's high-quality generated conversations are not memorized from the previous iteration. -
Calculate log-probabilities of oracle responses to filter best questions. Run the shell script
log-probs/scripts/m{t}-{split}-log-probs.sh
, and filter the conversations by runninglog-probs/scripts/m{t}-{split}-filter.sh
. In the paper, we keep the topk = 1
conversations out of 10 for each persona-task combination for the training set. -
Generate regularizer responses. Run the shell script
sft/preprocess/scripts/m{t}-{split}-model-responses.sh
. -
Preprocess the data for training. Run the shell script
sft/preprocess/scripts/m{t}-{split}-split.sh
. -
Train the initial model
$m_0$ to produce the weights$m_{t + 1}$ . Run the shell scriptsft/train/scripts/train-sft-m{t}-{split}.sh
.
At each iteration
- Simulate conversations for the
test
split. Run the shell scriptsimulate-conversations/scripts/m{t}-test.sh
, and pool the conversations by runningsimulate-conversations/scripts/m{t}-test-pool.sh
. - Calculate log-probabilities of oracle responses for the
test
split. Run the shell scriptlog-probs/scripts/m{t}-test-log-probs.sh
.
-
Generate responses from all
$m_t$ for$t \in [0, 1, 2, 3]$ , conditioned on a randomly sampled conversation for each persona-task combination. Run the shell scriptresponse-win-rates-randomized-zero-shot/scripts/get-responses.sh
. -
Generate win rates by prompting GPT-4 to select the more apt response between
$m_t$ and$m_0$ . Run the shell scriptresponse-win-rates-randomized-zero-shot/scripts/get-ratings.sh
.