RR-28023 / lightning-transformers

Flexible components pairing πŸ€— Transformers with Pytorch Lightning

Home Page:https://lightning-transformers.readthedocs.io

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Flexible components pairing πŸ€— Transformers with Pytorch Lightning


Docs β€’ Community


Installation

pip install lightning-transformers
From Source
git clone https://github.com/PyTorchLightning/lightning-transformers.git
cd lightning-transformers
pip install .

What is Lightning-Transformers

Lightning Transformers provides LightningModules, LightningDataModules and Strategies to use πŸ€— Transformers with the PyTorch Lightning Trainer.

Quick Recipes

Train bert-base-cased on the CARER emotion dataset using the Text Classification task.

import pytorch_lightning as pl
from transformers import AutoTokenizer

from lightning_transformers.task.nlp.text_classification import (
    TextClassificationDataModule,
    TextClassificationTransformer,
    TextClassificationDataConfig,
)

tokenizer = AutoTokenizer.from_pretrained(
    pretrained_model_name_or_path="bert-base-cased"
)
dm = TextClassificationDataModule(
    cfg=TextClassificationDataConfig(
        batch_size=1,
        dataset_name="emotion",
        max_length=512,
    ),
    tokenizer=tokenizer,
)
model = TextClassificationTransformer(
    pretrained_model_name_or_path="bert-base-cased", num_labels=dm.num_classes
)

trainer = pl.Trainer(accelerator="auto", devices="auto", max_epochs=1)

trainer.fit(model, dm)

Train a pre-trained mt5-base backbone on the WMT16 dataset using the Translation task.

import pytorch_lightning as pl
from transformers import AutoTokenizer

from lightning_transformers.task.nlp.translation import (
    TranslationTransformer,
    WMT16TranslationDataModule,
    TranslationConfig,
    TranslationDataConfig,
)

tokenizer = AutoTokenizer.from_pretrained(
    pretrained_model_name_or_path="google/mt5-base"
)
model = TranslationTransformer(
    pretrained_model_name_or_path="google/mt5-base",
    cfg=TranslationConfig(
        n_gram=4,
        smooth=False,
        val_target_max_length=142,
        num_beams=None,
        compute_generate_metrics=True,
    ),
)
dm = WMT16TranslationDataModule(
    cfg=TranslationDataConfig(
        dataset_name="wmt16",
        # WMT translation datasets: ['cs-en', 'de-en', 'fi-en', 'ro-en', 'ru-en', 'tr-en']
        dataset_config_name="ro-en",
        source_language="en",
        target_language="ro",
        max_source_length=128,
        max_target_length=128,
    ),
    tokenizer=tokenizer,
)
trainer = pl.Trainer(accelerator="auto", devices="auto", max_epochs=1)

trainer.fit(model, dm)

Lightning Transformers supports a bunch of πŸ€— tasks and datasets. See the documentation.

Contribute

Pull requests are welcome. For major changes, please open an issue first to discuss what you would like to change.

Please make sure to update tests as appropriate.

Community

For help or questions, join our huge community on Slack!

About

Flexible components pairing πŸ€— Transformers with Pytorch Lightning

https://lightning-transformers.readthedocs.io

License:Apache License 2.0


Languages

Language:Python 99.6%Language:Makefile 0.4%