File size: 2,132 Bytes
a06058a
 
504b233
a06058a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
# Model Card for Model ID

This is a Llama-2-7b model fine-tuned on MQuAKE using Localized Fine-tuning on LLM Representations (LoFiT; https://arxiv.org/abs/2406.01563). This model checkpoint modifies the attention outputs of 96 attention heads (10% of all attention heads).


### Model Description

- **License:** mit
- **Finetuned from model:** meta-llama/Llama-2-7b-hf

### Model Sources

<!-- Provide the basic links for the model. -->

- **Repository:** https://github.com/fc2869/lo-fit
- **Paper:** https://arxiv.org/abs/2406.01563

## Uses

<!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
Please use the lofit github repo (https://github.com/fc2869/lo-fit) and then use the following code snippet to run evaluations on MQuAKE in the repo with this checkpoint.
```
from models.modeling_llama import LlamaModel,LlamaForCausalLM
from transformers import AutoTokenizer
import torch
from utils.evaluate import evaluate_mquake
from utils.dataloaders import MQUAKE

checkpoint = 'fcyin/llama2_7B_base_lofit_mquake'
model_name = 'llama2_7B_base_lofit_mquake'
device = 'cuda'
cache_dir = './'
applied_module = 'attention'
torch_dtype = torch.float32

model = LlamaForCausalLM.custom_from_pretrained(checkpoint,
                                                device_map=device, 
                                                cache_dir=cache_dir,
                                                applied_module = applied_module,
                                                torch_dtype=torch_dtype).to(device)
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
dataloader = MQUAKE(
             split_dir = './dataset/MQuAKE',
             chat_template = False,
             model_name = model_name
         )
dataset = dataloader.load_data()

evaluate_mquake(eval_dataset=dataset['test'],model_name=model_name,model=model,tokenizer=tokenizer,fname='./',batch_size=16,max_new_tokens=16,apply_chat_template=False)
```

## Training Details
Please refer to the [paper](https://arxiv.org/abs/2406.01563) for the training details.