MemLong is a method that utilizes explicit retrievers to extend the context length of language models. It is compatible with any current Decoder-Only architecture model and requires only a small amount of fine-tuning data to achieve ultra-long length extension.
🏠 Homepage
👤 Bui1dMySea
- Website: https://github.com/Bui1dMySea
- Github: @Bui1dMySea
Chunking: For sequences of arbitrary length, we chunk them into fixed lengths (in our experiments, we used lengths of 256 and 512 ).
Memory and Retrieval: For the retrieval method, we innovatively proposed using an external retriever to search for the current chunk. The benefit of this approach is that it leverages the powerful retrieval capabilities of current models like bge-m3. For memory, we introduced dynamic memory planning. Specifically, our strategy differs from conventional FIFO (First In, First Out) by using a counter to calculate the trigger frequency of each chunk. When the memory length is exceeded, we prioritize deleting chunks with lower trigger frequencies until the required number of deletions is met.
Positional Encoding and Memory Fusion: Our experiments found that if we reassign positional information for retrieved chunks and the current chunk, such as rearranging the original positional information from
Efficient Training: During the fine-tuning phase, we only fine-tune the layers above the memory layer of the model. The advantage of this approach is that it significantly reduces the number of parameters to be fine-tuned compared to full fine-tuning, saving a lot of GPU memory. This is also why we only need a small amount of data.
conda create -n MemLong python=3.10
conda activate MemLong
pip install -r requirements.txt
conda install -c pytorch -c nvidia faiss-gpu=1.9.0
In the paper, we used slimpajama as our training dataset. We strongly recommend that you download to the ./data folder. For Chinese users, we recommend using HF-mirror for downloading. Here are some specific steps for downloading.
- Only Need for Chinese User:
export HF_ENDPOINT=https://hf-mirror.com
pip install -U huggingface_hub
wget https://hf-mirror.com/hfd/hfd.sh
chmod a+x hfd.sh
sudo apt-get install aria2c
cd ./data
./hfd.sh yaofu/slimpajama-per-source-length-upsample --dataset --tool aria2c -x 4
cd ./data
bash process.sh
In order to make the model adapt to the MemLong in advance, we need to do a warm up.
You can easily train a version of the LoRA model with no more than 20g of VRAM (single card 3090) and within 35 hours.
We provide the training script for the OpenLLaMA version, and you can easily train using the following command.
bash train_stage_1.sh
The biggest difference from the first step is that we have frozen the underlying parameters and introduced our core idea — the MemLong framework. In the default script we provide, we set 13 layers as memory layers, and define [13, 17, 21, 25] as retrieval layers.
Similarity,we provide the training script for the OpenLLaMA version, and you can easily train using the following command.
bash train_stage_2.sh
We provide two types of evaluations, including language modeling evaluation and ICL (In Context Learning) evaluation, which you can perform specifically under the eval folder.
For language modeling tasks,you can first cd eval/language_modeling
and then eval the model or method you want in bash script/anything.sh
This code is coming soon.
😀 Weijie Liu
- Github: @Bui1dMySea
😀 ZetangForward
- Github: @ZetangForward
Contributions, issues and feature requests are welcome!
Feel free to check issues page.
Give a ⭐️ if this project helped you!