pytorch / text

Models, data loaders and abstractions for language processing, powered by PyTorch

Home Page:https://pytorch.org/text

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Overview of issues in torchtext and the plan for revamping

zhangguanheng66 opened this issue · comments

Motivation and summary of the current issues in torchtext

Based on the feedback from users, there are several issues existing in torchtext, including

  • Several components and functionals are unclear and difficult to adopt. For example, Field class couples tokenizer, vocabulary, split, batching and sampling, padding, and numericalization together. The current Field class works as a "black box", and users are confused about what's going on within the class. Instead, those components should be divided into several basic building blocks. This is more consistent with PyTorch core library, which grants users the freedom to build the models and pipelines with orthogonal components.
  • Incompatible with DataLoader and Sampler in torch.utils.data. The current datasets in torchtext are not compatible with PyTorch core library. Some custom modules/functions in torchtext (e.g. Iterator, Batch, splits) should be replaced by the corresponding modules in torch.utils.data.

New datasets in torchtext.experimental.datasets

We have re-written several datasets in torchtext.experimental.datasets which were using the new abstractions. The old version of the datasets are still available in torchtext.datasets and the new datasets are opt-in.

  • Sentiment analysis dataset (#651)
    • IMDB
  • Language modeling datasets (#624), including
    • WikiText2
    • WikiText103
    • PennTreebank

Case study for IMDB dataset

API for new datasets

To load the new datasets, simply call the dataset API, as follow:

from torchtext.experimental.datasets import IMDB
train_dataset, test_dataset = IMDB()

To specify a tokenizer:

from torchtext.data.utils import get_tokenizer
tokenizer = get_tokenizer("spacy")
train_dataset, test_dataset = IMDB(tokenizer=tokenizer)

If you just need the test set (must pass a Vocab object!):

vocab = train_dataset.get_vocab()
test_dataset, = IMDB(tokenizer=tokenizer, vocab=vocab, data_select='test')

Legacy code

The old IMDB dataset is still available in the folder torchtext.datasets. You can use the legacy datasets, as follow:

import torchtext.data as data
TEXT = torchtext.data.Field(lower=True, include_lengths=True, batch_first=True)
LABEL = torchtext.data.Field(sequential=False)
train, test = torchtext.datasets.IMDB.splits(TEXT, LABEL)

Difference

With the old pattern, users have to create a Field object including a specific tokenizer. In the new dataset API, user can pass a custom tokenizer directly to the dataset constructor. A custom tokenizer defines the method to convert a string to a list of tokens

from torchtext.data.utils import get_tokenizer

# Old pattern
TEXT = torchtext.data.Field(tokenize=get_tokenizer("basic_english"))

# New pattern
train_dataset, test_dataset = IMDB(tokenizer=get_tokenizer("spacy"))

In the old dataset, vocab object is associated with Field class, which is not flexible enough to accept a pre-trained vocab object. In the new dataset, the vocab object can be obtained by

vocab = train_dataset.get_vocab()
new_vocab = torchtext.vocab.Vocab(counter=vocab.freqs, max_size=1000, min_freq=10)

and apply to generate other new datasets.

from torchtext.experimental.datasets import WikiText2
train_dataset, test_dataset, valid_dataset = WikiText2(vocab=new_vocab)

The datasets with the new pattern return a tensor of token IDs, instead of tokens in the old pattern. If users would like to retrieve the tokens, simply use the following command:

train_vocab = train_dataset.get_vocab()
# label and text are saved as a tuple
tokens = [train_vocab.itos[id] for id in train_dataset[0][1]]

Unlike the old pattern using BucketIterator.splits, users are encouraged to use torch.utils.data.DataLoader to generate batches of data. You can specify how to batch and pad the samples with a custom function passed to collate_fn. Here is an example to pad sequences with similar lengths and load data through DataLoader. To generate random samples, turn on the shuffle flag in DataLoader. Otherwise, a sequential sampler will be automatically constructed.

# Generate a list of tuples of text length, index, label, text
data_len = [(len(txt), idx, label, txt) for idx, (label, txt) in enumerate(train_dataset)]
data_len.sort() # sort by length and pad sequences with similar lengths

# Generate the pad id
pad_id = train_dataset.get_vocab()['<pad>']

# Generate 8x8 batches
# Pad sequences with similar lengths
import torch
from torch.utils.data import DataLoader
def pad_data(data):
    # Find max length of the mini-batch
    max_len = max(list(zip(*data))[0])
    label_list = list(zip(*data))[2]
    txt_list = list(zip(*data))[3]
    padded_tensors = torch.stack([torch.cat((txt, \
            torch.tensor([pad_id] * (max_len - len(txt))).long())) \
            for txt in txt_list])
    return padded_tensors, label_list

dataloader = DataLoader(data_len, batch_size=8, collate_fn=pad_data)
for idx, (txt, label) in enumerate(dataloader):
    print(idx, txt.size(), label)

Randomly split a dataset into non-overlapping new datasets of given lengths.

from torchtext.experimental.datasets import IMDB
train_dataset, test_dataset = IMDB()
train_subset, valid_subset = torch.utils.data.random_split(train_dataset, [15000, 10000])

Reference:

A few recent issues from OSS users:

  • Sorting sentence within a batch is confusing #641
  • split function is confusing #644
  • Generate vocab object based on a subset of text file #642
  • Pass a pre-trained vocab object to build a dataset #648
  • Load unconstructed text data #649
  • More flexibility to support word vector layers #650
  • More compatible with torch.utils.data.DataLoader #660

The new dataset API looks like a good start - I think it's a good idea to move away from using Field - but there's a few things that still need to be looked at.

I'm currently experimenting on the text_classification datasets in 0.4 using the following:

def create_dataset(dataset, dataset_args, vocab_args):
    
    #assert datset exists in text_classification
    assert dataset in torchtext.datasets.text_classification.DATASETS
        
    #get dataset
    train_data, _ = getattr(torchtext.datasets.text_classification, dataset)(**dataset_args)
    
    #get unfiltered vocab (default args)
    old_vocab = train_data.get_vocab()
    
    #create new filtered vocabulary (with desired args)
    new_vocab = torchtext.vocab.Vocab(counter = old_vocab.freqs, 
                                      **vocab_args)
    
    #return dataset with new vocabulary
    return getattr(torchtext.datasets.text_classification, dataset)(vocab = new_vocab, 
                                                                    **dataset_args)

An example of use:

train_data, test_data = create_dataset('AG_NEWS',
                                       dataset_args = {'ngrams': 1},
                                       vocab_args = {'max_size': 25_000})

With 0.5, we'll be able to pass a tokenizer to the dataset - which is much needed - but as we can't pass arguments to the initialization of the Vocab, it means we have to tokenize and create a vocabulary then use that vocabulary to construct another Vocab with our desired arguments and then tokenize again. Some way to pass Vocab arguments to the dataset is needed.

The default pad and unk tokens should also be part of the Vocab object and should be an argument as we won't have Field to handle them anymore and they're more related to the vocabulary than the dataset. I see that changing the unk is currently a TODO and even though the Vocab object doesn't actually use the pad token anywhere I think it makes sense for it to be part of the Vocab object - happy to hear if you disagree.

Also, a better collate_fn assuming the data is a list of (label, sequence) tuples is:

def collator(batch, pad_idx):
    
    labels, sequences = zip(*batch)

    #FloatTensor for binary classification, LongTensor for multiclass classification
    labels = torch.FloatTensor(labels)
            
    sequences = torch.nn.utils.rnn.pad_sequence(sequences, 
                                                padding_value = pad_idx)
        
    return labels, sequences

rnn.pad_sequences is faster than manipulating the lists and doing a concatenation then a stack.

Then this can be passed to the iterator with:

import functools

pad_idx = train_data.get_vocab()['<pad>']
batch_size = 64

collate_fn = functools.partial(collator, 
                               pad_idx = pad_idx)

train_iterator = torch.utils.data.DataLoader(train_data, 
                                             shuffle = True, 
                                             batch_size = batch_size,
                                             collate_fn = collate_fn)

The collate function above can also be expanded to provide sequence lengths and masks:

def collator(batch, pad_idx):
    
    labels, sequences = zip(*batch)

    labels = torch.FloatTensor(labels)

    lengths = torch.LongTensor([len(sequence) for sequence in sequences])
    
    sequences = torch.nn.utils.rnn.pad_sequence(sequences, 
                                                padding_value = pad_idx)
    
    masks = (sequences != pad_idx)
    
    return labels, sequences, lengths, masks

This gets back the missing include_lengths ability from the Field.

Apart from those small things - I think you've done a good job!

@bentrevett Thanks for your comments and great feedback.
For the vocab part, I think some changes could be made in build_vocab_from_iterator. Then, users build a custom vocab object, which is passed to the dataset interface.

For padding, yes, torch.nn.utils.rnn.pad_sequence is doing the same thing (don't know why it's under the "rnn" category). There are also different ways to pad sequences (like a long Tensor with offsets, a tensor with mask), or even NestedTensor in the future. Those are some interesting blocks to think about.

Batch/iterator issue with RawField #266

translation datasets are too slow with torchtext.data.Iterator (#668). Need to re-write the datasets and make them compatible with torch.utils.data.DataLoader.

Maybe it's worth using torch.utils.data.IterableDataset?

Yeap. We had discussions about this. Since all the other datasets are using torch.utils.data.Dataset in this release, we decided to keep consistent. Thanks for the advice.

Hi everyone!

I am quite new to torchtext (but a PyTorch user for 2+ years). My first impression about the new dataset API is great, however, there is one thing I cannot accomplish in a clean manner.

I am using the Penn Treebank dataset from torchtext.experimental, like this:

 # query training set to build vocab
vocab = PennTreebank(data_select="train")[0].get_vocab()
# create new filtered vocabulary (with word-level embeddings)
self.vocab = torchtext.vocab.Vocab(counter=vocab.freqs, vectors=torchtext.vocab.GloVe(name='6B', dim=self.embedding_dim))

# create datasets
train_dataset, valid_dataset, test_dataset = PennTreebank(vocab=self.vocab)

The problem is, that the returned dataset cannot provide sequences (if a write a collate_fn to reshape the batch to have a sequence dimension, then I cannot shuffle the data), train_dataset[i] will be a single value.

My solution would be to write a custom dataset and copy data from the one returned by torchtext. Is there a more clean way to do that?

In pseudocode, I need something like this:

train_dataset, valid_dataset, test_dataset = PennTreebank(vocab=self.vocab, seq_len=32)
print(train_dataset[0].shape) # should be 32 X 1

Hi everyone!

I am quite new to torchtext (but a PyTorch user for 2+ years). My first impression about the new dataset API is great, however, there is one thing I cannot accomplish in a clean manner.

I am using the Penn Treebank dataset from torchtext.experimental, like this:

 # query training set to build vocab
vocab = PennTreebank(data_select="train")[0].get_vocab()
# create new filtered vocabulary (with word-level embeddings)
self.vocab = torchtext.vocab.Vocab(counter=vocab.freqs, vectors=torchtext.vocab.GloVe(name='6B', dim=self.embedding_dim))

# create datasets
train_dataset, valid_dataset, test_dataset = PennTreebank(vocab=self.vocab)

The problem is, that the returned dataset cannot provide sequences (if a write a collate_fn to reshape the batch to have a sequence dimension, then I cannot shuffle the data), train_dataset[i] will be a single value.

My solution would be to write a custom dataset and copy data from the one returned by torchtext. Is there a more clean way to do that?

In pseudocode, I need something like this:

train_dataset, valid_dataset, test_dataset = PennTreebank(vocab=self.vocab, seq_len=32)
print(train_dataset[0].shape) # should be 32 X 1

Yes. That's the problem for all the three WLM datasets. One way to wrap up the dataset is here

Yes. That's the problem for all the three WLM datasets. One way to wrap up the dataset is here

@zhangguanheng66: Thanks for the help!

Is it possible to use this new-style datasets when loading from csv using a TabularDataset?

Is this something you are looking for? #701

@zhangguanheng66 I think that could work, moving the conversation to #701 for clarification.

Does anyone know how to define batch_size with maximum number of tokens (instead of a fixed batch size) in this setup? I have been playing with BatchSampler but did not succeed so far. Any help is greatly appreciated :-)

Does anyone know how to define batch_size with maximum number of tokens (instead of a fixed batch size) in this setup? I have been playing with BatchSampler but did not succeed so far. Any help is greatly appreciated :-)

Could you open a separate issue for this and attach a code snippet? This issue is to introduce the new dataset abstraction.

Is there an example somewhere of an NLP Dataset using the new scheme which adds '<sos>' and '<eos>' tokens to each sequence? Where would you suggest that should be done, in the dataset transformer pipeline or in the collate function?

Is there an example somewhere of an NLP Dataset using the new scheme which adds '' and '' tokens to each sequence? Where would you suggest that should be done, in the dataset transformer pipeline or in the collate function?

You can do both with the new "scheme". But I guess adding the "and" token id in the collate_fn makes more sentence since you want all the sequence to have it.

Thanks @zhangguanheng66 - for what it's worth I ended up doing the '<sos>' and '<eos>' tokens in the transform. Then I used @bentrevett collate_fn fro the padding.

My main comment is I think you need to update build_vocab to accept Vocab kwargs - I ended up reimplementing as:

def build_vocab(data, transforms, index, **kwargs):
    counter = Counter()
    with tqdm(unit_scale=0, unit='lines') as t:
        for line in data:
            counter.update(transforms(line[index]))
            t.update(1)
    word_vocab = Vocab(counter, **kwargs)
    return word_vocab

src_vocab = build_vocab(train, src_tokenizer, index=0, specials=('<pad>', '<unk>', '<sos>', '<eos>'))
tgt_vocab = build_vocab(train, tgt_tokenizer, index=1, specials=('<pad>', '<unk>', '<sos>', '<eos>'))

The alternative is a bit hacky - constructing a new Vocab from the return of 'build_vocab()' but using different **kwargs.

Yup. We can accept a PR for that.

def build_vocab(data, transforms, index, **kwargs):
    counter = Counter()
    with tqdm(unit_scale=0, unit='lines') as t:
        for line in data:
            counter.update(transforms(line[index]))
            t.update(1)
    word_vocab = Vocab(counter, **kwargs)
    return word_vocab

src_vocab = build_vocab(train, src_tokenizer, index=0, specials=('<pad>', '<unk>', '<sos>', '<eos>'))
tgt_vocab = build_vocab(train, tgt_tokenizer, index=1, specials=('<pad>', '<unk>', '<sos>', '<eos>'))

Could also be written as

def build_vocab(data, transforms, index, **kwargs):
    counter = Counter()
    with tqdm(unit_scale=0, unit='lines') as t:
        for line in data:
            counter.update(line)
            t.update(1)
    word_vocab = Vocab(counter, **kwargs)
    return word_vocab

def transformed(fn, index, data):
    for line in data:
        yield fn(line[index])


src_vocab = build_vocab(transformed(src_tokenizer, 0, train), specials=('<pad>', '<unk>', '<sos>', '<eos>'))
tgt_vocab = build_vocab(transformed(tgt_tokenizer, 1, train), specials=('<pad>', '<unk>', '<sos>', '<eos>'))

That way the only required change is forwarding the arguments form build_vocab to the contructor of Vocab, which is pretty much just adding more functionality to this factory function. I'd not want to see us merge the index and transforms portion of the suggest addition to build_vocab.

@cpuhrsch that makes sense, but I think you left the src_tokenizer args in build_vocab ie

def build_vocab(data,**kwargs):
    counter = Counter()
    with tqdm(unit_scale=0, unit='lines') as t:
        for line in data:
            counter.update(line)
            t.update(1)
    word_vocab = Vocab(counter, **kwargs)
    return word_vocab

def transformed(fn, index, data):
    for line in data:
        yield fn(line[index])

src_vocab = build_vocab(transformed(src_tokenizer, 0, train), specials=('<pad>', '<unk>', '<sos>', '<eos>'))
tgt_vocab = build_vocab(transformed(tgt_tokenizer, 1, train), specials=('<pad>', '<unk>', '<sos>', '<eos>'))

Also I actually combined 'build_vocab' from Multi30k with torchtext.vocab.build_vocab_from_iterator

It seems that torchtext.vocab.build_vocab_from_iterator should be modified to accept the **kwargs and then used as you suggest above i.e.

src_vocab = build_vocab_from_iterator(transformed(src_tokenizer, 0, train), specials=('<pad>', '<unk>', '<sos>', '<eos>'))

I think and also what I mean above, build_vocab_from_iterator should accept an iterator, which yield a list of tokens, and **kwargs, which are passed to the vocabulary constructor.

Here's some feedback after playing around with the new experimental API. First, I'd also like to say that the new API is great - makes for very clean code - the addition of transforms was a good idea and makes it a lot easier to use torchtext with other libraries, such as the huggingface transformers.

As for the feedback:

  1. Experimental vocab should take a max_size argument. An integer denoting the maximum size of the created vocabulary. It is more common to build a vocabulary up to a maximum size rather than to set a minimum frequency of tokens, although I still believe min_freq should remain. This should be an argument instead of the user cutting the ordered_dict to max_size so it can be passed to functions such as build_vocab_from_iterator, vocab_from_file etc. This means that some sort of sorting with respect to token frequency will have to be done internally in the Vocab C++ class(?)

  2. Experimental vocab should take a specials argument. A list of strings, each representing a token that will always be in the vocabulary, regardless of how many times it appears in the ordered_dict and each should be appended to the vocabulary after the <unk> and <pad> tokens but before the rest of the tokens. Used for adding <sos>, <eos>, <mask> tokens. Again, this should be an argument so they can be passed to build_vocab_from_iterator, etc.

  3. Experimental vocab's unk_token argument should be optional and the vocab object should raise an error if the user tries to lookup a token that isn't in the vocab when unk_token is not set. This is useful when building a vocabulary for labels which is easy to do with the new vocab transform API.

  4. Experimental vocab's pad_token argument, from here, should also be optional. Again, for building a label vocabulary. I do believe the pad_token should be its own argument and not be in specials as it was in the legacy vocab.

  5. Experimental functional transforms should be imported with experimental.transforms.functional and not experimental.functional, i.e. there should be a transforms dir in torchtext/experimental with a transforms.py and a functional.py in it. This mirrors the way it is done in torchvision.

  6. Experimental vocab's arguments should be set as attributes of the vocab object. For example, I should be able to create a vocabulary and call vocab.unk_token to get the vocabulary's unk_token, the same with vocab.pad_token, vocab.min_freq, vocab.max_size, vocab.specials, etc.

  7. Experimental's vector's unk_tensor argument should be either a tensor or callable which returns a tensor. At the moment I can't initialize the vector's oov tokens from something like a uniform or Normal distribution without them all being the exact same tensor.

  8. Experimental raw text classification datasets, especially IMDB, are not actually "raw". If I'm getting the raw IMDB data then I want the labels to be "neg"/"pos" and not 0/1. This line is explicitly not making the data "raw" anymore. This would mean, for consistency, that the other text classification datasets should also have their "raw" labels, from here. However, as they are actually already stored with their labels as integers then maybe it's a bit weird to transform them back into strings. Not sure about this one.

  9. vocab_from_raw_text_file, vocab_from_file and vocab in experimental.vocab should be renamed build_vocab_from_raw_text_file, build_vocab_from_file and build_vocab. The first two for consistency with build_vocab_from_iterator and the last one to avoid confusion with the Vocab class/object. It also explains what these functions do a bit better.

Happy to discuss all of these and help with any pull requests if needed.

I think and also what I mean above, build_vocab_from_iterator should accept an iterator, which yield a list of tokens, and **kwargs, which are passed to the vocabulary constructor.

Yes this is what I would propose - change

def build_vocab_from_iterator(iterator):
to

def build_vocab_from_iterator(iterator,**kwargs):
    """
    Build a Vocab from an iterator.
    Arguments:
        iterator: Iterator used to build Vocab. Must yield list or iterator of tokens.
        **kwargs: keyword args passed to Vocab constructor
    """

    counter = Counter()
    with tqdm(unit_scale=0, unit='lines') as t:
        for tokens in iterator:
            counter.update(tokens)
            t.update(1)
    word_vocab = Vocab(counter,**kwargs)
    return word_vocab

I had to change the usage example slightly - as you say build_vocab_from_iterator takes an iterator of tokens (arguably tqdm(unit_scale=0, unit='lines') is incorrect, it should be tqdm(unit_scale=0, unit='tok')) so the pipeline needs to nest iteration by lines and then iteration over the output of the tokenizer ie (could potentially make use of itertools here)

def transformed(fn, data, index):
    for line in data:
        for tok in fn(line[index]):
            yield to

tgt_vocab = build_vocab_from_iterator(transformed(tgt_tokenizer, data, 0), specials=('<pad>', '<unk>', '<sos>', '<eos>'))

Also one other thing I noticed. TranslationDataset.__getitem__ requires transforms, generally in torch they're optional (i.e. if you pass None it will just return data[idx] otherwise transform(data[idx]).

TranslationDataset

TranslationDataset is an abstraction which is used for translation datasets. It assumes transforms defined by users. We also provide iterator to yield raw text strings (equivalent to the case that None is passed to transforms)

Here's some feedback after playing around with the new experimental API. First, I'd also like to say that the new API is great - makes for very clean code - the addition of transforms was a good idea and makes it a lot easier to use torchtext with other libraries, such as the huggingface transformers.

As for the feedback:

  1. Experimental vocab should take a max_size argument. An integer denoting the maximum size of the created vocabulary. It is more common to build a vocabulary up to a maximum size rather than to set a minimum frequency of tokens, although I still believe min_freq should remain. This should be an argument instead of the user cutting the ordered_dict to max_size so it can be passed to functions such as build_vocab_from_iterator, vocab_from_file etc. This means that some sort of sorting with respect to token frequency will have to be done internally in the Vocab C++ class(?)
  2. Experimental vocab should take a specials argument. A list of strings, each representing a token that will always be in the vocabulary, regardless of how many times it appears in the ordered_dict and each should be appended to the vocabulary after the <unk> and <pad> tokens but before the rest of the tokens. Used for adding <sos>, <eos>, <mask> tokens. Again, this should be an argument so they can be passed to build_vocab_from_iterator, etc.
  3. Experimental vocab's unk_token argument should be optional and the vocab object should raise an error if the user tries to lookup a token that isn't in the vocab when unk_token is not set. This is useful when building a vocabulary for labels which is easy to do with the new vocab transform API.
  4. Experimental vocab's pad_token argument, from here, should also be optional. Again, for building a label vocabulary. I do believe the pad_token should be its own argument and not be in specials as it was in the legacy vocab.
  5. Experimental functional transforms should be imported with experimental.transforms.functional and not experimental.functional, i.e. there should be a transforms dir in torchtext/experimental with a transforms.py and a functional.py in it. This mirrors the way it is done in torchvision.
  6. Experimental vocab's arguments should be set as attributes of the vocab object. For example, I should be able to create a vocabulary and call vocab.unk_token to get the vocabulary's unk_token, the same with vocab.pad_token, vocab.min_freq, vocab.max_size, vocab.specials, etc.
  7. Experimental's vector's unk_tensor argument should be either a tensor or callable which returns a tensor. At the moment I can't initialize the vector's oov tokens from something like a uniform or Normal distribution without them all being the exact same tensor.
  8. Experimental raw text classification datasets, especially IMDB, are not actually "raw". If I'm getting the raw IMDB data then I want the labels to be "neg"/"pos" and not 0/1. This line is explicitly not making the data "raw" anymore. This would mean, for consistency, that the other text classification datasets should also have their "raw" labels, from here. However, as they are actually already stored with their labels as integers then maybe it's a bit weird to transform them back into strings. Not sure about this one.
  9. vocab_from_raw_text_file, vocab_from_file and vocab in experimental.vocab should be renamed build_vocab_from_raw_text_file, build_vocab_from_file and build_vocab. The first two for consistency with build_vocab_from_iterator and the last one to avoid confusion with the Vocab class/object. It also explains what these functions do a bit better.

Happy to discuss all of these and help with any pull requests if needed.

@bentrevett Thanks Ben for your valuable feedbacks. We will address those comments in separate issues and cc you there.

TranslationDataset

TranslationDataset is an abstraction which is used for translation datasets. It assumes transforms defined by users. We also provide iterator to yield raw text strings (equivalent to the case that None is passed to transforms)

Yes, it's not a big deal. But instead of passing a list of strings and transforms, you can also pass a list of tensors and None (i.e. pre-proccess the data). I found applying transforms at train time considerably slowed down my training in the case where I was using a custom (slow) tokeniser. So I applied the transforms as the file was loaded but then I had to create my own copy of TranslationDataset since it doesn't check if transforms is None (or validate the constructor args).

TranslationDataset

TranslationDataset is an abstraction which is used for translation datasets. It assumes transforms defined by users. We also provide iterator to yield raw text strings (equivalent to the case that None is passed to transforms)

Yes, it's not a big deal. But instead of passing a list of strings and transforms, you can also pass a list of tensors and None (i.e. pre-proccess the data). I found applying transforms at train time considerably slowed down my training in the case where I was using a custom (slow) tokeniser. So I applied the transforms as the file was loaded but then I had to create my own copy of TranslationDataset since it doesn't check if transforms is None (or validate the constructor args).

You don't need to use this dataset during training if you want to avoid the overhead. Starting with the raw text iterator, you can use the transforms to process the dataset and save them as a list of tensors. Then, pass the list of tensors to DataLoader. In this case, you can avoid the overhead. Another alternative is to check out the dataset and save the processed data as a list of tensors.

Oh right I wasn't aware that DataLoader takes a list of tuples - thanks for that tip, I'll give it a go.

Hi all, I have been playing around with the new api for a while. I am wondering is there a way to add custom 'text_transform' to input. For example, let's say I want to transform all str to lowercase or truncate the text str to a certain length. In my opinion, I think that should be passed as an argument so we can append them to the 'text_transform'.

Also, I am wondering why we are still using the old torchtext.vocab instead of the new experimental vocab in the examples? Anyway, I think it's an interesting change and I am wondering is there anything I can contribute?

Hi all, I have been playing around with the new api for a while. I am wondering is there a way to add custom 'text_transform' to input. For example, let's say I want to transform all str to lowercase or truncate the text str to a certain length. In my opinion, I think that should be passed as an argument so we can append them to the 'text_transform'.

Also, I am wondering why we are still using the old torchtext.vocab instead of the new experimental vocab in the examples? Anyway, I think it's an interesting change and I am wondering is there anything I can contribute?

Thanks for the comment. For you first question, you should check out the raw text data iterator and build a text transform pipeline. This way will give you more flexibility.

For your second comment, we will switch to the new vocabulary once we are done with some cleanup.

Hi all, I have been playing around with the new api for a while. I am wondering is there a way to add custom 'text_transform' to input. For example, let's say I want to transform all str to lowercase or truncate the text str to a certain length. In my opinion, I think that should be passed as an argument so we can append them to the 'text_transform'.
Also, I am wondering why we are still using the old torchtext.vocab instead of the new experimental vocab in the examples? Anyway, I think it's an interesting change and I am wondering is there anything I can contribute?

Thanks for the comment. For you first question, you should check out the raw text data iterator and build a text transform pipeline. This way will give you more flexibility.

For your second comment, we will switch to the new vocabulary once we are done with some cleanup.

Thank you for the reply. I wonder whether the API support us to write a Dataset object for custom dataset? It seems to be hard to do so with the new API. Like the text classification dataset, the build vocab and the transform pipeline is written in the _setup_datasets function, which is not accessible for us if we were to build a custom text classification dataset.

@zhangguanheng66 I believe @KwanWaiChung comment is very relevant, and we should make it very easy for users to understand how to write their own dataset.

Hi all, I have been playing around with the new api for a while. I am wondering is there a way to add custom 'text_transform' to input. For example, let's say I want to transform all str to lowercase or truncate the text str to a certain length. In my opinion, I think that should be passed as an argument so we can append them to the 'text_transform'.
Also, I am wondering why we are still using the old torchtext.vocab instead of the new experimental vocab in the examples? Anyway, I think it's an interesting change and I am wondering is there anything I can contribute?

Thanks for the comment. For you first question, you should check out the raw text data iterator and build a text transform pipeline. This way will give you more flexibility.
For your second comment, we will switch to the new vocabulary once we are done with some cleanup.

Thank you for the reply. I wonder whether the API support us to write a Dataset object for custom dataset? It seems to be hard to do so with the new API. Like the text classification dataset, the build vocab and the transform pipeline is written in the _setup_datasets function, which is not accessible for us if we were to build a custom text classification dataset.

Here's a minimal example of how to use your own data - here given as a very small list - to create a TextClassificationDataset:

import torch
from torchtext.experimental.datasets.text_classification import TextClassificationDataset
from torchtext.experimental.vocab import build_vocab_from_iterator
from torchtext.data.utils import get_tokenizer
from torchtext.experimental.functional import sequential_transforms, vocab_func, totensor

# load data from whatever format it's saved in to an iterable of (label, text)
my_data = [('pos', 'this film is great'), ('neg', 'this film is bad'), ('neg', 'this film is awful')]

# tokenizer can be any callable function that goes from str -> list[str]
my_tokenizer = get_tokenizer('basic_english')

# build vocabulary from data
my_vocab = build_vocab_from_iterator([my_tokenizer(text) for label, text in my_data])

# how should the label be transformed?
# str -> int -> LongTensor
label_transforms = sequential_transforms(lambda x: 1 if x == 'pos' else 0, totensor(torch.long))

# how should the text be transformed?
# str -> list[str] -> list[int] -> LongTensor
text_transforms = sequential_transforms(my_tokenizer, vocab_func(my_vocab), totensor(torch.long))

# tuple the transforms
my_transforms = (label_transforms, text_transforms)

# create TextClassificationDataset with data, vocabulary and transforms
dataset = TextClassificationDataset(my_data, my_vocab, my_transforms)

The only missing steps to apply this to actual data would be to add some code that loads your data into the list of (label, text) tuples. Any pre-processing desired can be handled by writing your own custom tokenizer function or any other functions that will fit within the sequential_transforms.

Thanks @bentrevett for the comment and the comment explains the process very well. I will use the Language Modeling dataset as an example and explain again how it works.

For the experimental datasets in torchtext, you can have two kinds

  • Raw dataset which is an iterator for the raw strings
  • Non-raw dataset which generates the processed data based on user-defined vocab and tokenizer.

Who should use non-raw datasets? For those who want to load the processed data with the single command, you should use the LM datasets, like torchtext.experimental.datasets.WikiText103. If you have a custom vocab or tokenizer, pass them to non-raw dataset constructor

def WikiText103(*args, **kwargs):
""" Defines WikiText103 datasets.
Create language modeling dataset: WikiText103
Separately returns the train/test/valid set
Arguments:
tokenizer: the tokenizer used to preprocess raw text data.
The default one is basic_english tokenizer in fastText. spacy tokenizer
is supported as well (see example below). A custom tokenizer is callable
function with input of a string and output of a token list.
root: Directory where the datasets are saved. Default: ".data"
vocab: Vocabulary used for dataset. If None, it will generate a new
vocabulary based on the train data set.
data_select: a string or tupel for the returned datasets
(Default: ('train', 'test','valid'))
By default, all the three datasets (train, test, valid) are generated. Users
could also choose any one or two of them, for example ('train', 'test') or
just a string 'train'. If 'train' is not in the tuple or string, a vocab
object should be provided which will be used to process valid and/or test
data.
single_line: whether to return all tokens in a single line.
(Default: True)
By default, all lines in raw text file are concatenated into a single line.
Use `single_line = False` if one wants to get data line by line.
Examples:
>>> from torchtext.experimental.datasets import WikiText103
>>> from torchtext.data.utils import get_tokenizer
>>> tokenizer = get_tokenizer("spacy")
>>> train_dataset, test_dataset, valid_dataset = WikiText103(tokenizer=tokenizer)
>>> vocab = train_dataset.get_vocab()
>>> valid_dataset, = WikiText103(tokenizer=tokenizer, vocab=vocab,
data_select='valid')

If you want more flexibility, like "truncate the text str to a certain length" requested by @KwanWaiChung, you have to use the raw dataset with the custom text transform pipeline. So how to do that? You can treat _setup_datasets func as an example to set up the transform pipeline.

def _setup_datasets(dataset_name, tokenizer=None, root='.data', vocab=None,
data_select=('train', 'test', 'valid'), single_line=True):

So what exactly _setup_datasets func is doing? It basically just set up the transform pipeline. See here

def text_transform(line):
return torch.tensor([vocab[token] for token in tokenizer(line)], dtype=torch.long)
.

In order to have the transform pipeline, you have to obtain a tokenizer https://github.com/pytorch/text/blob/master/torchtext/experimental/datasets/language_modeling.py#L65-L66

and generate a vocabulary

if vocab is None:
if 'train' not in data_select:
raise TypeError("Must pass a vocab if train is not selected.")
raw_train, = raw.DATASETS[dataset_name](root=root, data_select=('train',))
vocab = build_vocab(raw_train, tokenizer)
.

At the end, you pass the data and transforms to the language modeling abstraction to have a map-style dataset, which works with Dataloader

return tuple(LanguageModelingDataset(raw_data[item], vocab, text_transform, single_line)
for item in data_select)

I'm working on a more hands-on tutorial to show how to build a dataset with the ideas of building blocks.

Hi all, I have been playing around with the new api for a while. I am wondering is there a way to add custom 'text_transform' to input. For example, let's say I want to transform all str to lowercase or truncate the text str to a certain length. In my opinion, I think that should be passed as an argument so we can append them to the 'text_transform'.
Also, I am wondering why we are still using the old torchtext.vocab instead of the new experimental vocab in the examples? Anyway, I think it's an interesting change and I am wondering is there anything I can contribute?

Thanks for the comment. For you first question, you should check out the raw text data iterator and build a text transform pipeline. This way will give you more flexibility.
For your second comment, we will switch to the new vocabulary once we are done with some cleanup.

Thank you for the reply. I wonder whether the API support us to write a Dataset object for custom dataset? It seems to be hard to do so with the new API. Like the text classification dataset, the build vocab and the transform pipeline is written in the _setup_datasets function, which is not accessible for us if we were to build a custom text classification dataset.

Here's a minimal example of how to use your own data - here given as a very small list - to create a TextClassificationDataset:

import torch
from torchtext.experimental.datasets.text_classification import TextClassificationDataset
from torchtext.experimental.vocab import build_vocab_from_iterator
from torchtext.data.utils import get_tokenizer
from torchtext.experimental.functional import sequential_transforms, vocab_func, totensor

# load data from whatever format it's saved in to an iterable of (label, text)
my_data = [('pos', 'this film is great'), ('neg', 'this film is bad'), ('neg', 'this film is awful')]

# tokenizer can be any callable function that goes from str -> list[str]
my_tokenizer = get_tokenizer('basic_english')

# build vocabulary from data
my_vocab = build_vocab_from_iterator([my_tokenizer(text) for label, text in my_data])

# how should the label be transformed?
# str -> int -> LongTensor
label_transforms = sequential_transforms(lambda x: 1 if x == 'pos' else 0, totensor(torch.long))

# how should the text be transformed?
# str -> list[str] -> list[int] -> LongTensor
text_transforms = sequential_transforms(my_tokenizer, vocab_func(my_vocab), totensor(torch.long))

# tuple the transforms
my_transforms = (label_transforms, text_transforms)

# create TextClassificationDataset with data, vocabulary and transforms
dataset = TextClassificationDataset(my_data, my_vocab, my_transforms)

The only missing steps to apply this to actual data would be to add some code that loads your data into the list of (label, text) tuples. Any pre-processing desired can be handled by writing your own custom tokenizer function or any other functions that will fit within the sequential_transforms.

Thanks for the detailed example, that's really clear! I am wondering currently can we load pretrained word vectors like before? Or is it some functionality that is planned to add later? Just asking because there are some comments above talking about the issues of the experimental Vocab class, but the loading pretrained vectors is not mentioned.

Hi all, I have been playing around with the new api for a while. I am wondering is there a way to add custom 'text_transform' to input. For example, let's say I want to transform all str to lowercase or truncate the text str to a certain length. In my opinion, I think that should be passed as an argument so we can append them to the 'text_transform'.
Also, I am wondering why we are still using the old torchtext.vocab instead of the new experimental vocab in the examples? Anyway, I think it's an interesting change and I am wondering is there anything I can contribute?

Thanks for the comment. For you first question, you should check out the raw text data iterator and build a text transform pipeline. This way will give you more flexibility.
For your second comment, we will switch to the new vocabulary once we are done with some cleanup.

Thank you for the reply. I wonder whether the API support us to write a Dataset object for custom dataset? It seems to be hard to do so with the new API. Like the text classification dataset, the build vocab and the transform pipeline is written in the _setup_datasets function, which is not accessible for us if we were to build a custom text classification dataset.

Here's a minimal example of how to use your own data - here given as a very small list - to create a TextClassificationDataset:

import torch
from torchtext.experimental.datasets.text_classification import TextClassificationDataset
from torchtext.experimental.vocab import build_vocab_from_iterator
from torchtext.data.utils import get_tokenizer
from torchtext.experimental.functional import sequential_transforms, vocab_func, totensor

# load data from whatever format it's saved in to an iterable of (label, text)
my_data = [('pos', 'this film is great'), ('neg', 'this film is bad'), ('neg', 'this film is awful')]

# tokenizer can be any callable function that goes from str -> list[str]
my_tokenizer = get_tokenizer('basic_english')

# build vocabulary from data
my_vocab = build_vocab_from_iterator([my_tokenizer(text) for label, text in my_data])

# how should the label be transformed?
# str -> int -> LongTensor
label_transforms = sequential_transforms(lambda x: 1 if x == 'pos' else 0, totensor(torch.long))

# how should the text be transformed?
# str -> list[str] -> list[int] -> LongTensor
text_transforms = sequential_transforms(my_tokenizer, vocab_func(my_vocab), totensor(torch.long))

# tuple the transforms
my_transforms = (label_transforms, text_transforms)

# create TextClassificationDataset with data, vocabulary and transforms
dataset = TextClassificationDataset(my_data, my_vocab, my_transforms)

The only missing steps to apply this to actual data would be to add some code that loads your data into the list of (label, text) tuples. Any pre-processing desired can be handled by writing your own custom tokenizer function or any other functions that will fit within the sequential_transforms.

Thanks for the detailed example, that's really clear! I am wondering currently can we load pretrained word vectors like before? Or is it some functionality that is planned to add later? Just asking because there are some comments above talking about the issues of the experimental Vocab class, but the loading pretrained vectors is not mentioned.

The way I've been using pre-trained vectors is by loading them and then "aligning" them with the vocabulary to create a pretrained_embedding tensor which I use to replace the randomly initialized weights of my model's nn.Embedding layer. Here's a code example that follows on from the previous code:

from torchtext.experimental.vectors import GloVe

# define desired embedding dim
emb_dim = 100

# get pretrained glove vectors
glove = GloVe(name = '6B',
              dim = emb_dim)

# create a tensor used for holding the pre-trained vectors for each element of the vocab
pretrained_embedding = torch.zeros(len(my_vocab), emb_dim)

# get the pretrained vector's vocab, Dict[str, int]
pretrained_vocab = glove.vectors.get_stoi()

# iterate over your vocab's `itos` attribute, a list of tokens within the vocab
# if the token is in the pre-trained vocab, i.e. if it has a pre-trained vector
# then replace its row in the pre-trained embedding tensor with the pre-trained vector
# if the token is NOT in the pre-trained vocab, we leave it initialized to zero
for idx, token in enumerate(my_vocab.get_itos()):
    if token in pretrained_vocab:
        pretrained_vector = glove[token] # pretrained_vector is a FloatTensor pre-trained vector for `token`
        pretrained_embedding[idx] = pretrained_vector # update the appropriate row in pretrained_embedding

# at this point we have the aligned pre-trained vectors, but we need to actually use them in our model

# later on, when you've defined your model with an nn.Embedding layer called `embedding`
# replace the randomly initialized embedding with your pre-trained embedding
model.embedding.weight.data.copy_(pretrained_embedding)

I've been trying to use pad_sequence using the following collator

def collator(batch, pad_idx):
    
    labels, sequences = zip(*batch)

    labels = torch.FloatTensor(labels)

    lengths = torch.LongTensor([len(sequence) for sequence in sequences])
    
    sequences = torch.nn.utils.rnn.pad_sequence(sequences, 
                                                padding_value = pad_idx)
    
    masks = (sequences != pad_idx)
    
    return labels, sequences, lengths, masks

but this results in

RuntimeError: lengths array must be sorted in decreasing order when enforce_sorted is True. You can pass enforce_sorted=False to pack_padded_sequence and/or pack_sequence to sidestep this requirement if you do not need ONNX exportability.

i'm not sure what would be good way to go about doing this since there's been a lot of changes in torchtext recently. I do want them sorted.

@satyajitghana To avoid this error you must pass enforce_sorted=False to the pack_padded_sequence function call.