lucidrains / PaLM-rlhf-pytorch

Implementation of RLHF (Reinforcement Learning with Human Feedback) on top of the PaLM architecture. Basically ChatGPT but with PaLM

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Encoder-Decoder

Bachstelze opened this issue · comments

The follow-up research from PaLM switched in Flan-PaLM to the encoder-decoder t5 architecture. How would it be possible to also add an encoder to this implementation?

@Bachstelze I am curious as to where you are seeing that the Flan-PaLM architecture switched to encoder-decoder?

This is taken from the paper Scaling Instruction-Finetuned Language Models.
Screenshot_20221227_024016_Drive

@conceptofmind
Sorry, I got confused by this figure from UL2 and concluded that they switched completely to encode-decoder models:
image
Description: In both decoder-only and encoder-decoder setups, UL2 strikes a significantly improved balance in performance between fine-tuned discriminative tasks and prompt-based 1-shot open-ended text generation compared to previous methods. (All models are comparable in terms of computational costs, i.e., FLOPs (EncDec models are 300M and Dec models are 150M parameters).

In the end, the t5 models are the only released models (somewhere should also be the ul2-t5 weights).

@Bachstelze yes indeed, i can get an encoder / decoder into the repository early next month. going to take a break around the new years to go enjoy some electronic music, but will be back in full force after the holidays.

@conceptofmind Sorry, I got confused by this figure from UL2 and concluded that they switched completely to encode-decoder models: image Description: In both decoder-only and encoder-decoder setups, UL2 strikes a significantly improved balance in performance between fine-tuned discriminative tasks and prompt-based 1-shot open-ended text generation compared to previous methods. (All models are comparable in terms of computational costs, i.e., FLOPs (EncDec models are 300M and Dec models are 150M parameters).

In the end, the t5 models are the only released models (somewhere should also be the ul2-t5 weights).

If you are looking into using the ul2 training objective for enc-dec as well you would need something like this to set up the tasks with Seqio:

import functools
import tensorflow as tf
import seqio
import t5.data
from typing import Optional, Sequence

# UL2 paper appendix code missed this function
def prepend_prompt(dataset: tf.data.Dataset,
                   output_features: seqio.preprocessors.OutputFeaturesType,
                   sequence_length: Optional[
                       seqio.preprocessors.SequenceLengthType] = None,
                   prompt_mode: str = "",
                   key: str = "inputs",
                   mode: str = "") -> tf.data.Dataset:
    """Prepends a prompt at the beginning of an input sequence."""
    del sequence_length
    if prompt_mode and mode:
        # output_features may not have inputs key
        out_keys = list(output_features.keys())
        prompt_tokens = output_features[out_keys[0]
                                        ].vocabulary.encode_tf(prompt_mode)

        def add_to_inputs(x):
            x[key] = tf.concat([prompt_tokens, x[key]], axis=0)
            return x

        dataset = dataset.map(add_to_inputs)
    return dataset

# modified from t5.data.preprocessors because output_features may not have inputs key
def split_tokens_to_inputs_length(dataset, sequence_length,
                                  output_features, **kwargs):
    max_tokens = sequence_length['inputs']
    # output_features may not have inputs key
    out_keys = list(output_features.keys())
    if output_features[out_keys[0]].add_eos:
        # Leave room to insert an EOS token.
        max_tokens -= 1

    return t5.data.preprocessors.split_tokens(dataset, max_tokens_per_segment=max_tokens, **kwargs)

# modified from t5.data.preprocessors because output_features may not have inputs key
def prefix_lm(dataset, sequence_length, output_features):
    """Prefix language modeling objective used in Raffel et al. 2019."""
    ds = dataset
    ds = t5.data.preprocessors.select_random_chunk(ds, output_features=output_features,
                                                   feature_key='targets', max_length=65536)
    ds = split_tokens_to_inputs_length(ds, output_features=output_features,
                                       sequence_length=sequence_length)
    ds = t5.data.preprocessors.denoise(
        ds,
        output_features,
        inputs_fn=t5.data.preprocessors.drop_nonnoise_tokens,
        targets_fn=t5.data.preprocessors.drop_noise_tokens,
        noise_density=0.5,
        noise_mask_fn=t5.data.preprocessors.random_prefix_noise_mask,
    )
    return ds

# copied from UL2 paper https://arxiv.org/pdf/2205.05131.pdf appendix chapter 9.2
# note: modified to use the prefix_lm() from above instead of the default t5.data.preprocessors.prefix_lm() because output_features may not have inputs key
def ul2_objective(dataset: tf.data.Dataset,
                  sequence_length: seqio.preprocessors.SequenceLengthType,
                  output_features: seqio.preprocessors.OutputFeaturesType,
                  use_prefix_lm_task: bool = False,
                  rates: Optional[Sequence[float]] = None,
                  mean_noise_span_lengths: Sequence[float] = (3.0,),
                  noise_densities: Sequence[float] = (0.15,),
                  shard_ds: bool = True,
                  optional_task_prefixes: Optional[Sequence[str]] = None,
                  input_feature_key: str = "inputs",
                  merge_examples_to_reduce_padding: bool = True,
                  reserved_for_packing: bool = None,
                  seed: int = 7) -> tf.data.Dataset:
    """UL2-like pre-training objectives.
    This preprocessor amounts to calling the 'span_corruption' function several
    times with different values of 'noise_density' and 'mean_noise_span_length'.
    We either shard or copy the dataset, then apply each function to each shard.
    Add S-denoising (prefixLM) using use_prefix_lm_task.
    Args:
        dataset: A tf.data.Dataset with dictionaries containing the key 'input_feature_key'.
        sequence_length: dict mapping of feature key to int length for that feature.
        output_features: mapping of keys to features.
        use_prefix_lm_task: <bool> If True, include PrefixLM in the task mix.
        rates: <Optional<List<float>> List of rates per task. If None, tasks are sampled uniformly.
        mean_noise_span_lengths: List of mean number of tokens per masked span per example.
        noise_densities: List of what fraction of the tokens to mask.
        shard_ds: <bool> If True, shard dataset per objective.
        optional_task_prefixes: <Optional<list<str>> Strings to prepend for each corruption scheme. NOTE: If including prefixLM task, it must be the last prefix.
        input_feature_key: which feature to use from the dataset as the input text tokens.
        merge_examples_to_reduce_padding: if True, combines multiple input examples to reduce padding.
        reserved_for_packing: if specified, reduces the desired inputs length by the specified amount to enable multiple examples to be packed together downstream.
        seed: tf.int64 for controlling the random choice of spans.
    Returns:
        a dataset
    """

    if optional_task_prefixes:  # Ensure each task has a prefix.
        num_tasks = len(noise_densities) + int(use_prefix_lm_task)
        valid_number_of_prefixes = num_tasks == len(optional_task_prefixes)
        if not valid_number_of_prefixes:
            raise ValueError(
                "Number of task prefixes must match number of tasks.")
    inputs_length = sequence_length[input_feature_key]
    input_lengths, targets_lengths = [], []
    sequence_lengths = {x: y for x, y in sequence_length.items()}
    if reserved_for_packing:
        inputs_length -= reserved_for_packing
        for x, y in sequence_length.items():
            sequence_lengths[x] = y - reserved_for_packing
    hyperparams = list(zip(mean_noise_span_lengths, noise_densities))
    for mean_noise_span_length, noise_density in hyperparams:
        input_length, targets_length = t5.data.preprocessors.random_spans_helper(
            extra_tokens_per_span_inputs=1,
            extra_tokens_per_span_targets=1,
            inputs_length=inputs_length,
            mean_noise_span_length=mean_noise_span_length,
            noise_density=noise_density)
        input_lengths.append(input_length)
        targets_lengths.append(targets_length)

        if sequence_length["targets"] < targets_length:
            upper_bound = max(targets_lengths)
            raise ValueError(
                f'Expected max targets length for span corruption ({upper_bound}) is '
                f'greater than configured targets length '
                f"({sequence_length['targets']})")
    ds = dataset
    ds = t5.data.preprocessors.select_random_chunk(
        ds,
        output_features=output_features,
        feature_key="targets",
        max_length=65536)
    if merge_examples_to_reduce_padding:
        ds = t5.data.preprocessors.reduce_concat_tokens(
            ds, feature_key="targets", batch_size=128)
    num_shards = len(input_lengths) + int(use_prefix_lm_task)
    if shard_ds:
        ds_shards = [ds.shard(num_shards, i) for i in range(num_shards)]
    else:
        ds_shards = [ds for _ in range(num_shards)]
    processed_ds = []
    hyperparams = zip(input_lengths, hyperparams, range(num_shards))
    for input_length, (noise_span_length, noise_density), i in hyperparams:
        ds = ds_shards[i]
        ds = t5.data.preprocessors.split_tokens(
            ds,
            feature_key="targets",
            min_tokens_per_segment=None,
            max_tokens_per_segment=input_length)
        ds = t5.data.preprocessors.denoise(
            ds,
            output_features,
            inputs_fn=t5.data.preprocessors.noise_span_to_unique_sentinel,
            targets_fn=t5.data.preprocessors.nonnoise_span_to_unique_sentinel,
            noise_density=noise_density,
            noise_mask_fn=functools.partial(
                t5.data.preprocessors.random_spans_noise_mask,
                mean_noise_span_length=noise_span_length),
            input_feature_key=input_feature_key)
        if optional_task_prefixes:
            ds = prepend_prompt(
                ds,
                output_features,
                prompt_mode=optional_task_prefixes[i],
                mode=optional_task_prefixes[i],
                key=input_feature_key)
        processed_ds.append(ds)
    if use_prefix_lm_task:
        ds = ds_shards[-1]
        ds = prefix_lm(
            ds, sequence_lengths, output_features)
        if optional_task_prefixes:
            ds = prepend_prompt(
                ds,
                output_features,
                prompt_mode=optional_task_prefixes[-1],
                mode=optional_task_prefixes[-1],
                key=input_feature_key)
        processed_ds.append(ds)
    ds = tf.data.experimental.sample_from_datasets(processed_ds, rates, seed)
    return ds
import functools
import seqio
import tensorflow as tf
import t5.data
from datasets import load_dataset, load_from_disk
from t5.data import postprocessors
from t5.data import preprocessors
from t5.evaluation import metrics
from seqio import FunctionDataSource, utils

from ul2.ul2_objective import ul2_objective

# values from UL2 paper https://arxiv.org/pdf/2205.05131.pdf chapter 3.1.2 table 1
R_DENOISER_SPAN_LENGTHS = [3.0, 8.0]
X_DENOISER_SPAN_LENGTHS = [3.0, 8.0, 64.0, 64.0]
R_DENOISER_CORRUPT_RATES = [0.15, 0.15]
X_DENOISER_CORRUPT_RATES = [0.5, 0.5, 0.15, 0.5]

R_DENOISER_TOKEN_PREFIX = '[NLU]'
X_DENOISER_TOKEN_PREFIX = '[NLG]'
S_DENOISER_TOKEN_PREFIX = '[S2S]'

TaskRegistry = seqio.TaskRegistry

vocabulary = seqio.SentencePieceVocabulary('spiece.model', extra_ids=0)

DEFAULT_OUTPUT_FEATURES = {
    "inputs": seqio.Feature(
        vocabulary=vocabulary, add_eos=True,
        required=False),
    "targets": seqio.Feature(
        vocabulary=vocabulary, add_eos=True)
}


def gen_dataset(split, shuffle=False, seed=None, column="text", dataset=None):
    if shuffle:
        if seed:
            dataset = dataset.shuffle(seed=seed)
        else:
            dataset = dataset.shuffle()
    while True:
        for item in dataset[str(split)]:
            if item[column] is not None:
                yield item[column]


def dataset_fn(split, shuffle_files, seed=None, dataset=None):
    return tf.data.Dataset.from_generator(
        functools.partial(gen_dataset, split, shuffle_files,
                          seed, dataset=dataset),
        output_signature=tf.TensorSpec(
            shape=(), dtype=tf.string, name=dataset_name)
    )


@utils.map_over_dataset
def target_to_key(x, key_map, target_key):
    """Assign the value from the dataset to target_key in key_map"""
    return {**key_map, target_key: x}


dataset_name = "/researchdisk/training_dataset"
dataset_params = {"from_disk_path": dataset_name}

if "from_disk_path" in dataset_params:
    dataset = load_from_disk(dataset_params.get("from_disk_path"))
else:
    dataset = load_dataset(**dataset_params)

dataset_shapes = {"train": dataset["train"].num_rows,
                  "validation": dataset["validation"].num_rows}

TaskRegistry.add(
    "pretrain_ul2",
    source=seqio.FunctionDataSource(
        dataset_fn=functools.partial(dataset_fn, dataset=dataset),
        splits=("train", "validation"),
        caching_permitted=False,
        num_input_examples=dataset_shapes,
    ),
    preprocessors=[
        functools.partial(
            target_to_key, key_map={
                "inputs": None,
                "targets": None,
            }, target_key="targets"),
        seqio.preprocessors.tokenize,
        functools.partial(
            ul2_objective,
            shard_ds=False,
            use_prefix_lm_task=True,  # use S-denoising
            rates=[0.4 / len(R_DENOISER_SPAN_LENGTHS)]*len(R_DENOISER_SPAN_LENGTHS) + [
                0.4 / len(X_DENOISER_SPAN_LENGTHS)]*len(X_DENOISER_SPAN_LENGTHS) + [0.2],  # equal total 40% rate for both R- and X-denoisers + 20% for S-denoising (suggested at the paper chapter 4.5)
            mean_noise_span_lengths=R_DENOISER_SPAN_LENGTHS + X_DENOISER_SPAN_LENGTHS,
            noise_densities=R_DENOISER_CORRUPT_RATES + X_DENOISER_CORRUPT_RATES,
            optional_task_prefixes=[R_DENOISER_TOKEN_PREFIX]*len(R_DENOISER_SPAN_LENGTHS) + [
                X_DENOISER_TOKEN_PREFIX]*len(X_DENOISER_SPAN_LENGTHS) + [S_DENOISER_TOKEN_PREFIX],
            reserved_for_packing=1,  # make room for task prefix token
        ),
        seqio.preprocessors.append_eos_after_trim,
    ],
    output_features={"targets": DEFAULT_OUTPUT_FEATURES["targets"]},
    metric_fns=[metrics.accuracy]
)

UL2 weights can be found here: https://github.com/google-research/google-research/tree/master/ul2
Link to HF repo: https://huggingface.co/Finnish-NLP/ul2-base-nl36-finnish/blob/main/tasks.py#L25

@lucidrains How would you feel about adding Mixture-of-denoisers (ul2 objective) for initially pre-training the added encoder-decoder model? Would this be too off-topic?

@conceptofmind yea, it may be outside the scope of this project

i can always add it as a separate pretraining wrapper similar to this, but just have not totally bought into mixture-of-denoisers objective yet. if people still use it a few months from now, i'll build a pytorch package for it in a separate repo

@lucidrains ,

Ok. Always looking forward to your implementations! If anyone deserves a Holiday break it is definitely you!

I made an attempt at an encoder-decoder T5 architecture implementation. I did not want to open up a PR because I was unsure of whether you wanted to add T5 or just add an encoder to PaLM?

Attempt at T5:

import torch
from torch import nn
import torch.nn.functional as F

import math

from einops import rearrange

# pre-normalization wrapper
# they use layernorm without bias

class T5LayerNorm(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(dim))
        self.register_buffer("beta", torch.zeros(dim))

    def forward(self, x):
        return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = T5LayerNorm(dim)
        self.fn = fn

    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)

# gated-GELU activation function

class GEGLU(nn.Module):
    def forward(self, x):
        x, gates = x.chunk(2, dim = -1)
        return x * F.gelu(gates)

# feedforward layer with gated-GELU activation function

class FeedForward(nn.Module):
    def __init__(self, dim, mult = 4, dropout = 0.):
        super().__init__()
        inner_dim = int(dim * mult)
        self.net = nn.Sequential(
            nn.Linear(dim, inner_dim * 2),
            GEGLU(),
            nn.Dropout(dropout), # optional dropout
            nn.Linear(inner_dim, dim)
        )

    def forward(self, x):
        return self.net(x)

# T5 relative positional bias

class T5RelativePositionBias(nn.Module):
    def __init__(self, scale, causal = False, num_buckets = 32, max_distance = 128, heads = 8):
        super().__init__()
        self.scale = scale
        self.causal = causal
        self.num_buckets = num_buckets
        self.max_distance = max_distance
        self.relative_attention_bias = nn.Embedding(num_buckets, heads)

    @staticmethod
    def _relative_position_bucket(relative_position, causal = True, num_buckets = 32, max_distance = 128):
        ret = 0
        n = -relative_position
        if not causal:
            num_buckets //= 2
            ret += (n < 0).long() * num_buckets
            n = torch.abs(n)
        else:
            n = torch.max(n, torch.zeros_like(n))

        max_exact = num_buckets // 2
        is_small = n < max_exact

        val_if_large = max_exact + (
            torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
        ).long()
        val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))

        ret += torch.where(is_small, n, val_if_large)
        return ret

    def forward(self, qk_dots):
        i, j, device = *qk_dots.shape[-2:], qk_dots.device
        q_pos = torch.arange(j - i, j, dtype = torch.long, device = device)
        k_pos = torch.arange(j, dtype = torch.long, device = device)
        rel_pos = k_pos[None, :] - q_pos[:, None]
        rp_bucket = self._relative_position_bucket(
            rel_pos, 
            causal = self.causal, 
            num_buckets = self.num_buckets, 
            max_distance = self.max_distance
        )
        values = self.relative_attention_bias(rp_bucket)
        bias = rearrange(values, 'i j h -> h i j')
        return qk_dots + (bias * self.scale)

