tripplyons / retentive-network

A minimal PyTorch implementation of Retentive Network

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Retentive Network (RetNet)

A minimal PyTorch implementation of Retentive Network: A Successor to Transformer for Large Language Models

Notes

  • This repository exists mostly for educational purposes, for both me and anyone else who wants to learn about RetNet.
  • It is basically a direct translation of the math in the paper, complex numbers and all. I haven't looked into it, but there are other implementations that claim to do it without needing complex numbers.
  • It makes heavy use of torch.einsum, so make sure you understand it before trying to understand this code.
  • I haven't implemented the chunkwise recurrent mode yet, this repo only has the parallel and the recurrent modes.

Usage

For more examples see test.py

import torch
from retnet import RetNet

model = RetNet(256, 64, 4, 4)

x = torch.randint(0, 256, (1, 64), dtype=torch.long)

print(model.loss(x))

About

A minimal PyTorch implementation of Retentive Network

License:MIT License


Languages

Language:Python 100.0%