PEFT
Safetensors
English
jinjieyuan's picture
Update README.md
3e3fa24 verified
|
raw
history blame
6.38 kB
---
language: en
license: apache-2.0
library_name: peft
---
# Shears Adapter Card: shears-mpt-7b-50-gsm8k-heuristic-adapter
The heuristic adapter discovered from the [super-adapter](https://huggingface.co/IntelLabs/shears-mpt-7b-50-gsm8k-super-adapter) fine-tuned on sparsified [MPT-7B](https://huggingface.co/mosaicml/mpt-7b) with GSM8K datasets using Shears.
## Paper Abstract
Recently, several approaches successfully demonstrated that weight-sharing Neural Architecture Search (NAS) can effectively explore a search space of elastic low-rank adapters (LoRA), allowing the parameter-efficient fine-tuning (PEFT) and compression of large language models. In this paper, we introduce a novel approach called Shears, demonstrating how the integration of cost-effective sparsity and a proposed Neural Low-rank adapter Search (NLS) algorithm can further improve the efficiency of PEFT approaches. Results demonstrate the benefits of Shears compared to other methods, reaching high sparsity levels while improving or with little drop in accuracy, utilizing a single GPU for a pair of hours.
## Model Details
### Information
- **Adapter name:** shears-mpt-7b-50-gsm8k-heuristic-adapter
- **Base model:** [IntelLabs/shears-mpt-7b-50-base](https://huggingface.co/IntelLabs/shears-mpt-7b-50-base)
- **Sparsity:** 50%
- **Subnetwork version:** Heuristic
- **NNCF Configuration:** [nncf_shears_mpt.json](https://github.com/IntelLabs/Hardware-Aware-Automated-Machine-Learning/blob/main/Shears/nncf_config/nncf_shears_mpt.json)
### Adapter Configuration
- **LoRA rank:** 32 (24 in the heuristic subnetwork)
- **LoRA alpha:** 64
- **LoRA target modules:** q_proj, k_proj, v_proj, out_proj, up_proj, down_proj
- **LoRA rank search space:** [32, 24, 16] (for each LoRA module)
### Training Hyperparameters
- **Batch size:** 16
- **Learning rate:** 3e-4
- **Epoch:** 5
### Training and Evaluation
GSM8K dataset: [https://huggingface.co/datasets/gsm8k](https://huggingface.co/datasets/gsm8k)
## How to use
Use our modified PEFT library (apply [patch](https://github.com/IntelLabs/Hardware-Aware-Automated-Machine-Learning/tree/main/Shears/patches/peft-modifications-for-shears-inference-usage.patch)):
```bash
git clone https://github.com/huggingface/peft.git
cd peft && git checkout v0.5.0 && git apply --ignore-space-change --ignore-whitespace peft-modifications-for-shears-inference-usage.patch && pip install -e . && cd ..
```
```python
import torch
from peft import PeftModel
from transformers import AutoModelForCausalLM
from transformers import AutoTokenizer
def generate_prompt(instruction):
return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.
### Instruction:
{instruction}
### Response:
"""
base_model = AutoModelForCausalLM.from_pretrained("IntelLabs/shears-mpt-7b-50-base", trust_remote_code=True)
model = PeftModel.from_pretrained(base_model, "IntelLabs/shears-mpt-7b-50-gsm8k-heuristic-adapter")
model.eval()
non_zero_params = sum([(param.data != 0).sum().item() for _, param in model.named_parameters()])
print(f"Number of all non-zero parameters: {non_zero_params}")
tokenizer = AutoTokenizer.from_pretrained("IntelLabs/shears-mpt-7b-50-base", trust_remote_code=True)
instruction = "Edgar eats 18 pretzels a day. If his brother eats 1/2 as many, how many does his brother eat in a week?"
prompt = generate_prompt(instruction)
inputs = tokenizer(prompt, return_tensors="pt")
input_ids = inputs["input_ids"].to(model.device)
with torch.no_grad():
generation_output = model.generate(
input_ids=input_ids,
return_dict_in_generate=True,
output_scores=True,
max_new_tokens=256,
use_cache=True,
num_beams=4,
)
s = generation_output.sequences[0]
output = tokenizer.decode(s)
print(output)
```
## Evaluation Results
| Model | Sparsity | GSM8K Accuracy |
|-----------------------|-------------|-------|
| [**MPT-7B-Shears**](https://huggingface.co/IntelLabs/shears-mpt-7b-50-gsm8k-heuristic-adapter) | **50%** | 33.4 |
## Model Sources
**Repository:** [https://github.com/IntelLabs/Hardware-Aware-Automated-Machine-Learning/tree/main/Shears](https://github.com/IntelLabs/Hardware-Aware-Automated-Machine-Learning/tree/main/Shears)
**Paper:**
- [Shears: Unstructured Sparsity with Neural Low-rank Adapter Search](https://arxiv.org/abs/2404.10934)
- [Low-Rank Adapters Meet Neural Architecture Search for LLM Compression](https://arxiv.org/abs/2501.16372)
## Ethical Considerations
Intel is committed to respecting human rights and avoiding causing or contributing to adverse impacts on human rights. See [Intel’s Global Human Rights Principles](https://www.intel.com/content/dam/www/central-libraries/us/en/documents/policy-human-rights.pdf). Intel’s products and software are intended only to be used in applications that do not cause or contribute to adverse impacts on human rights.
| Ethical Considerations | Description |
| ----------- | ----------- |
| Data | The adapter was trained using the GSM8K dataset as described above. |
| Human life | The model is not intended to inform decisions central to human life or flourishing. |
| Mitigations | No additional risk mitigation strategies were considered during model development. |
| Risks and harms | This model has not been assessed for harm or biases, and should not be used for sensitive applications where it may cause harm. |
| Use cases | - |
## Citation
```bash
@inproceedings{munoz-etal-2024-shears,
title = "Shears: Unstructured Sparsity with Neural Low-rank Adapter Search",
author = "Mu{\~n}oz, J. Pablo and
Yuan, Jinjie and
Jain, Nilesh",
editor = "Yang, Yi and
Davani, Aida and
Sil, Avi and
Kumar, Anoop",
booktitle = "Proceedings of the 2024 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies (Volume 6: Industry Track)",
month = jun,
year = "2024",
address = "Mexico City, Mexico",
publisher = "Association for Computational Linguistics",
url = "https://aclanthology.org/2024.naacl-industry.34",
doi = "10.18653/v1/2024.naacl-industry.34",
pages = "395--405",
}
```
## License
Apache-2.0