# T5 attention

class T5Attention(nn.Module):
    def __init__(
        self,
        *,
        dim,
        heads = 8,
        dim_head = 64,
        causal = False,
        num_buckets = 32,
        max_distance = 128,
        dropout = 0.
    ):
        super().__init__()
        inner_dim = dim_head * heads
        self.heads = heads
        self.scale = dim_head ** -0.5
        self.causal = causal

        self.to_q = nn.Linear(dim, inner_dim, bias = False)
        self.to_k = nn.Linear(dim, inner_dim, bias = False)
        self.to_v = nn.Linear(dim, inner_dim, bias = False)
        self.to_out = nn.Linear(inner_dim, dim)

        self.relative_position_bias = T5RelativePositionBias(
            scale = dim_head ** -0.5, 
            causal = causal, 
            num_buckets = num_buckets, 
            max_distance = max_distance, 
            heads = heads
            )

        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask = None):
        b, n, _, h = *x.shape, self.heads
        q, k, v = self.to_q(x), self.to_k(x), self.to_v(x)

        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))

        q = q * self.scale

        dots = torch.einsum('b h i d, b h j d -> b h i j', q, k)

        if self.causal:
            i, j = dots.shape[-2:]

        dots = self.relative_position_bias(dots)

        if mask is not None and self.causal:
            # Causal Mask
            causal_mask = torch.ones((i, j), dtype = torch.bool, device = x.device).triu(j - i + 1)
            dots = dots.masked_fill(causal_mask, -torch.finfo(dots.dtype).max)

        elif mask is not None:
            mask_value = -torch.finfo(dots.dtype).max
            dots = dots.masked_fill_(~mask, mask_value)

        attn = dots.softmax(dim = -1)
        attn = self.dropout(attn)

        out = torch.einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)
        
