pinglmlcv / DiLM

Implementaiton of "DiLM: Distilling Dataset into Language Model for Text-level Dataset Distillation" (accepted by NAACL2024 Findings)".

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

DiLM: Distilling Dataset into Language Model

Implementaiton of "DiLM: Distilling Dataset into Language Model for Text-level Dataset Distillation" (accepted by NAACL2024 Findings)".

Abstract: Dataset distillation aims to compress a training dataset by creating a small number of informative synthetic samples such that neural networks trained on them perform as well as those trained on the original training dataset. Current text dataset distillation methods create each synthetic sample as a sequence of word embeddings instead of a text to apply gradient-based optimization; however, such embedding-level distilled datasets cannot be used for training other models whose word embedding weights are different from the model used for distillation. To address this issue, we propose a novel text dataset distillation approach, called Distilling dataset into Language Model (DiLM), which trains a language model to generate informative synthetic training samples as text data, instead of directly optimizing synthetic samples. We evaluated DiLM on various text classification datasets and showed that distilled synthetic datasets from DiLM outperform those from current coreset selection methods. DiLM achieved remarkable generalization performance in training different types of models and in-context learning of large language models. Our code will be available at https://github.com/arumaekawa/DiLM.

Paper: [arXiv], [NAACL2024 Findings]

Contents

This repository utilizes PyTorch and modern experiment manager tools, Hydra and MLflow.

Datasets and pre-trained models are downloaded and used with Hugging Face.

Directory structure

.
├── configs
│  ├── test
│  │  ├── coreset.yaml
│  │  ├── dc.yaml
│  │  └── lm.yaml
│  └── train
│     ├── generator
│     │  ├── pretrained_mnli.yaml
│     │  ├── pretrained_qqp.yaml
│     │  └── pretrained_sst2.yaml
│     ├── dc.yaml
│     └── lm.yaml
├── src
│  ├── coreset
│  │  ├── __init__.py
│  │  ├── coreset_base.py
│  │  ├── coreset_utils.py
│  │  ├── herding.py
│  │  ├── k_centers.py
│  │  ├── random.py
│  │  └── rank_text_gtn.py
│  ├── distillation
│  │  ├── __init__.py
│  │  ├── distilled_data.py
│  │  ├── trainer_base.py
│  │  ├── trainer_dc.py
│  │  └── trainer_lm.py
│  ├── data.py
│  ├── dataset_attrs.py
│  ├── evaluator.py
│  ├── generator.py
│  ├── learner.py
│  ├── test.py
│  ├── train.py
│  └── utils.py
├── README.md
└── requirements.txt

Run Scripts

  1. Install packages (Python 3.10)

    $ pip install -r requirements.txt
  2. Run pre-training (LM)

     $ python src/train.py --config-name=lm data.task_name=sst2
  3. Run dataset fine-tuning (Gradient Matching)

     $ python src/train.py --config-name=dc data.task_name=sst2 +generator=pretrained_sst2
  4. Run evaluation

     $ python src/test.py --config-name=dc data.task_name=sst2 generator.pretrained_model_dir=path/to/pretrained_model_dir
  5. Check the results with MLFlow (http://localhost:5000)

     $ mlflow server --backend-store-uri ./mlruns --host 0.0.0.0 --port 5000

Citation

TBW

About

Implementaiton of "DiLM: Distilling Dataset into Language Model for Text-level Dataset Distillation" (accepted by NAACL2024 Findings)".

License:MIT License


Languages

Language:Python 100.0%