File size: 4,859 Bytes
1c700e2
 
00e38e7
 
41779c8
 
 
 
 
 
 
 
1c700e2
 
0378e8a
1c700e2
0378e8a
 
 
 
 
fccec1f
 
 
 
0378e8a
 
 
 
 
1c700e2
a8697c7
1c700e2
0378e8a
a8697c7
 
 
 
 
 
 
 
1c700e2
0378e8a
 
 
 
1c700e2
fd48d85
 
 
 
 
0a41636
 
fd48d85
0a41636
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
---
library_name: transformers
datasets:
- tomg-group-umd/wikipedia-en-2k-samples
tags:
- goldfish-loss
- memorization
- mitigation
license: apache-2.0
language:
- en
pipeline_tag: text2text-generation
---

# Quick Links

- **GitHub Repository**: https://github.com/ahans30/goldfish-loss
- **arXiv**: https://arxiv.org/abs/2406.10209

# Goldfish Loss

<div align="center">
  <img src="https://raw.githubusercontent.com/ahans30/goldfish-loss/main/assets/goldfish-loss.jpg" width="300"/>
</div>

We introduce goldfish loss, a new language modeling loss function that mitigates memorization of training data. 
Specifically, goldfish loss pseudorandomly drops $1/k$ of total tokens seen (in the forward pass) during loss computation (i.e., it doesn't compute loss for these tokens), with k being a hyperparameter. 
We show that the model finds it increasingly difficult to verbatim regurgitate training data even after 100 epochs. Please read our paper linked below for more details.

# Overview

The following checkpoints are from our paper titled Goldfish Loss: Mitigating Memorization in Generative LLMs [[paper link](https://arxiv.org/abs/2406.10209)]. 

| Checkpoint Name                                                                                               | k-GL | Token Drop Strategy | Pretrain Tokens | Primary Dataset | Canaries Dataset for Memorization                                   |
| ------------------------------------------------------------------------------------------------------------- | ---- | ------------------- | --------------- | --------------- | ----------------------------------------------------------------------------------- |
| [tomg-group-umd/3-goldfish-loss-llama-1B](https://huggingface.co/tomg-group-umd/3-goldfish-loss-llama-1B)     | 3    | Hash (width = 13)   | 20B             | Redpajama       | [Wikipedia](https://huggingface.co/datasets/tomg-group-umd/wikipedia-en-2k-samples) |
| [tomg-group-umd/4-goldfish-loss-llama-1B](https://huggingface.co/tomg-group-umd/4-goldfish-loss-llama-1B)     | 4    | Hash (width = 13)   | 20B             | Redpajama       | [Wikipedia](https://huggingface.co/datasets/tomg-group-umd/wikipedia-en-2k-samples) |
| [tomg-group-umd/8-goldfish-loss-llama-1B](https://huggingface.co/tomg-group-umd/8-goldfish-loss-llama-1B)     | 8    | Hash (width = 13)   | 20B             | Redpajama       | [Wikipedia](https://huggingface.co/datasets/tomg-group-umd/wikipedia-en-2k-samples) |
| [tomg-group-umd/32-goldfish-loss-llama-1B](https://huggingface.co/tomg-group-umd/32-goldfish-loss-llama-1B)   | 32   | Hash (width = 13)   | 20B             | Redpajama       | [Wikipedia](https://huggingface.co/datasets/tomg-group-umd/wikipedia-en-2k-samples) |
| [tomg-group-umd/128-goldfish-loss-llama-1B](https://huggingface.co/tomg-group-umd/128-goldfish-loss-llama-1B) | 128  | Hash (width = 13)   | 20B             | Redpajama       | [Wikipedia](https://huggingface.co/datasets/tomg-group-umd/wikipedia-en-2k-samples) |
| [tomg-group-umd/control-llama-1B](https://huggingface.co/tomg-group-umd/control-llama-1B)                     | \-   | No Tokens Dropped   | 20B             | Redpajama       | None                                                                                |
| [tomg-group-umd/standard-loss-llama-1B](https://huggingface.co/tomg-group-umd/standard-loss-llama-1B)         | \-   | No Tokens Dropped   | 20B             | Redpajama       | [Wikipedia](https://huggingface.co/datasets/tomg-group-umd/wikipedia-en-2k-samples) |

### Description
- `standard-loss-llama-1B` and `control-llama-1B` are trained with the standard causal language modeling loss, which has the same exact specifications as the goldfish models. 
- The control model differs only in the fact that it did not utilize the canaries dataset for memorization and was simply pre-trained on 20B Redpajama tokens. 
- The Canaries dataset, which contains 2000 Wikidocs, is repeated 50 times throughout the pre-training. Thus, it contains around ~204M tokens in total (including padding).

# Technical Specification

Each checkpoint mentioned above used randomly initialized [TinyLLaMA-1.1B](https://huggingface.co/TinyLlama/TinyLlama_v1.1) architecture. 
For pretraining details, please find check our [GitHub](https://github.com/ahans30/goldfish-loss) repository. 

# Cite our work

If you find our model, codebase or dataset beneficial, please consider citing our work:

```bibtex
@misc{hans2024like,
      title={Be like a Goldfish, Don't Memorize! Mitigating Memorization in Generative LLMs}, 
      author={Abhimanyu Hans and Yuxin Wen and Neel Jain and John Kirchenbauer and Hamid Kazemi and Prajwal Singhania and Siddharth Singh and Gowthami Somepalli and Jonas Geiping and Abhinav Bhatele and Tom Goldstein},
      year={2024},
      eprint={2406.10209},
      archivePrefix={arXiv},
}
```