# T5 Decoder

class T5Decoder(nn.Module):
    def __init__(
        self,
        *,
        dim,
        num_tokens,
        depth,
        heads = 8,
        dim_head = 64,
        causal = True,
        num_buckets = 32,
        max_distance = 128,
        mlp_mult = 4,
        dropout = 0.
    ):
        super().__init__()
        self.num_tokens = num_tokens
        self.token_emb = nn.Embedding(num_tokens, dim)
        self.pos_emb = nn.Embedding(1024, dim)
        self.dropout = nn.Dropout(dropout)

        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PreNorm(dim, T5Attention(dim = dim, heads = heads, dim_head = dim_head, causal = causal, num_buckets = num_buckets, max_distance = max_distance, dropout = dropout)),
                PreNorm(dim, FeedForward(dim = dim, mult = mlp_mult, dropout = dropout)),
            ]))

    def forward(self, x, mask = None):
        b, n, device = *x.shape, x.device
        pos = torch.arange(n, device = device)
        x = self.token_emb(x) + self.pos_emb(pos)
        x = self.dropout(x)

        for attn, mlp, in self.layers:
            x = x + attn(x, mask = mask)
            x = x + mlp(x)

        return x

# T5 Encoder

class T5Encoder(nn.Module):
    def __init__(
        self,
        *,
        dim,
        num_tokens,
        depth,
        heads = 8,
        dim_head = 64,
        causal = False,
        num_buckets = 32,
        max_distance = 128,
        mlp_mult = 4,
        dropout = 0.
    ):
        super().__init__()
        self.num_tokens = num_tokens
        self.token_emb = nn.Embedding(num_tokens, dim)
        self.pos_emb = nn.Embedding(1024, dim)
        self.dropout = nn.Dropout(dropout)

        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PreNorm(dim, T5Attention(dim = dim, heads = heads, dim_head = dim_head, causal = causal, num_buckets = num_buckets, max_distance = max_distance, dropout = dropout)),
                PreNorm(dim, FeedForward(dim = dim, mult = mlp_mult, dropout = dropout)),
            ]))

    def forward(self, x, mask = None):
        b, n, device = *x.shape, x.device
        pos = torch.arange(n, device = device)
        x = self.token_emb(x) + self.pos_emb(pos)
        x = self.dropout(x)

        for attn, mlp in self.layers:
            x = x + attn(x, mask = mask)
            x = x + mlp(x)

        return x

# T5

class T5(nn.Module):
    def __init__(
        self,
        *,
        dim,
        enc_num_tokens,
        enc_depth,
        enc_heads,
        enc_dim_head,
        enc_mlp_mult,
        dec_num_tokens,
        dec_depth,
        dec_heads,
        dec_dim_head,
        dec_mlp_mult,
        dropout = 0.
    ):
        super().__init__()
        
        self.encoder = T5Encoder(
            dim = dim, 
            num_tokens = enc_num_tokens, 
            depth = enc_depth, 
            heads = enc_heads, 
            dim_head = enc_dim_head, 
            mlp_mult = enc_mlp_mult, 
            dropout = dropout
        )
        
        self.decoder = T5Decoder(
            dim = dim, 
            num_tokens = dec_num_tokens, 
            depth = dec_depth, 
            heads = dec_heads, 
            dim_head = dec_dim_head, 
            mlp_mult = dec_mlp_mult, 
            dropout = dropout
        )

    def forward(self, src, tgt, mask = None):

        x = self.encoder(src, mask = mask)
        x = self.decoder(tgt, mask = mask)
        return x


if __name__ == '__main__':
    
    model = T5(
        dim = 512,
        enc_num_tokens = 256,
        enc_depth = 6,
        enc_heads = 8,
        enc_dim_head = 64,
        enc_mlp_mult = 4,
        dec_num_tokens = 256,
        dec_depth = 6,
        dec_heads = 8,
        dec_dim_head = 64,
        dec_mlp_mult = 4,
        dropout = 0.
    )

    src = torch.randint(0, 256, (1, 1024))
    src_mask = torch.ones_like(src).bool()
    tgt = torch.randint(0, 256, (1, 1024))

    loss = model(src, tgt, mask = src_mask)
    print(loss.shape) #torch.Size([1, 1024, 512])

If this looks ok, I can look into adding something similar or just the encoder to the repository. Otherwise, I will make any necessary corrections.

Thank you,

Enrico

@conceptofmind thanks Enrico, hope you have a great new years eve too 🎊

i'm probably going to improvise a little and design a PaLM encoder / decoder, which does not exist yet

your code looks good! just a small thing; you need to pass the output of the encoder into the decoder, and the decoder will need to periodically cross attend to the encoded sequence, taking into account source sequence mask as well

@lucidrains Oh! I realized after I had left for a walk that I did not connect the output of the encoder-decoder!

    def forward(self, src, tgt, mask = None, context_mask = None):
        x = self.encoder(src, mask = mask)
        y = self.decoder(x, tgt, mask = mask, context_mask = context_mask)
        return y

I will definitely have to further review implementing cross-attention though. I will make amendments to stack self-attention and cross-attention layers with the context mask inside of the Decoder. I think the Encoder looks ok. I will update this when I hopefully have it working better.

# T5 Cross Attention

