VirtualRoyalty / gan-plus-nlp

Generative adversarial approach to most popular NLP tasks

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

tests

🦍 Semi-supervised learning for NLP via GAN


Semi-supervised learning for NLP tasks via GAN. Such approach can be used to enhance models in terms of small bunch of labeled examples.

Example usage

see detailed in examples

import model
from trainer import gan_trainer as gan_trainer_module

...

config["encoder_name"] = "distilbert-base-uncased"

discriminator = model.DiscriminatorForSequenceClassification(**config)

generator = model.SimpleSequenceGenerator(
    input_size=config["noise_size"],
    output_size=discriminator.encoder.config.hidden_size,
)

gan_trainer = gan_trainer_module.GANTrainerSequenceClassification(
    config=config,
    discriminator=discriminator,
    generator=generator,
    train_dataloader=loaders["train"],
    valid_dataloader=loaders["valid"],
    device=device,
    save_path=config["save_path"],
)

for epoch_i in range(config["num_train_epochs"]):
    print(
        f"======== Epoch {epoch_i + 1} / {config['num_train_epochs']} ========"
    )
    train_info = gan_trainer.train_epoch(log_env=None)
    result = gan_trainer.validation(log_env=None)

...

predict_info = gan_trainer.predict(
    discriminator, loaders["test"], label_names=config["label_names"]
)
print(predict_info["overall_f1"])

Supported tasks are following:

  • βœ… text classification (see DiscriminatorForSequenceClassification)
    • + multiple input text classification (e.g. NLI, paraprhase detection)
  • βœ… multi-label text classification (see DiscriminatorForMultiLabelClassification)
  • βœ… token classification (e.g. NER, see DiscriminatorForTokenClassification)
  • βœ… multiple choice tasks (see DiscriminatorForMultipleChoice)

Repo structure:

.
β”œβ”€β”€ base
β”‚   β”œβ”€β”€ __init__.py
β”‚   β”œβ”€β”€ base_model.py
β”‚   └── base_trainer.py
β”œβ”€β”€ model
β”‚   β”œβ”€β”€ __init__.py
β”‚   β”œβ”€β”€ discriminator.py
β”‚   β”œβ”€β”€ generator.py
β”‚   └── utils.py
β”œβ”€β”€ trainer
β”‚   β”œβ”€β”€ __init__.py
β”‚   β”œβ”€β”€ gan_trainer.py
β”‚   └── trainer.py
β”œβ”€β”€ tests
β”‚   β”œβ”€β”€ __init__.py
β”‚   β”œβ”€β”€ base_tests.py
β”‚   β”œβ”€β”€ model_tests.py
β”‚   └── trainer_tests.py
β”œβ”€β”€ examples
β”œβ”€β”€ LICENSE
β”œβ”€β”€ README.md
└── requirements.txt

This work based on GAN-BERT: Generative Adversarial Learning for Robust Text Classification with a Bunch of Labeled Examples, (Croce et al, 2020)


@article{VirtualRoyalty,
  title   = "Semi-supervised learning for natural language processing via GAN.",
  author  = "Alperovich, Vadim",
  year    = "2023",
  url     = "https://github.com/VirtualRoyalty/gan-plus-nlp",
}

About

Generative adversarial approach to most popular NLP tasks

License:MIT License


Languages

Language:Jupyter Notebook 95.3%Language:Python 4.7%