HazyResearch / m2

Repo for "Monarch Mixer: A Simple Sub-Quadratic GEMM-Based Architecture"

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

M2 model is applied on single image deraining model based on transformer

yingxuanhi opened this issue · comments

Hello! Your Monarch Mixer is extremely awesome work! I'd like to ask if it's possible for me to apply your M2 model to my Transformer-based single-image deraining model. Can I use the M2 model in the attention layer or the MLP layer of my Transformer, or even apply it to both?

Hello!

Yes, this would be a great use case. You'll want to replace the attention layer with something like this: https://github.com/HazyResearch/m2/blob/main/bert/src/mm/monarch_mixer_sequence_mixer.py

And the MLP layer with something like this: https://github.com/HazyResearch/m2/blob/main/bert/src/bert_layers.py#L297 (using the BlockdiagLinear class).

excuse me, I don't quite understand how to use the BlockdiagLinear class to replace the MLP layer of the general Transformer.

Could you please explain it a little bit? I'm sorry to bother you with your precious time.

You can replace your MLP class with this one (changing the configs to however it works in your model):

import torch
from torch import nn
from src.mm.blockdiag_linear import BlockdiagLinear

class M2MLP(nn.Module):
    """Applies the MLP."""

    def __init__(self, config):
        super().__init__()
        self.config = config

        if self.config.use_monarch_mlp:
            linear_cls = partial(BlockdiagLinear, nblocks=self.config.monarch_mlp_nblocks)
        else:
            linear_cls = nn.Linear

        self.linear = linear_cls(config.hidden_size,
                                      config.intermediate_size,
                                      bias=False)
        self.act = nn.GELU(approximate='none')
        self.wo = linear_cls(config.intermediate_size, config.hidden_size)

        self.layernorm = nn.LayerNorm(config.hidden_size,
                                      eps=config.layer_norm_eps)
        
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        """Compute new hidden states from current hidden states.

        Args:
            hidden_states (torch.Tensor): The (unpadded) hidden states from
                the attention layer [nnz, dim].
        """
        
        residual_connection = hidden_states
        hidden_states = self.linear(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.wo(hidden_states)
        hidden_states = self.layernorm(hidden_states + residual_connection)
        return hidden_states

The arguments expected in config are:

  • use_monarch_mlp: if True, use the block diagonal stuff, if False, use a normal MLP
  • monarch_mlp_nblocks: number of blocks in the block diagonal layers
  • hidden_size: hidden size of your model
  • intermediate_size: the intermediate size in the MLP (usually 4 * hidden_size)
  • layer_norm_eps: epsilon for a layer norm (or you can remove this entirely)
  • hidden_dropout_prob: dropout for the hidden layers (or you can remove this entirely)

I hope this helps!

You can replace your MLP class with this one (changing the configs to however it works in your model):

import torch
from torch import nn
from src.mm.blockdiag_linear import BlockdiagLinear

class M2MLP(nn.Module):
    """Applies the MLP."""

    def __init__(self, config):
        super().__init__()
        self.config = config

        if self.config.use_monarch_mlp:
            linear_cls = partial(BlockdiagLinear, nblocks=self.config.monarch_mlp_nblocks)
        else:
            linear_cls = nn.Linear

        self.linear = linear_cls(config.hidden_size,
                                      config.intermediate_size,
                                      bias=False)
        self.act = nn.GELU(approximate='none')
        self.wo = linear_cls(config.intermediate_size, config.hidden_size)

        self.layernorm = nn.LayerNorm(config.hidden_size,
                                      eps=config.layer_norm_eps)
        
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        """Compute new hidden states from current hidden states.

        Args:
            hidden_states (torch.Tensor): The (unpadded) hidden states from
                the attention layer [nnz, dim].
        """
        
        residual_connection = hidden_states
        hidden_states = self.linear(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.wo(hidden_states)
        hidden_states = self.layernorm(hidden_states + residual_connection)
        return hidden_states

The arguments expected in config are:

  • use_monarch_mlp: if True, use the block diagonal stuff, if False, use a normal MLP
  • monarch_mlp_nblocks: number of blocks in the block diagonal layers
  • hidden_size: hidden size of your model
  • intermediate_size: the intermediate size in the MLP (usually 4 * hidden_size)
  • layer_norm_eps: epsilon for a layer norm (or you can remove this entirely)
  • hidden_dropout_prob: dropout for the hidden layers (or you can remove this entirely)

I hope this helps!

Thanks for your help!!!!! I will devote to this part!!

Hello!

Yes, this would be a great use case. You'll want to replace the attention layer with something like this: https://github.com/HazyResearch/m2/blob/main/bert/src/mm/monarch_mixer_sequence_mixer.py

And the MLP layer with something like this: https://github.com/HazyResearch/m2/blob/main/bert/src/bert_layers.py#L297 (using the BlockdiagLinear class).

I apologize for any confusion. Here's a more fluent English translation of your question:

"I'm currently attempting to apply M2MLP to my custom Transformer model. Could you please provide information about the CUDA, cuDNN, Python, PyTorch, torchvision, and CUDA Toolkit versions used in the M2 package? Thank you."