class T5CrossAttention(nn.Module):
    def __init__(
        self,
        *,
        dim,
        heads = 8,
        dim_head = 64,
        num_buckets = 32,
        max_distance = 128,
        dropout = 0.
    ):
        super().__init__()
        inner_dim = dim_head * heads
        self.heads = heads
        self.scale = dim_head ** -0.5

        self.to_q = nn.Linear(dim, inner_dim, bias = False)
        self.to_k = nn.Linear(dim, inner_dim, bias = False)
        self.to_v = nn.Linear(dim, inner_dim, bias = False)
        self.to_out = nn.Linear(inner_dim, dim)

        self.relative_position_bias = T5RelativePositionBias(
            scale = dim_head ** -0.5, 
            causal = False, 
            num_buckets = num_buckets, 
            max_distance = max_distance, 
            heads = heads
            )

        self.dropout = nn.Dropout(dropout)

    def forward(self, x, context, mask = None, context_mask = None):
        b, n, _, h = *x.shape, self.heads
        q, k, v = self.to_q(x), self.to_k(context), self.to_v(context)

        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))

        q = q * self.scale

        dots = torch.einsum('b h i d, b h j d -> b h i j', q, k)

        dots = self.relative_position_bias(dots)

        if mask is not None:
            mask_value = -torch.finfo(dots.dtype).max
            dots = dots.masked_fill_(~mask, mask_value)

        if context_mask is not None:
            mask_value = -torch.finfo(dots.dtype).max
            dots = dots.masked_fill_(~context_mask[:, None, :], mask_value)

        attn = dots.softmax(dim = -1)
        attn = self.dropout(attn)

        out = torch.einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

I previously spoke to Ofir Press and Hails (EAI) about a version of encoder-decoder PaLM with ALiBi positional bias. Ofir said he would help where he could.

Thank you again.

@conceptofmind yup, that looks good 😃 Ofir Press would be good to work with, especially if he had found a solution to applying ALiBi to encoders (as of today, it has only been shown to work for decoders)

hope you have fun at First Night over there at the east coast

@lucidrains I hope you have a great New Year as well!

@lucidrains Sent a small, Thank you / Holiday, gift since I appreciate you answering my questions 😄.

I know you are away now but whenever you are back and free, I did some refactoring of the previous T5 architecture implementation attempt. I am still unsure of whether this is fully implemented correctly though.

  • I commented out the fixed positional embeddings in the Encoder / Decoder / T5 class because I am not sure these are needed when including T5 relative positional bias already.
    #self.pos_emb = nn.Embedding(max_seq_len, dim)
    #x = x + self.pos_emb(torch.arange(x.shape[1], device = x.device))
  • Added Cross Attention taking context masking and dimensionality into consideration. I do not know whether casual masking should be used in Cross Attention so I currently have not included it. I also commented out the relative positional bias in Cross Attention since I do not believe T5 uses it there. I think Cross Attention is implemented correctly now.
  • I tied the encoder token embedding to the decoder token embedding, self.encoder.token_emb.weight = self.decoder.token_emb.weight. I am pretty sure T5 shares or ties the token embedding weights of the encoder / decoder.
  • I added an embedding, self.embedding = nn.Embedding(enc_num_tokens, dim) and to_logits, self.to_logits = nn.Linear(dim, dec_num_tokens) to the T5 class. I do not know if these are required but it seemed pretty standard.
  • Used PreNorm but it is possible PostNorm should be used instead.
  • Switched GEGELU to nn.ReLU.
  • Added a final layer norm and Residual wrapper.

Refactored code:

import torch
from torch import nn
import torch.nn.functional as F

import math

from einops import rearrange

def exists(val):
    return val is not None

def default(val, d):
    return val if exists(val) else d

# residual wrapper

class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x, **kwargs):
        return self.fn(x, **kwargs) + x

# pre-normalization wrapper
# they use layernorm without bias

class T5LayerNorm(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(dim))
        self.register_buffer("beta", torch.zeros(dim))

    def forward(self, x):
        return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = T5LayerNorm(dim)
        self.fn = fn

    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)

# feedforward layer

class FeedForward(nn.Module):
    def __init__(self, dim, mult = 4, dropout = 0.):
        super().__init__()
        inner_dim = int(dim * mult)
        self.net = nn.Sequential(
            nn.Linear(dim, inner_dim),
            nn.ReLU(),
            nn.Dropout(dropout), # optional dropout
            nn.Linear(inner_dim, dim)
        )

    def forward(self, x):
        return self.net(x)

# T5 relative positional bias

class T5RelativePositionBias(nn.Module):
    def __init__(self, scale, causal = False, num_buckets = 32, max_distance = 128, heads = 12):
        super().__init__()
        self.scale = scale
        self.causal = causal
        self.num_buckets = num_buckets
        self.max_distance = max_distance
        self.relative_attention_bias = nn.Embedding(num_buckets, heads)

    @staticmethod
    def _relative_position_bucket(relative_position, causal = True, num_buckets = 32, max_distance = 128):
        ret = 0
        n = -relative_position
        if not causal:
            num_buckets //= 2
            ret += (n < 0).long() * num_buckets
            n = torch.abs(n)
        else:
            n = torch.max(n, torch.zeros_like(n))

        max_exact = num_buckets // 2
        is_small = n < max_exact

        val_if_large = max_exact + (
            torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
        ).long()
        val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))

        ret += torch.where(is_small, n, val_if_large)
        return ret

    def forward(self, qk_dots):
        i, j, device = *qk_dots.shape[-2:], qk_dots.device
        q_pos = torch.arange(j - i, j, dtype = torch.long, device = device)
        k_pos = torch.arange(j, dtype = torch.long, device = device)
        rel_pos = k_pos[None, :] - q_pos[:, None]
        rp_bucket = self._relative_position_bucket(
            rel_pos, 
            causal = self.causal, 
            num_buckets = self.num_buckets, 
            max_distance = self.max_distance
        )
        values = self.relative_attention_bias(rp_bucket)
        bias = rearrange(values, 'i j h -> h i j')
        return qk_dots + (bias * self.scale)

# T5 Self Attention

class T5SelfAttention(nn.Module):
    def __init__(
        self,
        *,
        dim,
        heads = 12,
        dim_head = 64,
        causal = False,
        dropout = 0.
    ):
        super().__init__()
        inner_dim = dim_head * heads
        self.heads = heads
        self.scale = dim_head ** -0.5
        self.causal = causal

        self.to_q = nn.Linear(dim, inner_dim, bias = False)
        self.to_k = nn.Linear(dim, inner_dim, bias = False)
        self.to_v = nn.Linear(dim, inner_dim, bias = False)
        self.to_out = nn.Linear(inner_dim, dim)

        self.relative_position_bias = T5RelativePositionBias(
            scale = dim_head ** -0.5, 
            causal = causal,
            heads = heads
            )

        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask = None):
        b, n, _, h = *x.shape, self.heads
        q, k, v = self.to_q(x), self.to_k(x), self.to_v(x)

        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))

        q = q * self.scale

        sim = torch.einsum('b h i d, b h j d -> b h i j', q, k)

        sim = self.relative_position_bias(sim)

        # mask

        mask_value = -torch.finfo(sim.dtype).max

        if mask is not None:
            sim = sim.masked_fill_(~mask, mask_value)

        if self.causal:
            i, j = sim.shape[-2:]
            causal_mask = torch.ones((i, j), dtype = torch.bool, device = x.device).triu(j - i + 1)
            sim = sim.masked_fill(causal_mask, mask_value)

        # attention

        attn = sim.softmax(dim = -1)
        attn = self.dropout(attn)

        # aggregate

        out = torch.einsum('b h i j, b h j d -> b h i d', attn, v)
        
        # merge heads

        out = rearrange(out, 'b h n d -> b n (h d)')
        
        # combine heads and linear output

        return self.to_out(out)

# T5 Cross Attention

