radarFudan / INTEREST

Temporal re-weighting improve the long-term memory learning.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

INTEREST

Improve loNg-Term mEmoRy lEarning reScaling Temporally

PyTorch Lightning Config: Hydra Template
Paper

Temporal weighted error

$$\textrm{Error}^{\textrm{TPE}} = \frac{1}{T} \sum_{t=1}^T w(t) |y(t) - \hat{y}(t)|, \quad w(t) > 0.$$

If we take $w(t)$ to be $w(t) = \frac{1}{t^p}$, the following curves characterize the error functions' temporal bias.

Memory bias

TODO:

  1. Tasks
    1. Synthetic linear functional
    2. Copying problem
    3. Text summarization
  2. Models
    1. TCN
    2. Transformer
  3. How to tune?

Installation

Pip

# clone project
git clone https://github.com/radarFudan/Curse-of-memory
cd Curse-of-memory

# [OPTIONAL] create conda environment
conda create -n myenv python=3.9
conda activate myenv

# install pytorch according to instructions
# https://pytorch.org/get-started/

# install requirements
pip install -r requirements.txt

Conda

# clone project
git clone https://github.com/radarFudan/Curse-of-memory
cd Curse-of-memory

# create conda environment and install dependencies
conda env create -f environment.yaml -n myenv

# activate conda environment
conda activate myenv

How to train

python src/train.py experiment=Lf/lf-rnn.yaml

Future plan

Refs

Curse of memory / stable approximation / memory functions

@misc{wang2023improve,
      title={Improve Long-term Memory Learning Through Rescaling the Error Temporally}, 
      author={Shida Wang and Zhanglu Yan},
      year={2023},
      eprint={2307.11462},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

About

Temporal re-weighting improve the long-term memory learning.


Languages

Language:Jupyter Notebook 69.7%Language:Python 27.8%Language:Shell 2.3%Language:Makefile 0.2%