class T5CrossAttention(nn.Module):
    def __init__(
        self,
        *,
        dim,
        context_dim = None,
        heads = 12,
        dim_head = 64,
        dropout = 0.
    ):
        super().__init__()
        inner_dim = dim_head * heads
        context_dim = default(context_dim, dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.to_q = nn.Linear(dim, inner_dim, bias = False)
        self.to_k = nn.Linear(context_dim, inner_dim, bias = False)
        self.to_v = nn.Linear(context_dim, inner_dim, bias = False)
        self.to_out = nn.Linear(inner_dim, dim)

        # self.relative_position_bias = T5RelativePositionBias(
        #     scale = dim_head ** -0.5,
        #     causal = False,
        #     heads = heads
        #     )

        self.dropout = nn.Dropout(dropout)

    def forward(self, x, context, mask = None, context_mask = None):
        b, n, _, h = *x.shape, self.heads

        kv_input = default(context, x)

        q, k, v = self.to_q(x), self.to_k(kv_input), self.to_v(kv_input)

        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))

        q = q * self.scale

        sim = torch.einsum('b h i d, b h j d -> b h i j', q, k)

        #sim = self.relative_position_bias(sim)

        # mask

        mask_value = -torch.finfo(sim.dtype).max

        if mask is not None:
            sim = sim.masked_fill_(~mask, mask_value)

        if context_mask is not None:
            sim = sim.masked_fill_(~context_mask[:, None, :], mask_value)

        # attention

        attn = sim.softmax(dim = -1)
        attn = self.dropout(attn)

        # aggregate

        out = torch.einsum('b h i j, b h j d -> b h i d', attn, v)
        
        # merge heads

        out = rearrange(out, 'b h n d -> b n (h d)')
        
        # combine heads and linear output

        return self.to_out(out)

# T5 Encoder

class T5Encoder(nn.Module):
    def __init__(
        self,
        *,
        dim,
        num_tokens,
        #max_seq_len,
        depth,
        heads = 12,
        dim_head = 64,
        causal = False,
        mlp_mult = 4,
        dropout = 0.
    ):
        super().__init__()
        self.token_emb = nn.Embedding(num_tokens, dim)
        #self.pos_emb = nn.Embedding(max_seq_len, dim)

        self.layer = nn.ModuleList([])
        for _ in range(depth):
            self.layer.append(nn.ModuleList([
                Residual(PreNorm(dim, T5SelfAttention(dim = dim, heads = heads, dim_head = dim_head, causal = causal, dropout = dropout))),
                Residual(PreNorm(dim, FeedForward(dim = dim, mult = mlp_mult, dropout = dropout))),
            ]))

        self.final_norm = T5LayerNorm(dim)

    def forward(self, x, mask = None):
        x = self.token_emb(x)
        #x = x + self.pos_emb(torch.arange(x.shape[1], device = x.device))

        for attn, mlp in self.layer:
            x = attn(x, mask = mask)
            x = mlp(x)

        x = self.final_norm(x)

        return x

# T5 Decoder

class T5Decoder(nn.Module):
    def __init__(
        self,
        *,
        dim,
        num_tokens,
        #max_seq_len,
        depth,
        heads = 12,
        dim_head = 64,
        causal = True,
        mlp_mult = 4,
        dropout = 0.
    ):
        super().__init__()
        self.token_emb = nn.Embedding(num_tokens, dim)
        #self.pos_emb = nn.Embedding(max_seq_len, dim)

        self.layer = nn.ModuleList([])
        for _ in range(depth):
            self.layer.append(nn.ModuleList([
                Residual(PreNorm(dim, T5SelfAttention(dim = dim, heads = heads, dim_head = dim_head, causal = causal, dropout = dropout))),
                Residual(PreNorm(dim, T5CrossAttention(dim = dim, heads = heads, dim_head = dim_head, dropout = dropout))),
                Residual(PreNorm(dim, FeedForward(dim = dim, mult = mlp_mult, dropout = dropout))),
            ]))

        self.final_norm = T5LayerNorm(dim)

    def forward(self, x, context, mask = None, context_mask = None):
        x = self.token_emb(x)
        #x = x + self.pos_emb(torch.arange(x.shape[1], device = x.device))

        for attn, cross_attn, mlp in self.layer:
            x = attn(x, mask = mask)
            x = cross_attn(x, context = context, mask = mask, context_mask = context_mask)
            x = mlp(x)

        x = self.final_norm(x)

        return x

# T5

class T5(nn.Module):
    def __init__(
        self,
        *,
        dim,
        #max_seq_len,
        enc_num_tokens,
        enc_depth,
        enc_heads,
        enc_dim_head,
        enc_mlp_mult,
        dec_num_tokens,
        dec_depth,
        dec_heads,
        dec_dim_head,
        dec_mlp_mult,
        dropout = 0.,
        tie_token_emb = True
    ):
        super().__init__()
        
        self.embedding = nn.Embedding(enc_num_tokens, dim)
        #self.pos_emb = nn.Embedding(max_seq_len, dim)

        self.encoder = T5Encoder(
            dim = dim,
            #max_seq_len = max_seq_len, 
            num_tokens = enc_num_tokens, 
            depth = enc_depth, 
            heads = enc_heads, 
            dim_head = enc_dim_head, 
            mlp_mult = enc_mlp_mult, 
            dropout = dropout
        )
        
        self.decoder = T5Decoder(
            dim = dim,
            #max_seq_len= max_seq_len, 
            num_tokens = dec_num_tokens, 
            depth = dec_depth, 
            heads = dec_heads, 
            dim_head = dec_dim_head, 
            mlp_mult = dec_mlp_mult, 
            dropout = dropout
        )

        self.to_logits = nn.Linear(dim, dec_num_tokens)

        # tie weights
        if tie_token_emb:
            self.encoder.token_emb.weight = self.decoder.token_emb.weight

    def forward(self, src, tgt, mask = None, context_mask = None):
        x = self.embedding(src)
        #x = x + self.pos_emb(torch.arange(x.shape[1], device = x.device))
        x = self.encoder(src, mask = mask)
        x = self.decoder(tgt, x, mask = mask, context_mask = context_mask)
        x = self.to_logits(x)
        return x


if __name__ == '__main__':
    
    model = T5(
        dim = 768,
        #max_seq_len = 1024,
        enc_num_tokens = 512,
        enc_depth = 6,
        enc_heads = 12,
        enc_dim_head = 64,
        enc_mlp_mult = 4,
        dec_num_tokens = 512,
        dec_depth = 6,
        dec_heads = 12,
        dec_dim_head = 64,
        dec_mlp_mult = 4,
        dropout = 0.,
        tie_token_emb = True
    )

    src = torch.randint(0, 512, (1, 1024))
    src_mask = torch.ones_like(src).bool()
    tgt = torch.randint(0, 512, (1, 1024))

    loss = model(src, tgt, mask = src_mask)

    print(loss.shape) #torch.Size([1, 1024, 512])

Thank you again for all of your help! I hope you are enjoying the EDM.

@conceptofmind hey thanks Enrico! you didn't have to

yup overall the code looks good! i'm probably going to stick with the PaLM architecture with the parallel attention / feedforward blocks, but you could release the code above in a repository and it would see usage!

@lucidrains I figured that would be the best way to show gratitude.

I appreciate you reviewing the code. I will create a repository with the final T5 code above. @Bachstelze Hopefully this helps cover both ul2 and t5 now.

I can try to work on an encoder-decoder with PaLM in the meantime as well if the above result does look good. I may try it for my own learning experience anyways. Unless there is a very explicit way you want it implemented then I can focus on the web app UI (#15 ).

Thank you,

Enrico

@conceptofmind Thank you for your interest and contribution!
To my knowledge, there is no research that shows that a decoder-only modification has a better performance, than an encoder-decoder architecture. The GPT-paper "Improving Language Understanding by Generative Pre-Training" also used a discriminative training task. Such a discriminative training object alone has shown improvements with Electra.
In fact, the T5-paper "Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer" showed a better performance for encoder-decoder models, even with the same parameters. P refers to the number of parameters in a 12-layer base Transformer layerstack and M to refer to the FLOPs required to process a sequence using the encoder-decoder model:
architectures
Probing Word Translations in the Transformer and Trading Decoder for Encoder Layers even points in the opposite direction of GPT and suggests increasing the encoder and reducing the decoder layers.
If we only look at the parameters P, then we should take into account model compression like ALBERT or parameter sharing across layers in transformers.
So I think that we need a gradual comparison of model modifications to get a clear overview.

take a break around the new years to go enjoy some electronic music

@lucidrains Enjoy the music! Which artists are or were playing?

@Bachstelze yup, agreed that current evidence seems to point to encoder / decoder scaling better @AranKomat has been telling me this for the last half year

a local techno musician the first night and then Spencer Brown of Anjunabeats the next 😄

@lucidrains I like Nightwalk by Spencer Brown.

Made an attempt at PaLM encoder-decoder. I did not think I should open a PR for this so I am just going to add it to this thread. Jason Phang at EleutherAI has graciously provided me with some information too so I will be revising everything to the best of my ability. I am working on adding Flash Attention / Flash Cross Attention as well. Hopefully, some of it is useful...

import torch
from torch import einsum, nn
import torch.nn.functional as F
from einops import rearrange, repeat

# helper functions

def exists(val):
    return val is not None

def default(val, d):
    return val if exists(val) else d

# normalization
# they use layernorm without bias, something that pytorch does not offer


class LayerNorm(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(dim))
        self.register_buffer("beta", torch.zeros(dim))

    def forward(self, x):
        return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)

# residual


class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x, *args, **kwargs):
        return self.fn(x, *args, **kwargs) + x


# rotary positional embedding
# https://arxiv.org/abs/2104.09864


class RotaryEmbedding(nn.Module):
    def __init__(self, dim):
        super().__init__()
        inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer("inv_freq", inv_freq)

    def forward(self, max_seq_len, *, device):
        seq = torch.arange(max_seq_len, device=device, dtype=self.inv_freq.dtype)
        freqs = einsum("i , j -> i j", seq, self.inv_freq)
        return torch.cat((freqs, freqs), dim=-1)


def rotate_half(x):
    x = rearrange(x, "... (j d) -> ... j d", j=2)
    x1, x2 = x.unbind(dim=-2)
    return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb(pos, t):
    return (t * pos.cos()) + (rotate_half(t) * pos.sin())


# classic Noam Shazeer paper, except here they use SwiGLU instead of the more popular GEGLU for gating the feedforward
# https://arxiv.org/abs/2002.05202


class SwiGLU(nn.Module):
    def forward(self, x):
        x, gate = x.chunk(2, dim=-1)
        return F.silu(gate) * x


# parallel attention and feedforward with residual
# discovered by Wang et al + EleutherAI from GPT-J fame


class ParallelTransformerBlock(nn.Module):
    def __init__(self, dim, dim_head=64, heads=8, ff_mult=4):
        super().__init__()
        self.norm = LayerNorm(dim)

        attn_inner_dim = dim_head * heads
        ff_inner_dim = dim * ff_mult
        self.fused_dims = (attn_inner_dim, dim_head, dim_head, (ff_inner_dim * 2))

        self.heads = heads
        self.scale = dim_head**-0.5
        self.rotary_emb = RotaryEmbedding(dim_head)

        self.fused_attn_ff_proj = nn.Linear(dim, sum(self.fused_dims), bias=False)
        self.attn_out = nn.Linear(attn_inner_dim, dim, bias=False)

        self.ff_out = nn.Sequential(
            SwiGLU(),
            nn.Linear(ff_inner_dim, dim, bias=False)
        )

        # for caching causal mask and rotary embeddings

        self.register_buffer("mask", None, persistent=False)
        self.register_buffer("pos_emb", None, persistent=False)

    def get_mask(self, n, device):
        if self.mask is not None and self.mask.shape[-1] >= n:
            return self.mask[:n, :n]

        mask = torch.ones((n, n), device=device, dtype=torch.bool).triu(1)
        self.register_buffer("mask", mask, persistent=False)
        return mask

    def get_rotary_embedding(self, n, device):
        if self.pos_emb is not None and self.pos_emb.shape[-2] >= n:
            return self.pos_emb[:n]

        pos_emb = self.rotary_emb(n, device=device)
        self.register_buffer("pos_emb", pos_emb, persistent=False)
        return pos_emb

    def forward(self, x, attn_mask=None):
        """
        einstein notation
        b - batch
        h - heads
        n, i, j - sequence length (base sequence length, source, target)
        d - feature dimension
        """

        n, device, h = x.shape[1], x.device, self.heads

        # pre layernorm

        x = self.norm(x)

        # attention queries, keys, values, and feedforward inner

        q, k, v, ff = self.fused_attn_ff_proj(x).split(self.fused_dims, dim=-1)

        # split heads
        # they use multi-query single-key-value attention, yet another Noam Shazeer paper
        # they found no performance loss past a certain scale, and more efficient decoding obviously
        # https://arxiv.org/abs/1911.02150

        q = rearrange(q, "b n (h d) -> b h n d", h=h)

        # rotary embeddings

        positions = self.get_rotary_embedding(n, device)
        q, k = map(lambda t: apply_rotary_pos_emb(positions, t), (q, k))

        # scale

        q = q * self.scale

        # similarity

        sim = einsum("b h i d, b j d -> b h i j", q, k)

        # mask

        causal_mask = self.get_mask(n, device)
        sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)

        if exists(attn_mask):
            sim = sim.masked_fill(~attn_mask, -torch.finfo(sim.dtype).max)

        # attention

        sim = sim - sim.amax(dim=-1, keepdim=True).detach()
        attn = sim.softmax(dim=-1)

        # aggregate values

        out = einsum("b h i j, b j d -> b h i d", attn, v)

        # merge heads

        out = rearrange(out, "b h n d -> b n (h d)")
        return self.attn_out(out) + self.ff_out(ff)

# cross attention - using multi-query + one-headed key / values as in PaLM w/ optional parallel feedforward

class CrossAttention(nn.Module):
    def __init__(
        self,
        dim,
        *,
        context_dim=None,
        dim_head=64,
        heads=8,
        ff_mult=4,
    ):
        super().__init__()
        self.heads = heads
        self.scale = dim_head ** -0.5
        inner_dim = heads * dim_head
        context_dim = default(context_dim, dim)

        self.norm = LayerNorm(dim)
        self.context_norm = LayerNorm(context_dim)

        self.to_q = nn.Linear(dim, inner_dim, bias=False)
        self.to_kv = nn.Linear(context_dim, dim_head * 2, bias=False)
        self.to_out = nn.Linear(inner_dim, dim, bias=False)

        # parallel feedforward

        ff_inner_dim = ff_mult * dim

        self.ff = nn.Sequential(
            nn.Linear(dim, ff_inner_dim * 2, bias=False),
            SwiGLU(),
            nn.Linear(ff_inner_dim, dim, bias=False)
        )

        self.register_buffer("mask", None, persistent=False)

    def get_mask(self, n, device):
        if self.mask is not None and self.mask.shape[-1] >= n:
            return self.mask[:n, :n]

        mask = torch.ones((n, n), device=device, dtype=torch.bool).triu(1)
        self.register_buffer("mask", mask, persistent=False)
        return mask

    def forward(self, x, context, attn_mask=None):
        """
        einstein notation
        b - batch
        h - heads
        n, i, j - sequence length (base sequence length, source, target)
        d - feature dimension
        """

        n, device, h = x.shape[1], x.device, self.heads

        # pre-layernorm, for queries and context

        x = self.norm(x)
        context = self.context_norm(context)

        # get queries

        q = self.to_q(x)
        q = rearrange(q, 'b n (h d) -> b h n d', h = self.heads)

        # scale

        q = q * self.scale

        # get key / values

        k, v = self.to_kv(context).chunk(2, dim=-1)

        # query / key similarity

        sim = einsum('b h i d, b j d -> b h i j', q, k)

        # mask

        causal_mask = self.get_mask(n, device)
        sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)

        if exists(attn_mask):
            sim = sim.masked_fill(~attn_mask, -torch.finfo(sim.dtype).max)

        # attention

        sim = sim - sim.amax(dim=-1, keepdim=True)
        attn = sim.softmax(dim=-1)

        # aggregate

        out = einsum('b h i j, b j d -> b h i d', attn, v)

        # merge and combine heads

        out = rearrange(out, 'b h n d -> b n (h d)')
        out = self.to_out(out)
        out = out + self.ff(x)

        return out

# Encoder

class Encoder(nn.Module):
    def __init__(
        self,
        *,
        dim,
        depth,
        heads,
        dim_head,
        ff_mult=4,
    ):
        super().__init__()

        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(
                Residual(ParallelTransformerBlock(dim=dim, dim_head=dim_head, heads=heads, ff_mult=ff_mult)),
            )

    def forward(self, x, attn_mask=None):

        for attn_ff in self.layers:
            x = attn_ff(x, attn_mask=attn_mask)
        return x

# Decoder

class Decoder(nn.Module):
    def __init__(
        self,
        *,
        dim,
        depth,
        heads,
        dim_head,
        ff_mult=4,
    ):
        super().__init__()

        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Residual(ParallelTransformerBlock(dim=dim, dim_head=dim_head, heads=heads, ff_mult=ff_mult)),
                Residual(CrossAttention(dim=dim, dim_head=dim_head, heads=heads, ff_mult=ff_mult))
            ]))

    def forward(self, x, context, attn_mask=None):

        for attn, cross_attn in self.layers:
            x = attn(x, attn_mask)
            x = cross_attn(x, context, attn_mask=attn_mask)
        return x

# PaLM

class PaLM(nn.Module):
    def __init__(
        self,
        *,
        dim,
        enc_num_tokens,
        enc_depth,
        enc_heads,
        enc_dim_head,
        dec_num_tokens,
        dec_depth,
        dec_heads,
        dec_dim_head,
        ff_mult=4,
    ):
        super().__init__()

        # embedding

        self.embedding = nn.Embedding(enc_num_tokens, dim)

        # encoder

        self.encoder = Encoder(
            dim = dim,
            depth = enc_depth,
            heads = enc_heads,
            dim_head = enc_dim_head,
            ff_mult = ff_mult,
        )

        # decoder

        self.decoder = Decoder(
            dim = dim,
            depth = dec_depth,
            heads = dec_heads,
            dim_head = dec_dim_head,
            ff_mult = ff_mult,
        )

        self.to_logits = nn.Sequential(
            LayerNorm(dim),
            nn.Linear(dim, dec_num_tokens, bias=False)
        )

    def forward(self, x, y, mask=None):
        x = self.embedding(x)
        y = self.embedding(y)

        x = self.encoder(x, attn_mask=mask)
        y = self.decoder(y, x, attn_mask=mask)

        return self.to_logits(y)

if __name__ == "__main__":

    model = PaLM(
        dim=512,
        enc_num_tokens=512,
        enc_depth=6,
        enc_heads=8,
        enc_dim_head=64,
        dec_num_tokens=512,
        dec_depth=6,
        dec_heads=8,
        dec_dim_head=64,
        ff_mult=4,
    )

    src = torch.randint(0, 512, (1, 1024))
    src_mask = torch.ones_like(src).bool()
    tgt = torch.randint(0, 512, (1, 1024))

    loss = model(src, tgt, mask = src_mask)

    print(loss.shape) #torch.Size([1, 1024, 512])

@Bachstelze 清华大学 GLM performance was great and may be worth checking out too.

Will continue to update as I build upon.

@conceptofmind hey yup, that's close to what i had in mind! just have to think a bit about relative positional encoding for the encoder, as the length extrapolation fix for the rotary embeddings is for decoder only

@lucidrains I had spoken to Jianlin Su of the original RoFormer paper. He stated he is going to provide advice and add blog posts (https://spaces.ac.cn/index.php) on fixing length extrapolation and integrating Rotary Embeddings in the Encoder. Hopefully, I will have more information and a semi-workable implementation regarding this soon.

@conceptofmind oh nice, yea, i'll have to look out for his blog post then! for now, i think turning off xpos and just forgoing the length extrapolation should be ok for encoder

@Bachstelze this should be complete-able by week's end

just need the actor critic class to accept the PaLMEncDec and do some custom logic to return the correct action mask

@conceptofmind oh nice, yea, i'll have to look out for his blog post then! for now, i think turning of xpos and just forgoing the length extrapolation should be ok for encoder

I wish I could have made a more meaningful contribution to enc-dec 😆.

Hopefully, I will hear back again from Jianlin soon and be able to actually contribute a PR with the RoPE fix.

I spoke to Louis briefly about Google Search API (or Bing API) integration with TRLX taking points from the Sparrow paper. I will open an issue regarding incorporate some learning points from Sparrow, given Letitia's video soon. Hopefully, it is of interest.

Thank you again!

@conceptofmind yea don't worry about it, this is a difficult project to contribute to. you already did a great job with the vit-flax repository!

re: Jianlin Su, let's wait for him to write a paper (or his preferred medium, a blog post) and see how reception is before integrating it. i've tried some of his other works, and they haven't worked out quite as well as RoPE

similarly, for search APIs, let us wait to see the results of the ongoing Bing + ChatGPT efforts before diving into a big engineering project. i'm skeptical that retrieval is all we need for solving the factual groundedness issues.

@conceptofmind one idea i had, since you had previous experience porting pytorch code over to jax, is to work with @sglucas and build out a jax version of this repository as a parallel effort. you can perhaps even try doing some small scale training with whatever compute @sglucas was allotted

@conceptofmind yea don't worry about it, this is a difficult project to contribute to. you already did a great job with the vit-flax repository!

re: Jianlin Su, let's wait for him to write a paper (or his preferred medium, a blog post) and see how reception is before integrating it. i've tried some of his other works, and they haven't worked out as well as RoPE

similarly, for search APIs, let us wait to see the results of the ongoing Bing + ChatGPT efforts before diving into a big engineering project. i'm skeptical that retrieval is all we need for solving the factual groundedness issues.

Sounds good. I will wait on the release of Jialin Su's blog posts and try to validate his approach with Jason Phang in the future as well.

I do have some doubts as well about factual groundedness using only search engine integration for retrieval after I had tested different Cohere-XL models with Google Search. I think a combination of numerous tools (Wolfram Alpha, Weather API, Financial Databases) and chain-of-thought prompting + instruction will be needed just for better grounding alone. I am starting a working group for this with Louis in CarperAI soon. So hopefully, I will be able to further validate as well.

@conceptofmind one idea i had, since you had previous experience porting pytorch code over to jax, is to work with @sglucas and build out a jax version of this repository as a parallel effort. you can perhaps even try doing some small scale training with whatever compute @sglucas was allotted

I will begin working on a jax version port in either Flax or Haiku. I can start with what I have from the PaLM-flax (or PaLM-haiku) repository and will follow continue to expand upon it once you have completed the PaLM Encoder-Decoder. I will confirm in #16 now.

Thank you,

Enrico

@Bachstelze

was able to get it kind of working, but the code is now more complicated than i would like

i'll keep meditating on it, and maybe by end of the month it will be polished and simple again

import torch
from palm_rlhf_pytorch import PaLMEncDec, RewardModel, RLHFTrainer

# load your pretrained palm

palm = PaLMEncDec(
    num_tokens = 20000,
    dim = 512,
    depth = 12
).cuda()

# load your pretrained reward model

reward_model = RewardModel(
    palm.encoder,
    num_binned_output = 5
).cuda()

# ready your list of prompts for reinforcement learning

prompts = torch.randint(0, 256, (50000, 512)).cuda() # 50k prompts

# pass it all to the trainer and train

trainer = RLHFTrainer(
    palm = palm,
    reward_model = reward_model,
    prompt_token_ids = prompts
)

trainer.train(num_episodes = 50000)

# then, if it succeeded...
# generate say 10 samples and use the reward model to return the best one

answer = trainer.generate(2048, prompt = prompts[0], num_samples = 10) # (<= 2048,)

@Bachstelze encoder / decoder would be trained something like

import torch
from palm_rlhf_pytorch import PaLMEncDec

palm_enc_dec = PaLMEncDec(
    num_tokens = 20000,
    dim = 512,
    depth = 12
).cuda()

seq = torch.randint(0, 20000, (1, 2048)).cuda()
prompt = torch.randint(0, 20000, (1, 2048)).cuda()  # negative values are to be masked out

loss = palm_enc_dec(decoder_seq = seq, prompt = prompt, return_loss = True)
loss.backward()

@Bachstelze this will also close out this issue you raised

Hey guys. I originally suggested for T5 / UL2 to be incorporated to the SotA LMs like ChatGPT, but I've found some arguments / data that suggest they may add more harm than benefits.

Honestly, I'm still trying to get more data, so I may change my position later, but this is what I know at this moment. But I'm not saying that LM encoder is completely useless for any purpose. They are very useful as an encoder for text-conditional multimodal models and information retrieval.

  1. The tasks where T5 / UL2 were evaluated on and excel at are limited to a subset of the traditional NLP tasks (i.e. discriminative / classification tasks, translation variants (e.g. NMT, data-to-text, etc) and summarization variants. You can see this from the list of tasks UL2 was evaluated on. The typical real-world tasks (e.g. the ones people try with ChatGPT) are more like turn-based generative tasks with long-form output, and we have little evidence that suggest T5 / UL2 perform better than GPT on this domain.

  2. Enc-dec models are much more compute-consuming at inference than dec-only models for dialogs, since it has to re-encode the entire context each time it receives a response, since the encoder is using bidirecitonal attention. This can easily increase the inference cost by 10x.

  3. Some concrete data that T5 may struggle or at least perform no better than GPT outside of the conventional text gen tasks T5 was evaluated on:
    a. FLAN-T5 performs much worse w/ CoT than direct prediction (Table 5 from FLAN-PaLM paper) compared with GPT counterparts.
    b. T5 XXL finetuned on GSM8K (~8% according to https://arxiv.org/abs/2212.08410) performs much worse than the ~10B GPT (~20% according to the original GSM8K paper). Admittedly, T5 isn't using a calculator and its tokenizer may not be suitable for this task, but I think it's safe to say that T5 isn't outperforming GPT on this task.
    c. T0 performs poorly on dialog task, even compared with untuned GPT (Table 5: https://arxiv.org/abs/2212.12017)
    d. AlphaCode (https://arxiv.org/abs/2203.07814) tried both enc-dec and dec-only models, and while enc-dec performs more efficiently than dec-only on HumanEval, which can be split into input-output pairs easily, dec-only works much better on various problem formats. A relevant passage below:

The HumanEval results for all of our encoder-decoder models (including the final AlphaCode model) are significantly worse than the decoder-only models, so we do not report them here. We believe this performance difference stems from the fact that encode-decoder models are well aligned with the competition programming setting where we have a dataset with clear inputs (programming contest problem descriptions) and outputs (solution code), as well as
example tests for effective filtering. Therefore the encoder-decoder models can both learn effectively and sample efficiently. However, encoder-decoders are not well aligned with the HumanEval setting where the only training data is GitHub code which cannot easily be split into meaningful inputs and outputs. As a result, the fact that decoder-only models compute loss on all tokens (rather than only one-third of tokens) enables them to learn more efficiently in the HumanEval setting.

@AranKomat thanks for sharing this! yea, if need be, i'll just revert some code to keep the repository simple

we can keep the issue open until we come to some conclusion

if any other researcher has an opinion, now is the time to speak up or forever hold your peace

i'll probably revert this in about a week or so back to decoder only

I trust your best judgment!

Aran pointed out that the current cost of inference may significantly outweigh the benefit of performance/accuracy (even in the case in which Enc/Dec matches Dec-only performance perfectly). I am writing up a survey/info but it likely won't be completed for the next few days or so. I think it may be worth evaluating, in the near future, Flan-T5 with TRLX on a task such as summarization with RLHF. Then compare to Carper's GPT-J 6B RLHF summarization ppo model to see if the difference in performance is negligible.

It still may be worth exploring PaLM Enc/Dec (even without RLHF) on the side so I added what you had here in addition to the previous code I wrote above to a separate repository.

Thank you,

Enrico

@conceptofmind This is just a sidenote, but as a matter of fact, for tasks like summarization and translation, supervised-finetuned T5 perform pretty well relative to RLHF GPT, as can be seen from OpenAI's RLHF for summarization paper for example. But these are among a small number of generative tasks where we know that T5 works better. I'd instead recommend a mixture of some common tasks submitted to ChatGPT, which better reflects the real world task distribution and does not give an unfair advantage to enc-dec models.

This encoder-decoder implementation makes it possible to test modifications like in Do Transformer Modifications Transfer Across Implementations and Applications?. Decoder-only models half the parameter size and also simplify the software code, yet they only keep the compute-intensive decoder, whereas the encoder-decoder models reuse the encoder computation. In the abstract of Compare Encoder-Decoder, Encoder-Only, and Decoder-Only Architectures for Text Generation on Low-Resource Datasets it is written that the full-transformer performs better with enough data. On the other side, the authors of What Language Model Architecture and Pretraining Objective Work Best for Zero-Shot Generalization? contradict Ul2 and prefer a decoder-only model.

@AranKomat

Enc-dec models are much more compute-consuming at inference than dec-only models for dialogs, since it has to re-encode the entire context each time it receives a response, since the encoder is using bidirecitonal attention. This can easily increase the inference cost by 10x.

How do you come to this conclusion? Doesn't the decoder also have to compute all of its cross-attentions with each new token? The T5 paper states that it is the other way: The decoder has a 10 times higher computational cost than the encoder.

a. FLAN-T5 performs much worse w/ CoT than direct prediction (Table 5 from FLAN-PaLM paper) compared with GPT counterparts.

Which models do you compare? Flan-T5-XXL with 11B performs better than Flan-GPT-PaLM with 8B parameters, besides the direct TyDiQA.

b. T5 XXL finetuned on GSM8K (~8% according to https://arxiv.org/abs/2212.08410) performs much worse than the ~10B GPT (~20% according to the original GSM8K paper). Admittedly, T5 isn't using a calculator and its tokenizer may not be suitable for this task, but I think it's safe to say that T5 isn't outperforming GPT on this task.

"Teaching Small Language Models to Reason" shows that accuracy increased from 8.11% to 21.99% with additional data. Which numbers are you referring to?

c. T0 performs poorly on dialog task, even compared with untuned GPT (Table 5: https://arxiv.org/abs/2212.12017)

Table 5 is about the varying benchmark proportions. There is a comparison between T0 and GPT in table 10, which states that the smaller T0 11B performs better than the bigger OPT-models and nearly as well as OPT-IML 175B.

d. AlphaCode (https://arxiv.org/abs/2203.07814) tried both enc-dec and dec-only models, and while enc-dec performs more efficiently than dec-only on HumanEval, which can be split into input-output pairs easily, dec-only works much better on various problem formats.

The training objective isn't bound to encoder-decoder or decoder-only models as Ul2 demonstrates. The encoder-decoder model should also fit this setting with all training objectives or the specific objective.

@Bachstelze are you a phd student? are you planning on doing any small scale training with the encoder / decoder?

@AranKomat what are your thoughts on @Bachstelze 's points?

How do you come to this conclusion? Doesn't the decoder also have to compute all of its cross-attentions with each new token?

If you need to re-encode the tokens generated by the model or the newly streamed input, the model has to process them along with the context, which is why it costs a lot more than decoder decoding. If you instead want to re-encode the new input on the decoder side, it's not as expensive, but this means you cannot leverage the power of encoder for this input.

The T5 paper states that it is the other way: The decoder has a 10 times higher computational cost than the encoder.

No idea where the T5 paper says a decoder costs 10x more computation than encoder. If you're referring to the fact that T5 performs on par with GPT with 10x more compute budget, then that's limited to a narrow range of tasks.

Which models do you compare? Flan-T5-XXL with 11B performs better than Flan-GPT-PaLM with 8B parameters, besides the direct TyDiQA.

Table 5 of FLAN-PaLM paper shows both standard prediction and CoT prediction for MMLU and BBH, and FLAN-T5-XXL performs much worse on CoT relative to the standard prediction than FLAN-PaLM 8B even after taking into account the fact that the former costs a bit less in compute.

"Teaching Small Language Models to Reason" shows that accuracy increased from 8.11% to 21.99% with additional data. Which numbers are you referring to?

This merely means that distillation works. The GPT distilled from the same dataset would perform much better too. Distillation from a larger model works for training smaller models, but when you want to train a SotA model, you don't have any larger model to begin with.

Table 5 is about the varying benchmark proportions. There is a comparison between T0 and GPT in table 10, which states that the smaller T0 11B performs better than the bigger OPT-models and nearly as well as OPT-IML 175B.

Blended Skill Talk is the dialog task I'm referring to, and clearly T5/T0's number is worse than the rest.

The training objective isn't bound to encoder-decoder or decoder-only models as Ul2 demonstrates. The encoder-decoder model should also fit this setting with all training objectives or the specific objective.

They tried next token prediction on the decoder side of the enc-dec, and it still didn't work. It doesn't seem to be the case of the choice of objective. Also this doesn't really address my point.

thanks @AranKomat , will probably revert it later today and move on to the self-contained application for gathering human feedback

encoder / decoder removed in 0.0.61