Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	
		Ethan Shen
		
	commited on
		
		
					Commit 
							
							·
						
						dda1539
	
1
								Parent(s):
							
							bc4858e
								
Initial commit
Browse files- .gitignore +3 -0
- LICENSE +126 -0
- app.py +97 -0
- params/g15_d3_mixed.json +27 -0
- params/g20_d3_mixed.json +27 -0
- params/g5_d3_mixed.json +27 -0
- params/p15_d10_mixed.json +26 -0
- params/p15_d2_mixed.json +26 -0
- params/p15_d3_mixed.json +26 -0
- params/p15_d3_ngram4_mixed.json +22 -0
- params/p15_d4_mixed.json +26 -0
- params/p15_d5_mixed.json +26 -0
- params/p15_d6_mixed.json +26 -0
- params/p25_d3_mixed.json +26 -0
- params/p40_d3_mixed.json +12 -0
- params/p5_d3_mixed.json +26 -0
- requirements.txt +11 -0
- superposed/llama/__init__.py +6 -0
- superposed/llama/__pycache__/__init__.cpython-312.pyc +0 -0
- superposed/llama/__pycache__/generation.cpython-312.pyc +0 -0
- superposed/llama/__pycache__/model.cpython-312.pyc +0 -0
- superposed/llama/__pycache__/superpose.cpython-312.pyc +0 -0
- superposed/llama/__pycache__/superposed_generation.cpython-312.pyc +0 -0
- superposed/llama/__pycache__/superposed_model.cpython-312.pyc +0 -0
- superposed/llama/__pycache__/tokenizer.cpython-312.pyc +0 -0
- superposed/llama/__pycache__/utils.cpython-312.pyc +0 -0
- superposed/llama/generation.py +268 -0
- superposed/llama/metrics.py +109 -0
- superposed/llama/model.py +548 -0
- superposed/llama/superpose.py +328 -0
- superposed/llama/superposed_generation.py +198 -0
- superposed/llama/superposed_model.py +515 -0
- superposed/llama/tokenizer.py +68 -0
- superposed/llama/utils.py +70 -0
- superposed/ngrams/__pycache__/ngram_models.cpython-312.pyc +0 -0
- superposed/ngrams/make_corpus.py +268 -0
- superposed/ngrams/ngram_models.py +115 -0
- superposed/ngrams/test.json +8 -0
- superposed/notebooks/custom.ipynb +289 -0
- superposed/notebooks/nq.ipynb +417 -0
- superposed/notebooks/triviaqa.ipynb +404 -0
    	
        .gitignore
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            .env
         | 
| 2 | 
            +
            weights
         | 
| 3 | 
            +
            ckpts-200k
         | 
    	
        LICENSE
    ADDED
    
    | @@ -0,0 +1,126 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            LLAMA 2 COMMUNITY LICENSE AGREEMENT	
         | 
| 2 | 
            +
            Llama 2 Version Release Date: July 18, 2023
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            "Agreement" means the terms and conditions for use, reproduction, distribution and 
         | 
| 5 | 
            +
            modification of the Llama Materials set forth herein.
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            "Documentation" means the specifications, manuals and documentation 
         | 
| 8 | 
            +
            accompanying Llama 2 distributed by Meta at ai.meta.com/resources/models-and-
         | 
| 9 | 
            +
            libraries/llama-downloads/.
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            "Licensee" or "you" means you, or your employer or any other person or entity (if 
         | 
| 12 | 
            +
            you are entering into this Agreement on such person or entity's behalf), of the age 
         | 
| 13 | 
            +
            required under applicable laws, rules or regulations to provide legal consent and that 
         | 
| 14 | 
            +
            has legal authority to bind your employer or such other person or entity if you are 
         | 
| 15 | 
            +
            entering in this Agreement on their behalf.
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            "Llama 2" means the foundational large language models and software and 
         | 
| 18 | 
            +
            algorithms, including machine-learning model code, trained model weights, 
         | 
| 19 | 
            +
            inference-enabling code, training-enabling code, fine-tuning enabling code and other 
         | 
| 20 | 
            +
            elements of the foregoing distributed by Meta at ai.meta.com/resources/models-and-
         | 
| 21 | 
            +
            libraries/llama-downloads/.
         | 
| 22 | 
            +
             | 
| 23 | 
            +
            "Llama Materials" means, collectively, Meta's proprietary Llama 2 and 
         | 
| 24 | 
            +
            Documentation (and any portion thereof) made available under this Agreement.
         | 
| 25 | 
            +
             | 
| 26 | 
            +
            "Meta" or "we" means Meta Platforms Ireland Limited (if you are located in or, if you 
         | 
| 27 | 
            +
            are an entity, your principal place of business is in the EEA or Switzerland) and Meta 
         | 
| 28 | 
            +
            Platforms, Inc. (if you are located outside of the EEA or Switzerland). 
         | 
| 29 | 
            +
             | 
| 30 | 
            +
            By clicking "I Accept" below or by using or distributing any portion or element of the 
         | 
| 31 | 
            +
            Llama Materials, you agree to be bound by this Agreement.
         | 
| 32 | 
            +
             | 
| 33 | 
            +
            1. License Rights and Redistribution. 
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                  a. Grant of Rights. You are granted a non-exclusive, worldwide, non-
         | 
| 36 | 
            +
            transferable and royalty-free limited license under Meta's intellectual property or 
         | 
| 37 | 
            +
            other rights owned by Meta embodied in the Llama Materials to use, reproduce, 
         | 
| 38 | 
            +
            distribute, copy, create derivative works of, and make modifications to the Llama 
         | 
| 39 | 
            +
            Materials.  
         | 
| 40 | 
            +
                  
         | 
| 41 | 
            +
                  b. Redistribution and Use.  
         | 
| 42 | 
            +
             | 
| 43 | 
            +
                        i. If you distribute or make the Llama Materials, or any derivative works 
         | 
| 44 | 
            +
            thereof, available to a third party, you shall provide a copy of this Agreement to such 
         | 
| 45 | 
            +
            third party. 
         | 
| 46 | 
            +
                        ii.  If you receive Llama Materials, or any derivative works thereof, from 
         | 
| 47 | 
            +
            a Licensee as part of an integrated end user product, then Section 2 of this 
         | 
| 48 | 
            +
            Agreement will not apply to you. 
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                        iii. You must retain in all copies of the Llama Materials that you 
         | 
| 51 | 
            +
            distribute the following attribution notice within a "Notice" text file distributed as a 
         | 
| 52 | 
            +
            part of such copies: "Llama 2 is licensed under the LLAMA 2 Community License, 
         | 
| 53 | 
            +
            Copyright (c) Meta Platforms, Inc. All Rights Reserved."
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                        iv. Your use of the Llama Materials must comply with applicable laws 
         | 
| 56 | 
            +
            and regulations (including trade compliance laws and regulations) and adhere to the 
         | 
| 57 | 
            +
            Acceptable Use Policy for the Llama Materials (available at 
         | 
| 58 | 
            +
            https://ai.meta.com/llama/use-policy), which is hereby incorporated by reference into 
         | 
| 59 | 
            +
            this Agreement.
         | 
| 60 | 
            +
             | 
| 61 | 
            +
                        v. You will not use the Llama Materials or any output or results of the 
         | 
| 62 | 
            +
            Llama Materials to improve any other large language model (excluding Llama 2 or 
         | 
| 63 | 
            +
            derivative works thereof).  
         | 
| 64 | 
            +
             | 
| 65 | 
            +
            2. Additional Commercial Terms. If, on the Llama 2 version release date, the 
         | 
| 66 | 
            +
            monthly active users of the products or services made available by or for Licensee, 
         | 
| 67 | 
            +
            or Licensee's affiliates, is greater than 700 million monthly active users in the 
         | 
| 68 | 
            +
            preceding calendar month, you must request a license from Meta, which Meta may 
         | 
| 69 | 
            +
            grant to you in its sole discretion, and you are not authorized to exercise any of the 
         | 
| 70 | 
            +
            rights under this Agreement unless or until Meta otherwise expressly grants you 
         | 
| 71 | 
            +
            such rights.
         | 
| 72 | 
            +
                        
         | 
| 73 | 
            +
            3. Disclaimer of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE 
         | 
| 74 | 
            +
            LLAMA MATERIALS AND ANY OUTPUT AND RESULTS THEREFROM ARE 
         | 
| 75 | 
            +
            PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, 
         | 
| 76 | 
            +
            EITHER EXPRESS OR IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY 
         | 
| 77 | 
            +
            WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR 
         | 
| 78 | 
            +
            FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE 
         | 
| 79 | 
            +
            FOR DETERMINING THE APPROPRIATENESS OF USING OR REDISTRIBUTING 
         | 
| 80 | 
            +
            THE LLAMA MATERIALS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR 
         | 
| 81 | 
            +
            USE OF THE LLAMA MATERIALS AND ANY OUTPUT AND RESULTS.
         | 
| 82 | 
            +
             | 
| 83 | 
            +
            4. Limitation of Liability. IN NO EVENT WILL META OR ITS AFFILIATES BE 
         | 
| 84 | 
            +
            LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, 
         | 
| 85 | 
            +
            NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS 
         | 
| 86 | 
            +
            AGREEMENT, FOR ANY LOST PROFITS OR ANY INDIRECT, SPECIAL, 
         | 
| 87 | 
            +
            CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN 
         | 
| 88 | 
            +
            IF META OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF 
         | 
| 89 | 
            +
            ANY OF THE FOREGOING.
         | 
| 90 | 
            +
             
         | 
| 91 | 
            +
            5. Intellectual Property.
         | 
| 92 | 
            +
             | 
| 93 | 
            +
                  a. No trademark licenses are granted under this Agreement, and in 
         | 
| 94 | 
            +
            connection with the Llama Materials, neither Meta nor Licensee may use any name 
         | 
| 95 | 
            +
            or mark owned by or associated with the other or any of its affiliates, except as 
         | 
| 96 | 
            +
            required for reasonable and customary use in describing and redistributing the 
         | 
| 97 | 
            +
            Llama Materials.
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                  b. Subject to Meta's ownership of Llama Materials and derivatives made by or 
         | 
| 100 | 
            +
            for Meta, with respect to any derivative works and modifications of the Llama 
         | 
| 101 | 
            +
            Materials that are made by you, as between you and Meta, you are and will be the 
         | 
| 102 | 
            +
            owner of such derivative works and modifications.
         | 
| 103 | 
            +
             | 
| 104 | 
            +
                  c. If you institute litigation or other proceedings against Meta or any entity 
         | 
| 105 | 
            +
            (including a cross-claim or counterclaim in a lawsuit) alleging that the Llama 
         | 
| 106 | 
            +
            Materials or Llama 2 outputs or results, or any portion of any of the foregoing, 
         | 
| 107 | 
            +
            constitutes an infringement of intellectual property or other rights owned or licensable 
         | 
| 108 | 
            +
            by you, then any licenses granted to you under this Agreement shall terminate as of 
         | 
| 109 | 
            +
            the date such litigation or claim is filed or instituted. You will indemnify and hold 
         | 
| 110 | 
            +
            harmless Meta from and against any claim by any third party arising out of or related 
         | 
| 111 | 
            +
            to your use or distribution of the Llama Materials.
         | 
| 112 | 
            +
             | 
| 113 | 
            +
            6. Term and Termination. The term of this Agreement will commence upon your 
         | 
| 114 | 
            +
            acceptance of this Agreement or access to the Llama Materials and will continue in 
         | 
| 115 | 
            +
            full force and effect until terminated in accordance with the terms and conditions 
         | 
| 116 | 
            +
            herein. Meta may terminate this Agreement if you are in breach of any term or 
         | 
| 117 | 
            +
            condition of this Agreement. Upon termination of this Agreement, you shall delete 
         | 
| 118 | 
            +
            and cease use of the Llama Materials. Sections 3, 4 and 7 shall survive the 
         | 
| 119 | 
            +
            termination of this Agreement. 
         | 
| 120 | 
            +
             | 
| 121 | 
            +
            7. Governing Law and Jurisdiction. This Agreement will be governed and 
         | 
| 122 | 
            +
            construed under the laws of the State of California without regard to choice of law 
         | 
| 123 | 
            +
            principles, and the UN Convention on Contracts for the International Sale of Goods 
         | 
| 124 | 
            +
            does not apply to this Agreement. The courts of California shall have exclusive 
         | 
| 125 | 
            +
            jurisdiction of any dispute arising out of this Agreement. 
         | 
| 126 | 
            +
             | 
    	
        app.py
    ADDED
    
    | @@ -0,0 +1,97 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import gradio as gr
         | 
| 2 | 
            +
            import json
         | 
| 3 | 
            +
            import os
         | 
| 4 | 
            +
            import spaces
         | 
| 5 | 
            +
            import torch
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            from dotenv import load_dotenv
         | 
| 8 | 
            +
            from huggingface_hub import login, snapshot_download
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            from superposed.llama.superposed_generation import SuperposedLlama
         | 
| 11 | 
            +
            from superposed.llama.tokenizer import Tokenizer
         | 
| 12 | 
            +
            from superposed.ngrams.ngram_models import make_models
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            # load_dotenv()
         | 
| 15 | 
            +
            # print(os.getenv("HF_ACCESS_TOKEN"))
         | 
| 16 | 
            +
            login(os.getenv("HF_ACCESS_TOKEN"))
         | 
| 17 | 
            +
            if not os.path.exists("./weights/"):
         | 
| 18 | 
            +
                os.mkdir("./weights/")
         | 
| 19 | 
            +
            snapshot_download(repo_id="meta-llama/Llama-2-7b", local_dir="./weights/")
         | 
| 20 | 
            +
            weight_path = "./weights/"
         | 
| 21 | 
            +
            # Load params
         | 
| 22 | 
            +
            param_file = "params/p15_d3_mixed.json"
         | 
| 23 | 
            +
            with open(param_file, "r") as f:
         | 
| 24 | 
            +
                params = json.load(f)
         | 
| 25 | 
            +
            alpha = params["alpha"]
         | 
| 26 | 
            +
            temp = params["temp"]
         | 
| 27 | 
            +
            n_drafts = params["n_drafts"]
         | 
| 28 | 
            +
            prompt_len = params["prompt_len"]
         | 
| 29 | 
            +
            n_token_sample = params["n_token_sample"]
         | 
| 30 | 
            +
            i_weights = params["i_weights"]
         | 
| 31 | 
            +
            i_length = params["i_length"]
         | 
| 32 | 
            +
            # Load main model
         | 
| 33 | 
            +
            model = SuperposedLlama.build(ckpt_dir=weight_path, 
         | 
| 34 | 
            +
                                     tokenizer_path=f'{weight_path}/tokenizer.model', 
         | 
| 35 | 
            +
                                     max_seq_len=100, 
         | 
| 36 | 
            +
                                     max_batch_size=32,
         | 
| 37 | 
            +
                                     model_parallel_size=1)
         | 
| 38 | 
            +
            tokenizer = Tokenizer(f'{weight_path}/tokenizer.model')
         | 
| 39 | 
            +
            # Create ngram models
         | 
| 40 | 
            +
            ngrams = make_models("ckpts-200k", bigram=True, trigram=True, fourgram=True, fivegram=True, sixgram=True, sevengram=False)
         | 
| 41 | 
            +
             | 
| 42 | 
            +
            def decode(tokenizer, encoding):
         | 
| 43 | 
            +
                """
         | 
| 44 | 
            +
                Args:
         | 
| 45 | 
            +
                    tokenizer (Any): Tokenizer
         | 
| 46 | 
            +
                    encoding (torch.Tensor): Encoding
         | 
| 47 | 
            +
                Returns:
         | 
| 48 | 
            +
                    decoding (str)
         | 
| 49 | 
            +
                """
         | 
| 50 | 
            +
                eos_locs = (encoding == tokenizer.eos_id).nonzero()
         | 
| 51 | 
            +
                if len(eos_locs > 0):
         | 
| 52 | 
            +
                    encoding = encoding[:eos_locs[0]]
         | 
| 53 | 
            +
                return tokenizer.decode(encoding.to(torch.int32).tolist())
         | 
| 54 | 
            +
             | 
| 55 | 
            +
            @spaces.GPU
         | 
| 56 | 
            +
            def update_options(input, num_tokens):
         | 
| 57 | 
            +
                tokenized_prompts = tokenizer.encode([input], True, False)
         | 
| 58 | 
            +
                alive_gens, _ = model.sup_generate(prompt_tokens=tokenized_prompts, 
         | 
| 59 | 
            +
                                                        smoothing="geom",
         | 
| 60 | 
            +
                                                        max_gen_len=num_tokens, 
         | 
| 61 | 
            +
                                                        n_token_sample=n_token_sample,
         | 
| 62 | 
            +
                                                        alpha=alpha, 
         | 
| 63 | 
            +
                                                        temp=temp,
         | 
| 64 | 
            +
                                                        n_drafts=n_drafts,
         | 
| 65 | 
            +
                                                        i_weights=i_weights,
         | 
| 66 | 
            +
                                                        i_length=i_length,
         | 
| 67 | 
            +
                                                        ngrams=ngrams,
         | 
| 68 | 
            +
                                                        get_time=False,
         | 
| 69 | 
            +
                                                        penalty=200)
         | 
| 70 | 
            +
                gens = alive_gens[0].reshape(n_drafts, -1)
         | 
| 71 | 
            +
                return decode(tokenizer, gens[0]), decode(tokenizer, gens[1]), decode(tokenizer, gens[2])
         | 
| 72 | 
            +
             | 
| 73 | 
            +
            with gr.Blocks(theme=gr.themes.Soft()) as demo:
         | 
| 74 | 
            +
                gr.Markdown(
         | 
| 75 | 
            +
                """
         | 
| 76 | 
            +
                # Superposed Decoding
         | 
| 77 | 
            +
                Start typing below to see suggestions.
         | 
| 78 | 
            +
                """)
         | 
| 79 | 
            +
                slider = gr.Slider(minimum=1, maximum=10, step=1, label="Generation length", value=10)
         | 
| 80 | 
            +
                inp = gr.Textbox(placeholder="Type anything!", lines=3)
         | 
| 81 | 
            +
                option1 = gr.Button(value="Option 1")
         | 
| 82 | 
            +
                option2 = gr.Button(value="Option 2")
         | 
| 83 | 
            +
                option3 = gr.Button(value="Option 3")
         | 
| 84 | 
            +
                inp.change(update_options, inputs=[inp, slider], outputs=[option1, option2, option3])
         | 
| 85 | 
            +
                # Button updates
         | 
| 86 | 
            +
                @option1.click(inputs=[inp, option1], outputs=inp)
         | 
| 87 | 
            +
                def option1_click(curr, txt):
         | 
| 88 | 
            +
                    return curr + txt
         | 
| 89 | 
            +
                @option2.click(inputs=[inp, option2], outputs=inp)
         | 
| 90 | 
            +
                def option2_click(curr, txt):
         | 
| 91 | 
            +
                    return curr + txt
         | 
| 92 | 
            +
                @option3.click(inputs=[inp, option3], outputs=inp)
         | 
| 93 | 
            +
                def option3_click(curr, txt):
         | 
| 94 | 
            +
                    return curr + txt
         | 
| 95 | 
            +
             | 
| 96 | 
            +
            if __name__ == "__main__":
         | 
| 97 | 
            +
                demo.launch(debug=True)
         | 
    	
        params/g15_d3_mixed.json
    ADDED
    
    | @@ -0,0 +1,27 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            {
         | 
| 2 | 
            +
                "alpha": 0.48,
         | 
| 3 | 
            +
                "temp": 0.06,
         | 
| 4 | 
            +
                "n_drafts": 3,
         | 
| 5 | 
            +
                "prompt_len": 15,
         | 
| 6 | 
            +
                "n_token_sample": 15,
         | 
| 7 | 
            +
                "max_gen_len": 15,
         | 
| 8 | 
            +
                "n_token_consider": 32000,
         | 
| 9 | 
            +
                "mixing_method": "sample_new_weights_with_score",
         | 
| 10 | 
            +
                "smoothing": "geom",
         | 
| 11 | 
            +
                "sample_tokens": 0,
         | 
| 12 | 
            +
                "sample_beams": 0,
         | 
| 13 | 
            +
                "i_weights": [
         | 
| 14 | 
            +
                    0.01,
         | 
| 15 | 
            +
                    0.04,
         | 
| 16 | 
            +
                    0.15,
         | 
| 17 | 
            +
                    0.18,
         | 
| 18 | 
            +
                    0.12
         | 
| 19 | 
            +
                ],
         | 
| 20 | 
            +
                "i_length": [
         | 
| 21 | 
            +
                    1,
         | 
| 22 | 
            +
                    2,
         | 
| 23 | 
            +
                    3,
         | 
| 24 | 
            +
                    4,
         | 
| 25 | 
            +
                    5
         | 
| 26 | 
            +
                ]
         | 
| 27 | 
            +
            }
         | 
    	
        params/g20_d3_mixed.json
    ADDED
    
    | @@ -0,0 +1,27 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            {
         | 
| 2 | 
            +
                "alpha": 0.5,
         | 
| 3 | 
            +
                "temp": 0.04,
         | 
| 4 | 
            +
                "n_drafts": 3,
         | 
| 5 | 
            +
                "prompt_len": 15,
         | 
| 6 | 
            +
                "n_token_sample": 15,
         | 
| 7 | 
            +
                "max_gen_len": 20,
         | 
| 8 | 
            +
                "n_token_consider": 32000,
         | 
| 9 | 
            +
                "mixing_method": "sample_new_weights_with_score",
         | 
| 10 | 
            +
                "smoothing": "geom",
         | 
| 11 | 
            +
                "sample_tokens": 0,
         | 
| 12 | 
            +
                "sample_beams": 0,
         | 
| 13 | 
            +
                "i_weights": [
         | 
| 14 | 
            +
                    0.01,
         | 
| 15 | 
            +
                    0.04,
         | 
| 16 | 
            +
                    0.15,
         | 
| 17 | 
            +
                    0.18,
         | 
| 18 | 
            +
                    0.12
         | 
| 19 | 
            +
                ],
         | 
| 20 | 
            +
                "i_length": [
         | 
| 21 | 
            +
                    1,
         | 
| 22 | 
            +
                    2,
         | 
| 23 | 
            +
                    3,
         | 
| 24 | 
            +
                    4,
         | 
| 25 | 
            +
                    5
         | 
| 26 | 
            +
                ]
         | 
| 27 | 
            +
            }
         | 
    	
        params/g5_d3_mixed.json
    ADDED
    
    | @@ -0,0 +1,27 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            {
         | 
| 2 | 
            +
                "alpha": 0.52,
         | 
| 3 | 
            +
                "temp": 0.06,
         | 
| 4 | 
            +
                "n_drafts": 3,
         | 
| 5 | 
            +
                "prompt_len": 15,
         | 
| 6 | 
            +
                "n_token_sample": 15,
         | 
| 7 | 
            +
                "max_gen_len": 5,
         | 
| 8 | 
            +
                "n_token_consider": 32000,
         | 
| 9 | 
            +
                "mixing_method": "sample_new_weights_with_score",
         | 
| 10 | 
            +
                "smoothing": "geom",
         | 
| 11 | 
            +
                "sample_tokens": 0,
         | 
| 12 | 
            +
                "sample_beams": 0,
         | 
| 13 | 
            +
                "i_weights": [
         | 
| 14 | 
            +
                    0.01,
         | 
| 15 | 
            +
                    0.04,
         | 
| 16 | 
            +
                    0.15,
         | 
| 17 | 
            +
                    0.18,
         | 
| 18 | 
            +
                    0.12
         | 
| 19 | 
            +
                ],
         | 
| 20 | 
            +
                "i_length": [
         | 
| 21 | 
            +
                    1,
         | 
| 22 | 
            +
                    2,
         | 
| 23 | 
            +
                    3,
         | 
| 24 | 
            +
                    4,
         | 
| 25 | 
            +
                    5
         | 
| 26 | 
            +
                ]
         | 
| 27 | 
            +
            }
         | 
    	
        params/p15_d10_mixed.json
    ADDED
    
    | @@ -0,0 +1,26 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            {
         | 
| 2 | 
            +
                "alpha": 0.54,
         | 
| 3 | 
            +
                "temp": 0.12,
         | 
| 4 | 
            +
                "n_drafts": 10,
         | 
| 5 | 
            +
                "prompt_len": 15,
         | 
| 6 | 
            +
                "n_token_sample": 30,
         | 
| 7 | 
            +
                "n_token_consider": 32000,
         | 
| 8 | 
            +
                "mixing_method": "sample_new_weights_with_score",
         | 
| 9 | 
            +
                "smoothing": "geom",
         | 
| 10 | 
            +
                "sample_tokens": 0,
         | 
| 11 | 
            +
                "sample_beams": 0,
         | 
| 12 | 
            +
                "i_weights": [
         | 
| 13 | 
            +
                    0.01,
         | 
| 14 | 
            +
                    0.04,
         | 
| 15 | 
            +
                    0.15,
         | 
| 16 | 
            +
                    0.18,
         | 
| 17 | 
            +
                    0.12
         | 
| 18 | 
            +
                ],
         | 
| 19 | 
            +
                "i_length": [
         | 
| 20 | 
            +
                    1,
         | 
| 21 | 
            +
                    2,
         | 
| 22 | 
            +
                    3,
         | 
| 23 | 
            +
                    4,
         | 
| 24 | 
            +
                    5
         | 
| 25 | 
            +
                ]
         | 
| 26 | 
            +
            }
         | 
    	
        params/p15_d2_mixed.json
    ADDED
    
    | @@ -0,0 +1,26 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            {
         | 
| 2 | 
            +
                "alpha": 0.62,
         | 
| 3 | 
            +
                "temp": 0.06,
         | 
| 4 | 
            +
                "n_drafts": 2,
         | 
| 5 | 
            +
                "prompt_len": 15,
         | 
| 6 | 
            +
                "n_token_sample": 6,
         | 
| 7 | 
            +
                "n_token_consider": 32000,
         | 
| 8 | 
            +
                "mixing_method": "sample_new_weights_with_score",
         | 
| 9 | 
            +
                "smoothing": "geom",
         | 
| 10 | 
            +
                "sample_tokens": 0,
         | 
| 11 | 
            +
                "sample_beams": 0,
         | 
| 12 | 
            +
                "i_weights": [
         | 
| 13 | 
            +
                    0.01,
         | 
| 14 | 
            +
                    0.04,
         | 
| 15 | 
            +
                    0.15,
         | 
| 16 | 
            +
                    0.18,
         | 
| 17 | 
            +
                    0.12
         | 
| 18 | 
            +
                ],
         | 
| 19 | 
            +
                "i_length": [
         | 
| 20 | 
            +
                    1,
         | 
| 21 | 
            +
                    2,
         | 
| 22 | 
            +
                    3,
         | 
| 23 | 
            +
                    4,
         | 
| 24 | 
            +
                    5
         | 
| 25 | 
            +
                ]
         | 
| 26 | 
            +
            }
         | 
    	
        params/p15_d3_mixed.json
    ADDED
    
    | @@ -0,0 +1,26 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            {
         | 
| 2 | 
            +
                "alpha": 0.54,
         | 
| 3 | 
            +
                "temp": 0.06,
         | 
| 4 | 
            +
                "n_drafts": 3,
         | 
| 5 | 
            +
                "prompt_len": 15,
         | 
| 6 | 
            +
                "n_token_sample": 9,
         | 
| 7 | 
            +
                "n_token_consider": 32000,
         | 
| 8 | 
            +
                "mixing_method": "sample_new_weights_with_score",
         | 
| 9 | 
            +
                "smoothing": "geom",
         | 
| 10 | 
            +
                "sample_tokens": 0,
         | 
| 11 | 
            +
                "sample_beams": 0,
         | 
| 12 | 
            +
                "i_weights": [
         | 
| 13 | 
            +
                    0.01,
         | 
| 14 | 
            +
                    0.04,
         | 
| 15 | 
            +
                    0.15,
         | 
| 16 | 
            +
                    0.18,
         | 
| 17 | 
            +
                    0.12
         | 
| 18 | 
            +
                ],
         | 
| 19 | 
            +
                "i_length": [
         | 
| 20 | 
            +
                    1,
         | 
| 21 | 
            +
                    2,
         | 
| 22 | 
            +
                    3,
         | 
| 23 | 
            +
                    4,
         | 
| 24 | 
            +
                    5
         | 
| 25 | 
            +
                ]
         | 
| 26 | 
            +
            }
         | 
    	
        params/p15_d3_ngram4_mixed.json
    ADDED
    
    | @@ -0,0 +1,22 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            {
         | 
| 2 | 
            +
                "alpha": 0.55,
         | 
| 3 | 
            +
                "temp": 0.1,
         | 
| 4 | 
            +
                "n_drafts": 3,
         | 
| 5 | 
            +
                "prompt_len": 15,
         | 
| 6 | 
            +
                "n_token_sample": 9,
         | 
| 7 | 
            +
                "n_token_consider": 32000,
         | 
| 8 | 
            +
                "mixing_method": "sample_new_weights_with_score",
         | 
| 9 | 
            +
                "smoothing": "geom",
         | 
| 10 | 
            +
                "sample_tokens": 0,
         | 
| 11 | 
            +
                "sample_beams": 0,
         | 
| 12 | 
            +
                "i_weights": [
         | 
| 13 | 
            +
                    0.01,
         | 
| 14 | 
            +
                    0.04,
         | 
| 15 | 
            +
                    0.15
         | 
| 16 | 
            +
                ],
         | 
| 17 | 
            +
                "i_length": [
         | 
| 18 | 
            +
                    1,
         | 
| 19 | 
            +
                    2,
         | 
| 20 | 
            +
                    3
         | 
| 21 | 
            +
                ]
         | 
| 22 | 
            +
            }
         | 
    	
        params/p15_d4_mixed.json
    ADDED
    
    | @@ -0,0 +1,26 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            {
         | 
| 2 | 
            +
                "alpha": 0.52,
         | 
| 3 | 
            +
                "temp": 0.06,
         | 
| 4 | 
            +
                "n_drafts": 4,
         | 
| 5 | 
            +
                "prompt_len": 15,
         | 
| 6 | 
            +
                "n_token_sample": 12,
         | 
| 7 | 
            +
                "n_token_consider": 32000,
         | 
| 8 | 
            +
                "mixing_method": "sample_new_weights_with_score",
         | 
| 9 | 
            +
                "smoothing": "geom",
         | 
| 10 | 
            +
                "sample_tokens": 0,
         | 
| 11 | 
            +
                "sample_beams": 0,
         | 
| 12 | 
            +
                "i_weights": [
         | 
| 13 | 
            +
                    0.01,
         | 
| 14 | 
            +
                    0.04,
         | 
| 15 | 
            +
                    0.15,
         | 
| 16 | 
            +
                    0.18,
         | 
| 17 | 
            +
                    0.12
         | 
| 18 | 
            +
                ],
         | 
| 19 | 
            +
                "i_length": [
         | 
| 20 | 
            +
                    1,
         | 
| 21 | 
            +
                    2,
         | 
| 22 | 
            +
                    3,
         | 
| 23 | 
            +
                    4,
         | 
| 24 | 
            +
                    5
         | 
| 25 | 
            +
                ]
         | 
| 26 | 
            +
            }
         | 
    	
        params/p15_d5_mixed.json
    ADDED
    
    | @@ -0,0 +1,26 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            {
         | 
| 2 | 
            +
                "alpha": 0.6,
         | 
| 3 | 
            +
                "temp": 0.06,
         | 
| 4 | 
            +
                "n_drafts": 5,
         | 
| 5 | 
            +
                "prompt_len": 15,
         | 
| 6 | 
            +
                "n_token_sample": 15,
         | 
| 7 | 
            +
                "n_token_consider": 32000,
         | 
| 8 | 
            +
                "mixing_method": "sample_new_weights_with_score",
         | 
| 9 | 
            +
                "smoothing": "geom",
         | 
| 10 | 
            +
                "sample_tokens": 0,
         | 
| 11 | 
            +
                "sample_beams": 0,
         | 
| 12 | 
            +
                "i_weights": [
         | 
| 13 | 
            +
                    0.01,
         | 
| 14 | 
            +
                    0.04,
         | 
| 15 | 
            +
                    0.15,
         | 
| 16 | 
            +
                    0.18,
         | 
| 17 | 
            +
                    0.12
         | 
| 18 | 
            +
                ],
         | 
| 19 | 
            +
                "i_length": [
         | 
| 20 | 
            +
                    1,
         | 
| 21 | 
            +
                    2,
         | 
| 22 | 
            +
                    3,
         | 
| 23 | 
            +
                    4,
         | 
| 24 | 
            +
                    5
         | 
| 25 | 
            +
                ]
         | 
| 26 | 
            +
            }
         | 
    	
        params/p15_d6_mixed.json
    ADDED
    
    | @@ -0,0 +1,26 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            {
         | 
| 2 | 
            +
                "alpha": 0.52,
         | 
| 3 | 
            +
                "temp": 0.06,
         | 
| 4 | 
            +
                "n_drafts": 6,
         | 
| 5 | 
            +
                "prompt_len": 15,
         | 
| 6 | 
            +
                "n_token_sample": 18,
         | 
| 7 | 
            +
                "n_token_consider": 32000,
         | 
| 8 | 
            +
                "mixing_method": "sample_new_weights_with_score",
         | 
| 9 | 
            +
                "smoothing": "geom",
         | 
| 10 | 
            +
                "sample_tokens": 0,
         | 
| 11 | 
            +
                "sample_beams": 0,
         | 
| 12 | 
            +
                "i_weights": [
         | 
| 13 | 
            +
                    0.01,
         | 
| 14 | 
            +
                    0.04,
         | 
| 15 | 
            +
                    0.15,
         | 
| 16 | 
            +
                    0.18,
         | 
| 17 | 
            +
                    0.12
         | 
| 18 | 
            +
                ],
         | 
| 19 | 
            +
                "i_length": [
         | 
| 20 | 
            +
                    1,
         | 
| 21 | 
            +
                    2,
         | 
| 22 | 
            +
                    3,
         | 
| 23 | 
            +
                    4,
         | 
| 24 | 
            +
                    5
         | 
| 25 | 
            +
                ]
         | 
| 26 | 
            +
            }
         | 
    	
        params/p25_d3_mixed.json
    ADDED
    
    | @@ -0,0 +1,26 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            {
         | 
| 2 | 
            +
                "alpha": 0.5,
         | 
| 3 | 
            +
                "temp": 0.12,
         | 
| 4 | 
            +
                "n_drafts": 3,
         | 
| 5 | 
            +
                "prompt_len": 25,
         | 
| 6 | 
            +
                "n_token_sample": 15,
         | 
| 7 | 
            +
                "n_token_consider": 32000,
         | 
| 8 | 
            +
                "mixing_method": "sample_new_weights_with_score",
         | 
| 9 | 
            +
                "smoothing": "geom",
         | 
| 10 | 
            +
                "sample_tokens": 0,
         | 
| 11 | 
            +
                "sample_beams": 0,
         | 
| 12 | 
            +
                "i_weights": [
         | 
| 13 | 
            +
                    0.01,
         | 
| 14 | 
            +
                    0.04,
         | 
| 15 | 
            +
                    0.15,
         | 
| 16 | 
            +
                    0.18,
         | 
| 17 | 
            +
                    0.12
         | 
| 18 | 
            +
                ],
         | 
| 19 | 
            +
                "i_length": [
         | 
| 20 | 
            +
                    1,
         | 
| 21 | 
            +
                    2,
         | 
| 22 | 
            +
                    3,
         | 
| 23 | 
            +
                    4,
         | 
| 24 | 
            +
                    5
         | 
| 25 | 
            +
                ]
         | 
| 26 | 
            +
            }
         | 
    	
        params/p40_d3_mixed.json
    ADDED
    
    | @@ -0,0 +1,12 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            {
         | 
| 2 | 
            +
                "alpha": 0.55, 
         | 
| 3 | 
            +
                "temp": 0.1,
         | 
| 4 | 
            +
                "prompt_len": 40,
         | 
| 5 | 
            +
                "mixing_method": "sample_new_weights_with_score",
         | 
| 6 | 
            +
                "smoothing": "geom",
         | 
| 7 | 
            +
                "sample_tokens": 0,
         | 
| 8 | 
            +
                "sample_beams": 0,
         | 
| 9 | 
            +
                "i_weights": [0.01, 0.04, 0.15, 0.18, 0.12],
         | 
| 10 | 
            +
                "i_length": [1, 2, 3, 4, 5],
         | 
| 11 | 
            +
                "ckpt_path": "../ckpts-200k"
         | 
| 12 | 
            +
            }
         | 
    	
        params/p5_d3_mixed.json
    ADDED
    
    | @@ -0,0 +1,26 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            {
         | 
| 2 | 
            +
                "alpha": 0.34,
         | 
| 3 | 
            +
                "temp": 0.12,
         | 
| 4 | 
            +
                "n_drafts": 3,
         | 
| 5 | 
            +
                "prompt_len": 5,
         | 
| 6 | 
            +
                "n_token_sample": 15,
         | 
| 7 | 
            +
                "n_token_consider": 32000,
         | 
| 8 | 
            +
                "mixing_method": "sample_new_weights_with_score",
         | 
| 9 | 
            +
                "smoothing": "geom",
         | 
| 10 | 
            +
                "sample_tokens": 0,
         | 
| 11 | 
            +
                "sample_beams": 0,
         | 
| 12 | 
            +
                "i_weights": [
         | 
| 13 | 
            +
                    0.01,
         | 
| 14 | 
            +
                    0.04,
         | 
| 15 | 
            +
                    0.15,
         | 
| 16 | 
            +
                    0.18,
         | 
| 17 | 
            +
                    0.12
         | 
| 18 | 
            +
                ],
         | 
| 19 | 
            +
                "i_length": [
         | 
| 20 | 
            +
                    1,
         | 
| 21 | 
            +
                    2,
         | 
| 22 | 
            +
                    3,
         | 
| 23 | 
            +
                    4,
         | 
| 24 | 
            +
                    5
         | 
| 25 | 
            +
                ]
         | 
| 26 | 
            +
            }
         | 
    	
        requirements.txt
    ADDED
    
    | @@ -0,0 +1,11 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            datasets==2.19.0
         | 
| 2 | 
            +
            fairscale==0.4.13
         | 
| 3 | 
            +
            loguru==0.7.2
         | 
| 4 | 
            +
            nltk==3.8.1
         | 
| 5 | 
            +
            numpy==1.26.4
         | 
| 6 | 
            +
            Requests==2.32.2
         | 
| 7 | 
            +
            sentencepiece==0.2.0
         | 
| 8 | 
            +
            setuptools==58.2.0
         | 
| 9 | 
            +
            torch==2.3.0
         | 
| 10 | 
            +
            tqdm==4.66.4
         | 
| 11 | 
            +
            transformers==4.37.2
         | 
    	
        superposed/llama/__init__.py
    ADDED
    
    | @@ -0,0 +1,6 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            from .generation import Llama, Dialog
         | 
| 5 | 
            +
            from .model import ModelArgs, Transformer
         | 
| 6 | 
            +
            from .tokenizer import Tokenizer
         | 
    	
        superposed/llama/__pycache__/__init__.cpython-312.pyc
    ADDED
    
    | Binary file (335 Bytes). View file | 
|  | 
    	
        superposed/llama/__pycache__/generation.cpython-312.pyc
    ADDED
    
    | Binary file (13.9 kB). View file | 
|  | 
    	
        superposed/llama/__pycache__/model.cpython-312.pyc
    ADDED
    
    | Binary file (26.7 kB). View file | 
|  | 
    	
        superposed/llama/__pycache__/superpose.cpython-312.pyc
    ADDED
    
    | Binary file (19.1 kB). View file | 
|  | 
    	
        superposed/llama/__pycache__/superposed_generation.cpython-312.pyc
    ADDED
    
    | Binary file (10.1 kB). View file | 
|  | 
    	
        superposed/llama/__pycache__/superposed_model.cpython-312.pyc
    ADDED
    
    | Binary file (25.9 kB). View file | 
|  | 
    	
        superposed/llama/__pycache__/tokenizer.cpython-312.pyc
    ADDED
    
    | Binary file (3.26 kB). View file | 
|  | 
    	
        superposed/llama/__pycache__/utils.cpython-312.pyc
    ADDED
    
    | Binary file (3.97 kB). View file | 
|  | 
    	
        superposed/llama/generation.py
    ADDED
    
    | @@ -0,0 +1,268 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import json
         | 
| 5 | 
            +
            import os
         | 
| 6 | 
            +
            import sys
         | 
| 7 | 
            +
            import time
         | 
| 8 | 
            +
            from pathlib import Path
         | 
| 9 | 
            +
            from typing import List, Literal, Optional, Tuple, TypedDict
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            import torch
         | 
| 12 | 
            +
            import torch.nn.functional as F
         | 
| 13 | 
            +
            from fairscale.nn.model_parallel.initialize import (
         | 
| 14 | 
            +
                get_model_parallel_rank,
         | 
| 15 | 
            +
                initialize_model_parallel,
         | 
| 16 | 
            +
                model_parallel_is_initialized,
         | 
| 17 | 
            +
            )
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            from superposed.llama.model import ModelArgs, Transformer
         | 
| 20 | 
            +
            from superposed.llama.tokenizer import Tokenizer
         | 
| 21 | 
            +
            from superposed.llama.utils import *
         | 
| 22 | 
            +
             | 
| 23 | 
            +
            Role = Literal["system", "user", "assistant"]
         | 
| 24 | 
            +
             | 
| 25 | 
            +
             | 
| 26 | 
            +
            class Message(TypedDict):
         | 
| 27 | 
            +
                role: Role
         | 
| 28 | 
            +
                content: str
         | 
| 29 | 
            +
             | 
| 30 | 
            +
             | 
| 31 | 
            +
            class CompletionPrediction(TypedDict, total=False):
         | 
| 32 | 
            +
                generation: str
         | 
| 33 | 
            +
                tokens: List[str]  # not required
         | 
| 34 | 
            +
                logprobs: List[float]  # not required
         | 
| 35 | 
            +
             | 
| 36 | 
            +
             | 
| 37 | 
            +
            class ChatPrediction(TypedDict, total=False):
         | 
| 38 | 
            +
                generation: Message
         | 
| 39 | 
            +
                tokens: List[str]  # not required
         | 
| 40 | 
            +
                logprobs: List[float]  # not required
         | 
| 41 | 
            +
             | 
| 42 | 
            +
             | 
| 43 | 
            +
            Dialog = List[Message]
         | 
| 44 | 
            +
             | 
| 45 | 
            +
            B_INST, E_INST = "[INST]", "[/INST]"
         | 
| 46 | 
            +
            B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
         | 
| 47 | 
            +
             | 
| 48 | 
            +
            SPECIAL_TAGS = [B_INST, E_INST, "<<SYS>>", "<</SYS>>"]
         | 
| 49 | 
            +
            UNSAFE_ERROR = "Error: special tags are not allowed as part of the prompt."
         | 
| 50 | 
            +
             | 
| 51 | 
            +
             | 
| 52 | 
            +
            class Llama:
         | 
| 53 | 
            +
                @staticmethod
         | 
| 54 | 
            +
                def build(
         | 
| 55 | 
            +
                    ckpt_dir: str,
         | 
| 56 | 
            +
                    tokenizer_path: str,
         | 
| 57 | 
            +
                    max_seq_len: int,
         | 
| 58 | 
            +
                    max_batch_size: int,
         | 
| 59 | 
            +
                    device: None,
         | 
| 60 | 
            +
                    model_parallel_size: Optional[int] = None,
         | 
| 61 | 
            +
                    seed: int = 1,
         | 
| 62 | 
            +
                ) -> "Llama":
         | 
| 63 | 
            +
                    """
         | 
| 64 | 
            +
                    Build a Llama instance by initializing and loading a pre-trained model.
         | 
| 65 | 
            +
             | 
| 66 | 
            +
                    Args:
         | 
| 67 | 
            +
                        ckpt_dir (str): Path to the directory containing checkpoint files.
         | 
| 68 | 
            +
                        tokenizer_path (str): Path to the tokenizer file.
         | 
| 69 | 
            +
                        max_seq_len (int): Maximum sequence length for input text.
         | 
| 70 | 
            +
                        max_batch_size (int): Maximum batch size for inference.
         | 
| 71 | 
            +
                        mixed (bool): Whether to mix embeddings or not
         | 
| 72 | 
            +
                        model_parallel_size (Optional[int], optional): Number of model parallel processes.
         | 
| 73 | 
            +
                            If not provided, it's determined from the environment. Defaults to None.
         | 
| 74 | 
            +
             | 
| 75 | 
            +
                    Returns:
         | 
| 76 | 
            +
                        Llama: An instance of the Llama class with the loaded model and tokenizer.
         | 
| 77 | 
            +
             | 
| 78 | 
            +
                    Raises:
         | 
| 79 | 
            +
                        AssertionError: If there are no checkpoint files in the specified directory,
         | 
| 80 | 
            +
                            or if the model parallel size does not match the number of checkpoint files.
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                    Note:
         | 
| 83 | 
            +
                        This method initializes the distributed process group, sets the device to CUDA,
         | 
| 84 | 
            +
                        and loads the pre-trained model and tokenizer.
         | 
| 85 | 
            +
             | 
| 86 | 
            +
                    """
         | 
| 87 | 
            +
                    if not torch.distributed.is_initialized():
         | 
| 88 | 
            +
                        torch.distributed.init_process_group("nccl")
         | 
| 89 | 
            +
                    if not model_parallel_is_initialized():
         | 
| 90 | 
            +
                        if model_parallel_size is None:
         | 
| 91 | 
            +
                            model_parallel_size = int(os.environ.get("WORLD_SIZE", 1))
         | 
| 92 | 
            +
                        initialize_model_parallel(model_parallel_size)
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                    local_rank = int(os.environ.get("LOCAL_RANK", 0))
         | 
| 95 | 
            +
                    print(local_rank)
         | 
| 96 | 
            +
                    # torch.cuda.set_device(local_rank)
         | 
| 97 | 
            +
                    if device == None:
         | 
| 98 | 
            +
                        torch.cuda.set_device(local_rank)
         | 
| 99 | 
            +
                        device = f"cuda:{local_rank}"
         | 
| 100 | 
            +
                    # seed must be the same in all processes
         | 
| 101 | 
            +
                    torch.manual_seed(seed)
         | 
| 102 | 
            +
             | 
| 103 | 
            +
                    if local_rank > 0:
         | 
| 104 | 
            +
                        sys.stdout = open(os.devnull, "w")
         | 
| 105 | 
            +
             | 
| 106 | 
            +
                    start_time = time.time()
         | 
| 107 | 
            +
                    checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
         | 
| 108 | 
            +
                    assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}"
         | 
| 109 | 
            +
                    assert model_parallel_size == len(
         | 
| 110 | 
            +
                        checkpoints
         | 
| 111 | 
            +
                    ), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {model_parallel_size}"
         | 
| 112 | 
            +
                    ckpt_path = checkpoints[get_model_parallel_rank()]
         | 
| 113 | 
            +
                    checkpoint = torch.load(ckpt_path, map_location="cpu")
         | 
| 114 | 
            +
                    with open(Path(ckpt_dir) / "params.json", "r") as f:
         | 
| 115 | 
            +
                        params = json.loads(f.read())
         | 
| 116 | 
            +
             | 
| 117 | 
            +
                    model_args: ModelArgs = ModelArgs(
         | 
| 118 | 
            +
                        max_seq_len=max_seq_len,
         | 
| 119 | 
            +
                        max_batch_size=max_batch_size,
         | 
| 120 | 
            +
                        **params,
         | 
| 121 | 
            +
                    )
         | 
| 122 | 
            +
                    tokenizer = Tokenizer(model_path=tokenizer_path)
         | 
| 123 | 
            +
                    model_args.vocab_size = tokenizer.n_words
         | 
| 124 | 
            +
                    torch.set_default_tensor_type(torch.cuda.HalfTensor)
         | 
| 125 | 
            +
                    model = Transformer(model_args)
         | 
| 126 | 
            +
                    model.load_state_dict(checkpoint, strict=False)
         | 
| 127 | 
            +
                    print(f"Loaded in {time.time() - start_time:.2f} seconds")
         | 
| 128 | 
            +
                    return Llama(model, tokenizer, device)
         | 
| 129 | 
            +
             | 
| 130 | 
            +
                def __init__(self, model: Transformer, tokenizer: Tokenizer, device):
         | 
| 131 | 
            +
                    self.model = model.to(device).eval()
         | 
| 132 | 
            +
                    self.tokenizer = tokenizer
         | 
| 133 | 
            +
                    self.device = device
         | 
| 134 | 
            +
             | 
| 135 | 
            +
                @torch.inference_mode()
         | 
| 136 | 
            +
                def generate(
         | 
| 137 | 
            +
                    self,
         | 
| 138 | 
            +
                    prompt_tokens: List[List[int]],
         | 
| 139 | 
            +
                    max_gen_len: int,
         | 
| 140 | 
            +
                    temperature: float = 0.6,
         | 
| 141 | 
            +
                    top_p: float = 0.9,
         | 
| 142 | 
            +
                    logprobs: bool = True,
         | 
| 143 | 
            +
                    grade: bool = False
         | 
| 144 | 
            +
                ) -> Tuple[List[List[int]], Optional[List[List[float]]]]:
         | 
| 145 | 
            +
                    """
         | 
| 146 | 
            +
                    Generate text sequences based on provided prompts using the language generation model.
         | 
| 147 | 
            +
             | 
| 148 | 
            +
                    Args:
         | 
| 149 | 
            +
                        prompt_tokens (List[List[int]]): List of tokenized prompts, where each prompt is represented as a list of integers.
         | 
| 150 | 
            +
                        max_gen_len (int): Maximum length of the generated text sequence.
         | 
| 151 | 
            +
                        temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6.
         | 
| 152 | 
            +
                        top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to 0.9.
         | 
| 153 | 
            +
                        logprobs (bool, optional): Flag indicating whether to compute token log probabilities. Defaults to False.
         | 
| 154 | 
            +
                        echo (bool, optional): Flag indicating whether to include prompt tokens in the generated output. Defaults to False.
         | 
| 155 | 
            +
             | 
| 156 | 
            +
                    Returns:
         | 
| 157 | 
            +
                        Tuple[List[List[int]], Optional[List[List[float]]]]: A tuple containing generated token sequences and, if logprobs is True, corresponding token log probabilities.
         | 
| 158 | 
            +
             | 
| 159 | 
            +
                    Note:
         | 
| 160 | 
            +
                        This method uses the provided prompts as a basis for generating text. It employs nucleus sampling to produce text with controlled randomness.
         | 
| 161 | 
            +
                        If logprobs is True, token log probabilities are computed for each generated token.
         | 
| 162 | 
            +
             | 
| 163 | 
            +
                    """
         | 
| 164 | 
            +
                    params = self.model.params
         | 
| 165 | 
            +
                    bsz = len(prompt_tokens)
         | 
| 166 | 
            +
                    assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)
         | 
| 167 | 
            +
             | 
| 168 | 
            +
                    min_prompt_len = min(len(t) for t in prompt_tokens)
         | 
| 169 | 
            +
                    max_prompt_len = max(len(t) for t in prompt_tokens)
         | 
| 170 | 
            +
                    # assert min_prompt_len == max_prompt_len
         | 
| 171 | 
            +
                    prompt_len = min_prompt_len
         | 
| 172 | 
            +
                    assert max_prompt_len <= params.max_seq_len
         | 
| 173 | 
            +
                    total_len = min(params.max_seq_len, max_gen_len + max_prompt_len)
         | 
| 174 | 
            +
             | 
| 175 | 
            +
                    pad_id = self.tokenizer.pad_id
         | 
| 176 | 
            +
                    tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device=self.device)
         | 
| 177 | 
            +
                    for k, t in enumerate(prompt_tokens):
         | 
| 178 | 
            +
                        tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device=self.device)
         | 
| 179 | 
            +
                    if logprobs:
         | 
| 180 | 
            +
                        token_logprobs = torch.zeros_like(tokens, dtype=torch.float)
         | 
| 181 | 
            +
                    prev_pos = 0
         | 
| 182 | 
            +
                    eos_reached = torch.tensor([False] * bsz, device=self.device)
         | 
| 183 | 
            +
                    input_text_mask = tokens != pad_id
         | 
| 184 | 
            +
                    if grade:
         | 
| 185 | 
            +
                        pad_mask = tokens == pad_id
         | 
| 186 | 
            +
                        tokens = torch.where(tokens == pad_id, 0, tokens)
         | 
| 187 | 
            +
                        logits = self.model.forward(tokens, prev_pos, False)
         | 
| 188 | 
            +
                        tokens[pad_mask] = pad_id
         | 
| 189 | 
            +
                        token_logprobs = -F.cross_entropy(
         | 
| 190 | 
            +
                            input=logits[:, :-1, :].transpose(1, 2),
         | 
| 191 | 
            +
                            target=tokens[:, 1:],
         | 
| 192 | 
            +
                            reduction="none",
         | 
| 193 | 
            +
                            ignore_index=pad_id,
         | 
| 194 | 
            +
                        )
         | 
| 195 | 
            +
                        #if pad_id in tokens:
         | 
| 196 | 
            +
                        #    print(pad_id)
         | 
| 197 | 
            +
                        #    print(tokens)
         | 
| 198 | 
            +
                        #    print(token_logprobs)
         | 
| 199 | 
            +
                        return token_logprobs
         | 
| 200 | 
            +
             | 
| 201 | 
            +
                    for cur_pos in range(min_prompt_len, total_len):
         | 
| 202 | 
            +
                        logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos, False)
         | 
| 203 | 
            +
                        if temperature > 0:
         | 
| 204 | 
            +
                            probs = torch.softmax(logits[:, -1] / temperature, dim=-1)
         | 
| 205 | 
            +
                            next_token = sample_top_p(probs, top_p)
         | 
| 206 | 
            +
                        else:
         | 
| 207 | 
            +
                            next_token = torch.argmax(logits[:, -1], dim=-1)
         | 
| 208 | 
            +
             | 
| 209 | 
            +
                        next_token = next_token.reshape(-1)
         | 
| 210 | 
            +
                        # only replace token if prompt has already been generated
         | 
| 211 | 
            +
                        next_token = torch.where(
         | 
| 212 | 
            +
                            input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token
         | 
| 213 | 
            +
                        )
         | 
| 214 | 
            +
                        tokens[:, cur_pos] = next_token
         | 
| 215 | 
            +
                        if logprobs:
         | 
| 216 | 
            +
                            token_logprobs[:, prev_pos + 1 : cur_pos + 1] = -F.cross_entropy(
         | 
| 217 | 
            +
                                input=logits.transpose(1, 2),
         | 
| 218 | 
            +
                                target=tokens[:, prev_pos + 1 : cur_pos + 1],
         | 
| 219 | 
            +
                                reduction="none",
         | 
| 220 | 
            +
                                ignore_index=pad_id,
         | 
| 221 | 
            +
                            )                
         | 
| 222 | 
            +
                        eos_reached |= (~input_text_mask[:, cur_pos]) & (
         | 
| 223 | 
            +
                            next_token == self.tokenizer.eos_id
         | 
| 224 | 
            +
                        )
         | 
| 225 | 
            +
                        prev_pos = cur_pos
         | 
| 226 | 
            +
                        if all(eos_reached):
         | 
| 227 | 
            +
                            break
         | 
| 228 | 
            +
             | 
| 229 | 
            +
                    # seq_len = torch.sum(tokens != pad_id, dim=1)
         | 
| 230 | 
            +
                    # return tokens, torch.exp(-1 * torch.sum(logprobs, dim=1) / (seq_len - prompt_len)), torch.exp(-1 * torch.sum(custom_logprobs, dim=1) / )
         | 
| 231 | 
            +
                    if logprobs:
         | 
| 232 | 
            +
                        token_logprobs = token_logprobs.tolist()
         | 
| 233 | 
            +
             | 
| 234 | 
            +
                    out_ppl = []
         | 
| 235 | 
            +
                    for i, toks in enumerate(tokens.tolist()):
         | 
| 236 | 
            +
                        if logprobs:
         | 
| 237 | 
            +
                            probs = token_logprobs[i][prompt_len : len(prompt_tokens[i]) + max_gen_len]
         | 
| 238 | 
            +
                        # cut to eos tok if any
         | 
| 239 | 
            +
                        if self.tokenizer.eos_id in toks:
         | 
| 240 | 
            +
                            eos_idx = toks.index(self.tokenizer.eos_id)
         | 
| 241 | 
            +
                            probs = probs[:eos_idx] if logprobs else None
         | 
| 242 | 
            +
                        out_ppl.append(torch.exp(-1 * torch.sum(torch.tensor(probs)) / len(probs)))
         | 
| 243 | 
            +
                    return tokens, torch.tensor(out_ppl) if logprobs else None
         | 
| 244 | 
            +
             | 
| 245 | 
            +
            def sample_top_p(probs, p, s=1):
         | 
| 246 | 
            +
                """
         | 
| 247 | 
            +
                Perform top-p (nucleus) sampling on a probability distribution.
         | 
| 248 | 
            +
             | 
| 249 | 
            +
                Args:
         | 
| 250 | 
            +
                    probs (torch.Tensor): Probability distribution tensor.
         | 
| 251 | 
            +
                    p (float): Probability threshold for top-p sampling.
         | 
| 252 | 
            +
             | 
| 253 | 
            +
                Returns:
         | 
| 254 | 
            +
                    torch.Tensor: Sampled token indices.
         | 
| 255 | 
            +
             | 
| 256 | 
            +
                Note:
         | 
| 257 | 
            +
                    Top-p sampling selects the smallest set of tokens whose cumulative probability mass
         | 
| 258 | 
            +
                    exceeds the threshold p. The distribution is renormalized based on the selected tokens.
         | 
| 259 | 
            +
             | 
| 260 | 
            +
                """
         | 
| 261 | 
            +
                probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
         | 
| 262 | 
            +
                probs_sum = torch.cumsum(probs_sort, dim=-1)
         | 
| 263 | 
            +
                mask = probs_sum - probs_sort > p
         | 
| 264 | 
            +
                probs_sort[mask] = 0.0
         | 
| 265 | 
            +
                probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
         | 
| 266 | 
            +
                next_token = torch.multinomial(probs_sort, num_samples=s)
         | 
| 267 | 
            +
                next_token = torch.gather(probs_idx, -1, next_token)
         | 
| 268 | 
            +
                return next_token
         | 
    	
        superposed/llama/metrics.py
    ADDED
    
    | @@ -0,0 +1,109 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            import nltk
         | 
| 3 | 
            +
            from nltk.translate.bleu_score import SmoothingFunction
         | 
| 4 | 
            +
            from tqdm import tqdm
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            def calculate_perplexity(model, tokens, prompt_len, bsz=1, marker=False):
         | 
| 7 | 
            +
                """
         | 
| 8 | 
            +
                Calculate perplexity of given tokens using provided model, ignoring padding tokens. 
         | 
| 9 | 
            +
                Args:
         | 
| 10 | 
            +
                    model: Llama model
         | 
| 11 | 
            +
                    tokens (List[List[int]] or torch.Tensor): Input tokens (n_prompt * n_draft, seqlen)
         | 
| 12 | 
            +
                    prompt_len (int): Prefix length
         | 
| 13 | 
            +
                    bsz (int): Batch size
         | 
| 14 | 
            +
                    marker (bool): Whether to show progress bar
         | 
| 15 | 
            +
                Returns:
         | 
| 16 | 
            +
                    Perplexity across all generations (n_prompt * n_drafts)
         | 
| 17 | 
            +
                """
         | 
| 18 | 
            +
                it = range(0, len(tokens), bsz)
         | 
| 19 | 
            +
                if marker:
         | 
| 20 | 
            +
                    it = tqdm(it)
         | 
| 21 | 
            +
                start = 0
         | 
| 22 | 
            +
                ppl = torch.zeros(len(tokens))
         | 
| 23 | 
            +
                for start in it:
         | 
| 24 | 
            +
                    end = start + bsz
         | 
| 25 | 
            +
                    data = tokens[start : end]
         | 
| 26 | 
            +
                    if not isinstance(data, list):
         | 
| 27 | 
            +
                        data = data.tolist()
         | 
| 28 | 
            +
                    # Remove any padding tokens (-1) in generations
         | 
| 29 | 
            +
                    for d_idx in range(len(data)):
         | 
| 30 | 
            +
                        cur = data[d_idx]
         | 
| 31 | 
            +
                        if -1 in cur:
         | 
| 32 | 
            +
                            data[d_idx] = cur[:cur.index(-1)]
         | 
| 33 | 
            +
                    # Calculate cross entropy loss on tokens
         | 
| 34 | 
            +
                    ce_loss = model.generate(data, max_gen_len=0, temperature=-1, top_p=-1, grade=True)
         | 
| 35 | 
            +
                    # Cut off everything past `prompt_len`
         | 
| 36 | 
            +
                    ce_loss = ce_loss[:, prompt_len-1:]  # Subtract 1 because the first token (start token) is removed
         | 
| 37 | 
            +
                    # Calculate perplexity 
         | 
| 38 | 
            +
                    lengths = (ce_loss != 0).sum(dim=-1)
         | 
| 39 | 
            +
                    mean = ce_loss.sum(dim=-1) / lengths
         | 
| 40 | 
            +
                    ppl[start : end] = torch.exp(-1 * mean)
         | 
| 41 | 
            +
                return ppl
         | 
| 42 | 
            +
                
         | 
| 43 | 
            +
            def calculate_diversity(generations, k=4):
         | 
| 44 | 
            +
                """
         | 
| 45 | 
            +
                Calculate diversity of generations using SELF-BLEU.
         | 
| 46 | 
            +
                Args:
         | 
| 47 | 
            +
                    generations (List[List[List[int]]]): Tokenized input
         | 
| 48 | 
            +
                    k (int, Optional): Number of n-grams to use for bleu
         | 
| 49 | 
            +
                Returns:
         | 
| 50 | 
            +
                    Average diversity across all generations (float)
         | 
| 51 | 
            +
                """
         | 
| 52 | 
            +
                nltk.download('punkt')  # Can be deleted once downloaded
         | 
| 53 | 
            +
                smooth = SmoothingFunction()
         | 
| 54 | 
            +
                bleus = []
         | 
| 55 | 
            +
                
         | 
| 56 | 
            +
                for drafts in generations:
         | 
| 57 | 
            +
                    tokenized_drafts = []
         | 
| 58 | 
            +
                    # Stringify tokens
         | 
| 59 | 
            +
                    for d in drafts:
         | 
| 60 | 
            +
                        if -1 in d:
         | 
| 61 | 
            +
                            d = d[:d.index(-1)]
         | 
| 62 | 
            +
                        tokenized_drafts.append([str(n) for n in d])
         | 
| 63 | 
            +
                    # Calculate SELF-BLEU
         | 
| 64 | 
            +
                    minlength = min([len(g) for g in tokenized_drafts])
         | 
| 65 | 
            +
                    minlength = min(minlength, k)
         | 
| 66 | 
            +
                    weights = tuple((1. / minlength for _ in range(minlength)))
         | 
| 67 | 
            +
                    for i in range(len(drafts)):
         | 
| 68 | 
            +
                        # Create source and reference (all other drafts)
         | 
| 69 | 
            +
                        src = tokenized_drafts[i]
         | 
| 70 | 
            +
                        ref = tokenized_drafts[:i] + tokenized_drafts[i+1:]
         | 
| 71 | 
            +
                        tmp = nltk.translate.bleu_score.sentence_bleu(references=ref, 
         | 
| 72 | 
            +
                                                                      hypothesis=src, 
         | 
| 73 | 
            +
                                                                      weights=weights,
         | 
| 74 | 
            +
                                                                      smoothing_function=smooth.method1)
         | 
| 75 | 
            +
                        bleus.append(tmp)
         | 
| 76 | 
            +
                bleus = torch.Tensor(bleus)
         | 
| 77 | 
            +
                return torch.mean(bleus)
         | 
| 78 | 
            +
             | 
| 79 | 
            +
             | 
| 80 | 
            +
            def calculate_ngram_repetition(sequences):
         | 
| 81 | 
            +
                """
         | 
| 82 | 
            +
                Calculate uniqueness scores of `sequences`.
         | 
| 83 | 
            +
                Args:
         | 
| 84 | 
            +
                    sequences (List[List[int]]): Generated sequences
         | 
| 85 | 
            +
                Returns:
         | 
| 86 | 
            +
                    (unigram_uniqueness, bigram_uniqueness, trigram_uniqueness)
         | 
| 87 | 
            +
                """
         | 
| 88 | 
            +
                u_total = 0
         | 
| 89 | 
            +
                b_total = 0
         | 
| 90 | 
            +
                t_total = 0
         | 
| 91 | 
            +
                # Iterate through all sequences indiscriminately
         | 
| 92 | 
            +
                for gen in sequences:
         | 
| 93 | 
            +
                    if -1 in gen:
         | 
| 94 | 
            +
                        gen = gen[:gen.index(-1)]
         | 
| 95 | 
            +
                    unigrams, bigrams, trigrams = [], [], []
         | 
| 96 | 
            +
                    o = [str(i) for i in gen]
         | 
| 97 | 
            +
                    # Create lists of n-grams for the generation
         | 
| 98 | 
            +
                    for i in range(len(o)):
         | 
| 99 | 
            +
                        unigrams.append(o[i])
         | 
| 100 | 
            +
                    for i in range(len(o) - 1):
         | 
| 101 | 
            +
                        bigrams.append(o[i] + '_' + o[i + 1])
         | 
| 102 | 
            +
                    for i in range(len(o) - 2):
         | 
| 103 | 
            +
                        trigrams.append(o[i] + '_' + o[i + 1] + '_' + o[i + 2])
         | 
| 104 | 
            +
                    # Calculate uniqueness of the generation
         | 
| 105 | 
            +
                    u, b, t = len(set(unigrams)) / len(unigrams), len(set(bigrams)) / len(bigrams), len(set(trigrams)) / len(trigrams)
         | 
| 106 | 
            +
                    u_total += u
         | 
| 107 | 
            +
                    b_total += b
         | 
| 108 | 
            +
                    t_total += t
         | 
| 109 | 
            +
                return u_total / len(sequences), b_total / len(sequences), t_total / len(sequences)
         | 
    	
        superposed/llama/model.py
    ADDED
    
    | @@ -0,0 +1,548 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import math
         | 
| 5 | 
            +
            from dataclasses import dataclass
         | 
| 6 | 
            +
            from typing import Optional, Tuple
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            import fairscale.nn.model_parallel.initialize as fs_init
         | 
| 9 | 
            +
            import torch
         | 
| 10 | 
            +
            import torch.nn.functional as F
         | 
| 11 | 
            +
            from fairscale.nn.model_parallel.layers import (
         | 
| 12 | 
            +
                ColumnParallelLinear,
         | 
| 13 | 
            +
                ParallelEmbedding,
         | 
| 14 | 
            +
                RowParallelLinear,
         | 
| 15 | 
            +
            )
         | 
| 16 | 
            +
            from torch import nn
         | 
| 17 | 
            +
             | 
| 18 | 
            +
             | 
| 19 | 
            +
            @dataclass
         | 
| 20 | 
            +
            class ModelArgs:
         | 
| 21 | 
            +
                dim: int = 4096
         | 
| 22 | 
            +
                n_layers: int = 32
         | 
| 23 | 
            +
                n_heads: int = 32
         | 
| 24 | 
            +
                n_kv_heads: Optional[int] = None
         | 
| 25 | 
            +
                vocab_size: int = -1  # defined later by tokenizer
         | 
| 26 | 
            +
                multiple_of: int = 256  # make SwiGLU hidden layer size multiple of large power of 2
         | 
| 27 | 
            +
                ffn_dim_multiplier: Optional[float] = None
         | 
| 28 | 
            +
                norm_eps: float = 1e-5
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                max_batch_size: int = 32
         | 
| 31 | 
            +
                max_seq_len: int = 2048
         | 
| 32 | 
            +
             | 
| 33 | 
            +
             | 
| 34 | 
            +
            class RMSNorm(torch.nn.Module):
         | 
| 35 | 
            +
                def __init__(self, dim: int, eps: float = 1e-6):
         | 
| 36 | 
            +
                    """
         | 
| 37 | 
            +
                    Initialize the RMSNorm normalization layer.
         | 
| 38 | 
            +
             | 
| 39 | 
            +
                    Args:
         | 
| 40 | 
            +
                        dim (int): The dimension of the input tensor.
         | 
| 41 | 
            +
                        eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
         | 
| 42 | 
            +
             | 
| 43 | 
            +
                    Attributes:
         | 
| 44 | 
            +
                        eps (float): A small value added to the denominator for numerical stability.
         | 
| 45 | 
            +
                        weight (nn.Parameter): Learnable scaling parameter.
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                    """
         | 
| 48 | 
            +
                    super().__init__()
         | 
| 49 | 
            +
                    self.eps = eps
         | 
| 50 | 
            +
                    self.weight = nn.Parameter(torch.ones(dim))
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                def _norm(self, x):
         | 
| 53 | 
            +
                    """
         | 
| 54 | 
            +
                    Apply the RMSNorm normalization to the input tensor.
         | 
| 55 | 
            +
             | 
| 56 | 
            +
                    Args:
         | 
| 57 | 
            +
                        x (torch.Tensor): The input tensor.
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                    Returns:
         | 
| 60 | 
            +
                        torch.Tensor: The normalized tensor.
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                    """
         | 
| 63 | 
            +
                    return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                def forward(self, x):
         | 
| 66 | 
            +
                    """
         | 
| 67 | 
            +
                    Forward pass through the RMSNorm layer.
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                    Args:
         | 
| 70 | 
            +
                        x (torch.Tensor): The input tensor.
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                    Returns:
         | 
| 73 | 
            +
                        torch.Tensor: The output tensor after applying RMSNorm.
         | 
| 74 | 
            +
             | 
| 75 | 
            +
                    """
         | 
| 76 | 
            +
                    output = self._norm(x.float()).type_as(x)
         | 
| 77 | 
            +
                    return output * self.weight
         | 
| 78 | 
            +
             | 
| 79 | 
            +
             | 
| 80 | 
            +
            def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
         | 
| 81 | 
            +
                """
         | 
| 82 | 
            +
                Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
         | 
| 83 | 
            +
             | 
| 84 | 
            +
                This function calculates a frequency tensor with complex exponentials using the given dimension 'dim'
         | 
| 85 | 
            +
                and the end index 'end'. The 'theta' parameter scales the frequencies.
         | 
| 86 | 
            +
                The returned tensor contains complex values in complex64 data type.
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                Args:
         | 
| 89 | 
            +
                    dim (int): Dimension of the frequency tensor.
         | 
| 90 | 
            +
                    end (int): End index for precomputing frequencies.
         | 
| 91 | 
            +
                    theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
         | 
| 92 | 
            +
             | 
| 93 | 
            +
                Returns:
         | 
| 94 | 
            +
                    torch.Tensor: Precomputed frequency tensor with complex exponentials.
         | 
| 95 | 
            +
             | 
| 96 | 
            +
                
         | 
| 97 | 
            +
                    
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                """
         | 
| 100 | 
            +
                freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
         | 
| 101 | 
            +
                t = torch.arange(end, device=freqs.device)  # type: ignore
         | 
| 102 | 
            +
                freqs = torch.outer(t, freqs).float()  # type: ignore
         | 
| 103 | 
            +
                freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64
         | 
| 104 | 
            +
                return freqs_cis
         | 
| 105 | 
            +
             | 
| 106 | 
            +
             | 
| 107 | 
            +
            def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
         | 
| 108 | 
            +
                """
         | 
| 109 | 
            +
                Reshape frequency tensor for broadcasting it with another tensor.
         | 
| 110 | 
            +
             | 
| 111 | 
            +
                This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
         | 
| 112 | 
            +
                for the purpose of broadcasting the frequency tensor during element-wise operations.
         | 
| 113 | 
            +
             | 
| 114 | 
            +
                Args:
         | 
| 115 | 
            +
                    freqs_cis (torch.Tensor): Frequency tensor to be reshaped.
         | 
| 116 | 
            +
                    x (torch.Tensor): Target tensor for broadcasting compatibility.
         | 
| 117 | 
            +
             | 
| 118 | 
            +
                Returns:
         | 
| 119 | 
            +
                    torch.Tensor: Reshaped frequency tensor.
         | 
| 120 | 
            +
             | 
| 121 | 
            +
                Raises:
         | 
| 122 | 
            +
                    AssertionError: If the frequency tensor doesn't match the expected shape.
         | 
| 123 | 
            +
                    AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions.
         | 
| 124 | 
            +
                """
         | 
| 125 | 
            +
                ndim = x.ndim
         | 
| 126 | 
            +
                assert 0 <= 1 < ndim
         | 
| 127 | 
            +
                assert freqs_cis.shape == (x.shape[1], x.shape[-1])
         | 
| 128 | 
            +
                shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
         | 
| 129 | 
            +
                return freqs_cis.view(*shape)
         | 
| 130 | 
            +
             | 
| 131 | 
            +
             | 
| 132 | 
            +
            def apply_rotary_emb(
         | 
| 133 | 
            +
                xq: torch.Tensor,
         | 
| 134 | 
            +
                xk: torch.Tensor,
         | 
| 135 | 
            +
                freqs_cis: torch.Tensor,
         | 
| 136 | 
            +
            ) -> Tuple[torch.Tensor, torch.Tensor]:
         | 
| 137 | 
            +
                """
         | 
| 138 | 
            +
                Apply rotary embeddings to input tensors using the given frequency tensor.
         | 
| 139 | 
            +
             | 
| 140 | 
            +
                This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided
         | 
| 141 | 
            +
                frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor
         | 
| 142 | 
            +
                is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are
         | 
| 143 | 
            +
                returned as real tensors.
         | 
| 144 | 
            +
             | 
| 145 | 
            +
                Args:
         | 
| 146 | 
            +
                    xq (torch.Tensor): Query tensor to apply rotary embeddings.
         | 
| 147 | 
            +
                    xk (torch.Tensor): Key tensor to apply rotary embeddings.
         | 
| 148 | 
            +
                    freqs_cis (torch.Tensor): Precomputed frequency tensor for complex exponentials.
         | 
| 149 | 
            +
             | 
| 150 | 
            +
                Returns:
         | 
| 151 | 
            +
                    Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
         | 
| 152 | 
            +
             | 
| 153 | 
            +
                    
         | 
| 154 | 
            +
             | 
| 155 | 
            +
                """
         | 
| 156 | 
            +
                xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
         | 
| 157 | 
            +
                xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
         | 
| 158 | 
            +
                freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
         | 
| 159 | 
            +
                xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
         | 
| 160 | 
            +
                xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
         | 
| 161 | 
            +
                return xq_out.type_as(xq), xk_out.type_as(xk)
         | 
| 162 | 
            +
             | 
| 163 | 
            +
             | 
| 164 | 
            +
            def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
         | 
| 165 | 
            +
                """torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
         | 
| 166 | 
            +
                bs, slen, n_kv_heads, head_dim = x.shape
         | 
| 167 | 
            +
                if n_rep == 1:
         | 
| 168 | 
            +
                    return x
         | 
| 169 | 
            +
                return (
         | 
| 170 | 
            +
                    x[:, :, :, None, :]
         | 
| 171 | 
            +
                    .expand(bs, slen, n_kv_heads, n_rep, head_dim)
         | 
| 172 | 
            +
                    .reshape(bs, slen, n_kv_heads * n_rep, head_dim)
         | 
| 173 | 
            +
                )
         | 
| 174 | 
            +
             | 
| 175 | 
            +
             | 
| 176 | 
            +
            class Attention(nn.Module):
         | 
| 177 | 
            +
                """Multi-head attention module."""
         | 
| 178 | 
            +
                def __init__(self, args: ModelArgs):
         | 
| 179 | 
            +
                    """
         | 
| 180 | 
            +
                    Initialize the Attention module.
         | 
| 181 | 
            +
             | 
| 182 | 
            +
                    Args:
         | 
| 183 | 
            +
                        args (ModelArgs): Model configuration parameters.
         | 
| 184 | 
            +
             | 
| 185 | 
            +
                    Attributes:
         | 
| 186 | 
            +
                        n_kv_heads (int): Number of key and value heads.
         | 
| 187 | 
            +
                        n_local_heads (int): Number of local query heads.
         | 
| 188 | 
            +
                        n_local_kv_heads (int): Number of local key and value heads.
         | 
| 189 | 
            +
                        n_rep (int): Number of repetitions for local heads.
         | 
| 190 | 
            +
                        head_dim (int): Dimension size of each attention head.
         | 
| 191 | 
            +
                        wq (ColumnParallelLinear): Linear transformation for queries.
         | 
| 192 | 
            +
                        wk (ColumnParallelLinear): Linear transformation for keys.
         | 
| 193 | 
            +
                        wv (ColumnParallelLinear): Linear transformation for values.
         | 
| 194 | 
            +
                        wo (RowParallelLinear): Linear transformation for output.
         | 
| 195 | 
            +
                        cache_k (torch.Tensor): Cached keys for attention.
         | 
| 196 | 
            +
                        cache_v (torch.Tensor): Cached values for attention.
         | 
| 197 | 
            +
             | 
| 198 | 
            +
                    """
         | 
| 199 | 
            +
                    super().__init__()
         | 
| 200 | 
            +
                    self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
         | 
| 201 | 
            +
                    model_parallel_size = fs_init.get_model_parallel_world_size()
         | 
| 202 | 
            +
                    self.n_local_heads = args.n_heads // model_parallel_size
         | 
| 203 | 
            +
                    self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
         | 
| 204 | 
            +
                    self.n_rep = self.n_local_heads // self.n_local_kv_heads
         | 
| 205 | 
            +
                    self.head_dim = args.dim // args.n_heads
         | 
| 206 | 
            +
             | 
| 207 | 
            +
                    self.wq = ColumnParallelLinear(
         | 
| 208 | 
            +
                        args.dim,
         | 
| 209 | 
            +
                        args.n_heads * self.head_dim,
         | 
| 210 | 
            +
                        bias=False,
         | 
| 211 | 
            +
                        gather_output=False,
         | 
| 212 | 
            +
                        init_method=lambda x: x,
         | 
| 213 | 
            +
                    )
         | 
| 214 | 
            +
                    self.wk = ColumnParallelLinear(
         | 
| 215 | 
            +
                        args.dim,
         | 
| 216 | 
            +
                        self.n_kv_heads * self.head_dim,
         | 
| 217 | 
            +
                        bias=False,
         | 
| 218 | 
            +
                        gather_output=False,
         | 
| 219 | 
            +
                        init_method=lambda x: x,
         | 
| 220 | 
            +
                    )
         | 
| 221 | 
            +
                    self.wv = ColumnParallelLinear(
         | 
| 222 | 
            +
                        args.dim,
         | 
| 223 | 
            +
                        self.n_kv_heads * self.head_dim,
         | 
| 224 | 
            +
                        bias=False,
         | 
| 225 | 
            +
                        gather_output=False,
         | 
| 226 | 
            +
                        init_method=lambda x: x,
         | 
| 227 | 
            +
                    )
         | 
| 228 | 
            +
                    self.wo = RowParallelLinear(
         | 
| 229 | 
            +
                        args.n_heads * self.head_dim,
         | 
| 230 | 
            +
                        args.dim,
         | 
| 231 | 
            +
                        bias=False,
         | 
| 232 | 
            +
                        input_is_parallel=True,
         | 
| 233 | 
            +
                        init_method=lambda x: x,
         | 
| 234 | 
            +
                    )
         | 
| 235 | 
            +
             | 
| 236 | 
            +
                    self.cache_k = torch.zeros(
         | 
| 237 | 
            +
                        (
         | 
| 238 | 
            +
                            args.max_batch_size,
         | 
| 239 | 
            +
                            args.max_seq_len,
         | 
| 240 | 
            +
                            self.n_local_kv_heads,
         | 
| 241 | 
            +
                            self.head_dim,
         | 
| 242 | 
            +
                        )
         | 
| 243 | 
            +
                    ).cuda()
         | 
| 244 | 
            +
                    self.cache_v = torch.zeros(
         | 
| 245 | 
            +
                        (
         | 
| 246 | 
            +
                            args.max_batch_size,
         | 
| 247 | 
            +
                            args.max_seq_len,
         | 
| 248 | 
            +
                            self.n_local_kv_heads,
         | 
| 249 | 
            +
                            self.head_dim,
         | 
| 250 | 
            +
                        )
         | 
| 251 | 
            +
                    ).cuda()
         | 
| 252 | 
            +
             | 
| 253 | 
            +
                def forward(
         | 
| 254 | 
            +
                    self,
         | 
| 255 | 
            +
                    x: torch.Tensor,
         | 
| 256 | 
            +
                    start_pos: int,
         | 
| 257 | 
            +
                    freqs_cis: torch.Tensor,
         | 
| 258 | 
            +
                    mask: Optional[torch.Tensor],
         | 
| 259 | 
            +
                    beam: Optional[bool] = None,
         | 
| 260 | 
            +
                    n_beams: Optional[int] = None,
         | 
| 261 | 
            +
                    attention_change_ids: Optional[torch.Tensor] = None
         | 
| 262 | 
            +
                ):
         | 
| 263 | 
            +
                    """
         | 
| 264 | 
            +
                    Forward pass of the attention module.
         | 
| 265 | 
            +
             | 
| 266 | 
            +
                    Args:
         | 
| 267 | 
            +
                        x (torch.Tensor): Input tensor.
         | 
| 268 | 
            +
                        start_pos (int): Starting position for caching.
         | 
| 269 | 
            +
                        freqs_cis (torch.Tensor): Precomputed frequency tensor.
         | 
| 270 | 
            +
                        mask (torch.Tensor, optional): Attention mask tensor.
         | 
| 271 | 
            +
             | 
| 272 | 
            +
                    Returns:
         | 
| 273 | 
            +
                        torch.Tensor: Output tensor after attention.
         | 
| 274 | 
            +
             | 
| 275 | 
            +
                    """
         | 
| 276 | 
            +
                    bsz, seqlen, _ = x.shape
         | 
| 277 | 
            +
                    _, max_seq_len, n_local_kv_heads, head_dim = self.cache_k.shape
         | 
| 278 | 
            +
                    # KV Cache updates for beam search
         | 
| 279 | 
            +
                    if beam:
         | 
| 280 | 
            +
                        # Extract used cache values
         | 
| 281 | 
            +
                        used_cache_k = self.cache_k[:bsz]
         | 
| 282 | 
            +
                        used_cache_v = self.cache_v[:bsz]
         | 
| 283 | 
            +
                        # Reshape to apply change ids
         | 
| 284 | 
            +
                        t_cache_k = used_cache_k.reshape(bsz // n_beams, n_beams, max_seq_len, n_local_kv_heads, head_dim)
         | 
| 285 | 
            +
                        t_cache_v = used_cache_v.reshape(bsz // n_beams, n_beams, max_seq_len, n_local_kv_heads, head_dim)
         | 
| 286 | 
            +
                        used_cache_k = torch.take_along_dim(t_cache_k, attention_change_ids.reshape(-1, n_beams, 1, 1, 1), 1)
         | 
| 287 | 
            +
                        used_cache_v = torch.take_along_dim(t_cache_v, attention_change_ids.reshape(-1, n_beams, 1, 1, 1), 1)
         | 
| 288 | 
            +
                        # Update cache
         | 
| 289 | 
            +
                        self.cache_k[:bsz] = used_cache_k.reshape(bsz, max_seq_len, n_local_kv_heads, head_dim)
         | 
| 290 | 
            +
                        self.cache_v[:bsz] = used_cache_v.reshape(bsz, max_seq_len, n_local_kv_heads, head_dim)
         | 
| 291 | 
            +
                    
         | 
| 292 | 
            +
                    xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
         | 
| 293 | 
            +
             | 
| 294 | 
            +
                    xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
         | 
| 295 | 
            +
                    xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
         | 
| 296 | 
            +
                    xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
         | 
| 297 | 
            +
             | 
| 298 | 
            +
                    xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
         | 
| 299 | 
            +
             | 
| 300 | 
            +
                    self.cache_k = self.cache_k.to(xq)
         | 
| 301 | 
            +
                    self.cache_v = self.cache_v.to(xq)
         | 
| 302 | 
            +
             | 
| 303 | 
            +
                    self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
         | 
| 304 | 
            +
                    self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv
         | 
| 305 | 
            +
             | 
| 306 | 
            +
                    keys = self.cache_k[:bsz, : start_pos + seqlen]
         | 
| 307 | 
            +
                    values = self.cache_v[:bsz, : start_pos + seqlen]
         | 
| 308 | 
            +
             | 
| 309 | 
            +
                    # repeat k/v heads if n_kv_heads < n_heads
         | 
| 310 | 
            +
                    keys = repeat_kv(keys, self.n_rep)  # (bs, seqlen, n_local_heads, head_dim)
         | 
| 311 | 
            +
                    values = repeat_kv(values, self.n_rep)  # (bs, seqlen, n_local_heads, head_dim)
         | 
| 312 | 
            +
             | 
| 313 | 
            +
                    xq = xq.transpose(1, 2)  # (bs, n_local_heads, seqlen, head_dim)
         | 
| 314 | 
            +
                    keys = keys.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
         | 
| 315 | 
            +
                    values = values.transpose(1, 2)
         | 
| 316 | 
            +
                    scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim) # (bs, n_local_heads, seqlen, seqlen)
         | 
| 317 | 
            +
                    if mask is not None:
         | 
| 318 | 
            +
                        scores = scores + mask  # (bs, n_local_heads, seqlen, seqlen)
         | 
| 319 | 
            +
                    scores = F.softmax(scores.float(), dim=-1).type_as(xq) # (bs, n_local_heads, seqlen, seqlen)
         | 
| 320 | 
            +
                    output = torch.matmul(scores, values)  # (bs, n_local_heads, seqlen, head_dim)
         | 
| 321 | 
            +
                    output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
         | 
| 322 | 
            +
                    return self.wo(output)
         | 
| 323 | 
            +
             | 
| 324 | 
            +
             | 
| 325 | 
            +
            class FeedForward(nn.Module):
         | 
| 326 | 
            +
                def __init__(
         | 
| 327 | 
            +
                    self,
         | 
| 328 | 
            +
                    dim: int,
         | 
| 329 | 
            +
                    hidden_dim: int,
         | 
| 330 | 
            +
                    multiple_of: int,
         | 
| 331 | 
            +
                    ffn_dim_multiplier: Optional[float],
         | 
| 332 | 
            +
                ):
         | 
| 333 | 
            +
                    """
         | 
| 334 | 
            +
                    Initialize the FeedForward module.
         | 
| 335 | 
            +
             | 
| 336 | 
            +
                    Args:
         | 
| 337 | 
            +
                        dim (int): Input dimension.
         | 
| 338 | 
            +
                        hidden_dim (int): Hidden dimension of the feedforward layer.
         | 
| 339 | 
            +
                        multiple_of (int): Value to ensure hidden dimension is a multiple of this value.
         | 
| 340 | 
            +
                        ffn_dim_multiplier (float, optional): Custom multiplier for hidden dimension. Defaults to None.
         | 
| 341 | 
            +
             | 
| 342 | 
            +
                    Attributes:
         | 
| 343 | 
            +
                        w1 (ColumnParallelLinear): Linear transformation for the first layer.
         | 
| 344 | 
            +
                        w2 (RowParallelLinear): Linear transformation for the second layer.
         | 
| 345 | 
            +
                        w3 (ColumnParallelLinear): Linear transformation for the third layer.
         | 
| 346 | 
            +
             | 
| 347 | 
            +
                    """
         | 
| 348 | 
            +
                    super().__init__()
         | 
| 349 | 
            +
                    hidden_dim = int(2 * hidden_dim / 3)
         | 
| 350 | 
            +
                    # custom dim factor multiplier
         | 
| 351 | 
            +
                    if ffn_dim_multiplier is not None:
         | 
| 352 | 
            +
                        hidden_dim = int(ffn_dim_multiplier * hidden_dim)
         | 
| 353 | 
            +
                    hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
         | 
| 354 | 
            +
             | 
| 355 | 
            +
                    self.w1 = ColumnParallelLinear(
         | 
| 356 | 
            +
                        dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x
         | 
| 357 | 
            +
                    )
         | 
| 358 | 
            +
                    self.w2 = RowParallelLinear(
         | 
| 359 | 
            +
                        hidden_dim, dim, bias=False, input_is_parallel=True, init_method=lambda x: x
         | 
| 360 | 
            +
                    )
         | 
| 361 | 
            +
                    self.w3 = ColumnParallelLinear(
         | 
| 362 | 
            +
                        dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x
         | 
| 363 | 
            +
                    )
         | 
| 364 | 
            +
             | 
| 365 | 
            +
                def forward(self, x):
         | 
| 366 | 
            +
                    return self.w2(F.silu(self.w1(x)) * self.w3(x))
         | 
| 367 | 
            +
             | 
| 368 | 
            +
             | 
| 369 | 
            +
            class TransformerBlock(nn.Module):
         | 
| 370 | 
            +
                def __init__(self, layer_id: int, args: ModelArgs):
         | 
| 371 | 
            +
                    """
         | 
| 372 | 
            +
                    Initialize a TransformerBlock.
         | 
| 373 | 
            +
             | 
| 374 | 
            +
                    Args:
         | 
| 375 | 
            +
                        layer_id (int): Identifier for the layer.
         | 
| 376 | 
            +
                        args (ModelArgs): Model configuration parameters.
         | 
| 377 | 
            +
             | 
| 378 | 
            +
                    Attributes:
         | 
| 379 | 
            +
                        n_heads (int): Number of attention heads.
         | 
| 380 | 
            +
                        dim (int): Dimension size of the model.
         | 
| 381 | 
            +
                        head_dim (int): Dimension size of each attention head.
         | 
| 382 | 
            +
                        attention (Attention): Attention module.
         | 
| 383 | 
            +
                        feed_forward (FeedForward): FeedForward module.
         | 
| 384 | 
            +
                        layer_id (int): Identifier for the layer.
         | 
| 385 | 
            +
                        attention_norm (RMSNorm): Layer normalization for attention output.
         | 
| 386 | 
            +
                        ffn_norm (RMSNorm): Layer normalization for feedforward output.
         | 
| 387 | 
            +
             | 
| 388 | 
            +
                    """
         | 
| 389 | 
            +
                    super().__init__()
         | 
| 390 | 
            +
                    self.n_heads = args.n_heads
         | 
| 391 | 
            +
                    self.dim = args.dim
         | 
| 392 | 
            +
                    self.head_dim = args.dim // args.n_heads
         | 
| 393 | 
            +
                    self.attention = Attention(args)
         | 
| 394 | 
            +
                    self.feed_forward = FeedForward(
         | 
| 395 | 
            +
                        dim=args.dim,
         | 
| 396 | 
            +
                        hidden_dim=4 * args.dim,
         | 
| 397 | 
            +
                        multiple_of=args.multiple_of,
         | 
| 398 | 
            +
                        ffn_dim_multiplier=args.ffn_dim_multiplier,
         | 
| 399 | 
            +
                    )
         | 
| 400 | 
            +
                    self.layer_id = layer_id
         | 
| 401 | 
            +
                    self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
         | 
| 402 | 
            +
                    self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
         | 
| 403 | 
            +
             | 
| 404 | 
            +
                def forward(
         | 
| 405 | 
            +
                    self,
         | 
| 406 | 
            +
                    x: torch.Tensor,
         | 
| 407 | 
            +
                    start_pos: int,
         | 
| 408 | 
            +
                    freqs_cis: torch.Tensor,
         | 
| 409 | 
            +
                    mask: Optional[torch.Tensor],
         | 
| 410 | 
            +
                    beam: Optional[bool],
         | 
| 411 | 
            +
                    n_beams: Optional[int] = None,
         | 
| 412 | 
            +
                    attention_change_ids: Optional[torch.Tensor] = None
         | 
| 413 | 
            +
                ):
         | 
| 414 | 
            +
                    """
         | 
| 415 | 
            +
                    Perform a forward pass through the TransformerBlock.
         | 
| 416 | 
            +
             | 
| 417 | 
            +
                    Args:
         | 
| 418 | 
            +
                        x (torch.Tensor): Input tensor.
         | 
| 419 | 
            +
                        start_pos (int): Starting position for attention caching.
         | 
| 420 | 
            +
                        freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies.
         | 
| 421 | 
            +
                        mask (torch.Tensor, optional): Masking tensor for attention. Defaults to None.
         | 
| 422 | 
            +
             | 
| 423 | 
            +
                    Returns:
         | 
| 424 | 
            +
                        torch.Tensor: Output tensor after applying attention and feedforward layers.
         | 
| 425 | 
            +
             | 
| 426 | 
            +
                    """
         | 
| 427 | 
            +
                    if beam:
         | 
| 428 | 
            +
                        h = x + self.attention.forward(
         | 
| 429 | 
            +
                            self.attention_norm(x), start_pos, freqs_cis, mask, beam, n_beams, attention_change_ids
         | 
| 430 | 
            +
                        )
         | 
| 431 | 
            +
                    else:
         | 
| 432 | 
            +
                        h = x + self.attention.forward(
         | 
| 433 | 
            +
                            self.attention_norm(x), start_pos, freqs_cis, mask
         | 
| 434 | 
            +
                        )
         | 
| 435 | 
            +
                    out = h + self.feed_forward.forward(self.ffn_norm(h))
         | 
| 436 | 
            +
                    return out
         | 
| 437 | 
            +
             | 
| 438 | 
            +
             | 
| 439 | 
            +
            class Transformer(nn.Module):
         | 
| 440 | 
            +
                def __init__(self, params: ModelArgs):
         | 
| 441 | 
            +
                    """
         | 
| 442 | 
            +
                    Initialize a Transformer model.
         | 
| 443 | 
            +
             | 
| 444 | 
            +
                    Args:
         | 
| 445 | 
            +
                        params (ModelArgs): Model configuration parameters.
         | 
| 446 | 
            +
             | 
| 447 | 
            +
                    Attributes:
         | 
| 448 | 
            +
                        params (ModelArgs): Model configuration parameters.
         | 
| 449 | 
            +
                        vocab_size (int): Vocabulary size.
         | 
| 450 | 
            +
                        n_layers (int): Number of layers in the model.
         | 
| 451 | 
            +
                        tok_embeddings (ParallelEmbedding): Token embeddings.
         | 
| 452 | 
            +
                        layers (torch.nn.ModuleList): List of Transformer blocks.
         | 
| 453 | 
            +
                        norm (RMSNorm): Layer normalization for the model output.
         | 
| 454 | 
            +
                        output (ColumnParallelLinear): Linear layer for final output.
         | 
| 455 | 
            +
                        freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies.
         | 
| 456 | 
            +
             | 
| 457 | 
            +
                    """
         | 
| 458 | 
            +
                    super().__init__()
         | 
| 459 | 
            +
                    self.params = params
         | 
| 460 | 
            +
                    self.vocab_size = params.vocab_size
         | 
| 461 | 
            +
                    self.n_layers = params.n_layers
         | 
| 462 | 
            +
             | 
| 463 | 
            +
                    self.tok_embeddings = ParallelEmbedding(
         | 
| 464 | 
            +
                        params.vocab_size, params.dim, init_method=lambda x: x
         | 
| 465 | 
            +
                    )
         | 
| 466 | 
            +
             | 
| 467 | 
            +
                    self.layers = torch.nn.ModuleList()
         | 
| 468 | 
            +
                    for layer_id in range(params.n_layers):
         | 
| 469 | 
            +
                        self.layers.append(TransformerBlock(layer_id, params))
         | 
| 470 | 
            +
             | 
| 471 | 
            +
                    self.norm = RMSNorm(params.dim, eps=params.norm_eps)
         | 
| 472 | 
            +
                    self.output = ColumnParallelLinear(
         | 
| 473 | 
            +
                        params.dim, params.vocab_size, bias=False, init_method=lambda x: x
         | 
| 474 | 
            +
                    )
         | 
| 475 | 
            +
             | 
| 476 | 
            +
                    self.freqs_cis = precompute_freqs_cis(
         | 
| 477 | 
            +
                        # Note that self.params.max_seq_len is multiplied by 2 because the token limit for the Llama 2 generation of models is 4096. 
         | 
| 478 | 
            +
                        # Adding this multiplier instead of using 4096 directly allows for dynamism of token lengths while training or fine-tuning.
         | 
| 479 | 
            +
                        self.params.dim // self.params.n_heads, self.params.max_seq_len * 2
         | 
| 480 | 
            +
                    )
         | 
| 481 | 
            +
                    
         | 
| 482 | 
            +
                    
         | 
| 483 | 
            +
                @torch.inference_mode()
         | 
| 484 | 
            +
                def forward(self, 
         | 
| 485 | 
            +
                            tokens: torch.Tensor, 
         | 
| 486 | 
            +
                            start_pos: int, 
         | 
| 487 | 
            +
                            beam: bool, 
         | 
| 488 | 
            +
                            n_beams: Optional[int] = None, 
         | 
| 489 | 
            +
                            attention_change_ids: Optional[torch.Tensor] = None,
         | 
| 490 | 
            +
                            verbose: Optional[bool] = False):
         | 
| 491 | 
            +
                    """
         | 
| 492 | 
            +
                    Perform a forward pass through the Transformer model.
         | 
| 493 | 
            +
             | 
| 494 | 
            +
                    Args:
         | 
| 495 | 
            +
                        tokens (torch.Tensor): Input token indices.
         | 
| 496 | 
            +
                        start_pos (int): Starting position for attention caching.
         | 
| 497 | 
            +
                        verbose (bool): Whether to return intermediate hidden layer states 
         | 
| 498 | 
            +
                        
         | 
| 499 | 
            +
                    Returns:
         | 
| 500 | 
            +
                        torch.Tensor or (torch.Tensor, Dict): output logits after applying the Transformer model.
         | 
| 501 | 
            +
             | 
| 502 | 
            +
                    """
         | 
| 503 | 
            +
                    ### ANALYSIS CODE ###
         | 
| 504 | 
            +
                    if verbose:
         | 
| 505 | 
            +
                        states = {"layers": [], "tokens": tokens}
         | 
| 506 | 
            +
                    # 
         | 
| 507 | 
            +
                    
         | 
| 508 | 
            +
                    _bsz, seqlen = tokens.shape
         | 
| 509 | 
            +
                    h = self.tok_embeddings(tokens)
         | 
| 510 | 
            +
                    self.freqs_cis = self.freqs_cis.to(h.device)
         | 
| 511 | 
            +
                    freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]
         | 
| 512 | 
            +
             | 
| 513 | 
            +
                    ### ANALYSIS CODE ###
         | 
| 514 | 
            +
                    if verbose:
         | 
| 515 | 
            +
                        states["layers"].append(h)
         | 
| 516 | 
            +
                    #
         | 
| 517 | 
            +
             | 
| 518 | 
            +
                    mask = None
         | 
| 519 | 
            +
                    if seqlen > 1:
         | 
| 520 | 
            +
                        mask = torch.full(
         | 
| 521 | 
            +
                            (1, 1, seqlen, seqlen), float("-inf"), device=tokens.device
         | 
| 522 | 
            +
                        )
         | 
| 523 | 
            +
                        mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h)
         | 
| 524 | 
            +
             | 
| 525 | 
            +
                    for layer in self.layers:
         | 
| 526 | 
            +
                        if not beam:
         | 
| 527 | 
            +
                            h = layer(h, start_pos, freqs_cis, mask, beam)
         | 
| 528 | 
            +
                        else:
         | 
| 529 | 
            +
                            h = layer(h, start_pos, freqs_cis, mask, beam, n_beams, attention_change_ids)
         | 
| 530 | 
            +
                        ### ANALYSIS CODE ###
         | 
| 531 | 
            +
                        if verbose:
         | 
| 532 | 
            +
                            states["layers"].append(h)
         | 
| 533 | 
            +
                        #
         | 
| 534 | 
            +
                    h = self.norm(h)
         | 
| 535 | 
            +
                    # if want differences, at end, subtract differences from [-1] position of embedding vectors each iteration
         | 
| 536 | 
            +
                    
         | 
| 537 | 
            +
                    ### ANALYSIS CODE ###
         | 
| 538 | 
            +
                    if verbose:
         | 
| 539 | 
            +
                        states["layers"].append(h)
         | 
| 540 | 
            +
                    #
         | 
| 541 | 
            +
                    
         | 
| 542 | 
            +
                    output = self.output(h).float()
         | 
| 543 | 
            +
                
         | 
| 544 | 
            +
                    if verbose:
         | 
| 545 | 
            +
                        return output, states
         | 
| 546 | 
            +
                    else:
         | 
| 547 | 
            +
                        return output
         | 
| 548 | 
            +
                
         | 
    	
        superposed/llama/superpose.py
    ADDED
    
    | @@ -0,0 +1,328 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Implementation loosely based on https://github.com/tensorflow/tensor2tensor/blob/bafdc1b67730430d38d6ab802cbd51f9d053ba2e/tensor2tensor/utils/beam_search.py#L554
         | 
| 2 | 
            +
            import requests
         | 
| 3 | 
            +
            import time
         | 
| 4 | 
            +
            from datetime import datetime, timedelta
         | 
| 5 | 
            +
            from typing import Optional, Literal
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            import torch
         | 
| 8 | 
            +
            import torch.nn as nn
         | 
| 9 | 
            +
            from transformers import LlamaTokenizer
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            from superposed.llama.utils import *
         | 
| 12 | 
            +
            from superposed.ngrams.ngram_models import NGram
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            INF = 1. * 1e7
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            # Test by scaling # beams & verify work
         | 
| 17 | 
            +
            class Superpose(nn.Module): 
         | 
| 18 | 
            +
                def __init__(self, 
         | 
| 19 | 
            +
                             initial_tokens,
         | 
| 20 | 
            +
                             tokenizer,
         | 
| 21 | 
            +
                             vocab_size,
         | 
| 22 | 
            +
                             smoothing=Optional[Literal["geom", "all"]],
         | 
| 23 | 
            +
                             alpha = None,
         | 
| 24 | 
            +
                             verbose = False,
         | 
| 25 | 
            +
                             i_weights = None,
         | 
| 26 | 
            +
                             i_length = None,
         | 
| 27 | 
            +
                             ngrams = None,
         | 
| 28 | 
            +
                             sample_beams = False,
         | 
| 29 | 
            +
                             sample_tokens = False,
         | 
| 30 | 
            +
                             get_time = False,
         | 
| 31 | 
            +
                             penalty = 200): # default no effect
         | 
| 32 | 
            +
                    """
         | 
| 33 | 
            +
                    Initialize a beam search class.
         | 
| 34 | 
            +
                    
         | 
| 35 | 
            +
                    Args:
         | 
| 36 | 
            +
                        initial_tokens (torch.Tensor): Initial tokens
         | 
| 37 | 
            +
                        n_prompts (int): Number of prompts
         | 
| 38 | 
            +
                        tokenizer (Tokenizer): Llama tokenizer
         | 
| 39 | 
            +
                        vocab_size (int): Total vocab size
         | 
| 40 | 
            +
                        smoothing (str): Smoothing method ("geom" for default, "all" for only ngram, None for no ngram)
         | 
| 41 | 
            +
                        ngram_length (int): N gram length to consider
         | 
| 42 | 
            +
                        alpha (float): Alpha parameter
         | 
| 43 | 
            +
                        debug (bool): Whether to print information
         | 
| 44 | 
            +
                    """
         | 
| 45 | 
            +
                    super().__init__()
         | 
| 46 | 
            +
                    # primary parameters
         | 
| 47 | 
            +
                    self.n_prompts, self.n_drafts, _ = initial_tokens.shape
         | 
| 48 | 
            +
                    self.tokenizer = tokenizer
         | 
| 49 | 
            +
                    self.vocab_size = vocab_size
         | 
| 50 | 
            +
                    self.alive_seq = initial_tokens
         | 
| 51 | 
            +
                    self.fin_seq = initial_tokens
         | 
| 52 | 
            +
                    self.smoothing = smoothing
         | 
| 53 | 
            +
                    self.alive_log_probs = torch.zeros(self.n_prompts, self.n_drafts)
         | 
| 54 | 
            +
                    self.fin_log_probs = torch.full((self.n_prompts, self.n_drafts), float("-inf"))
         | 
| 55 | 
            +
                    self.alpha = alpha
         | 
| 56 | 
            +
                    self.verbose = verbose
         | 
| 57 | 
            +
                    self.penalty = penalty
         | 
| 58 | 
            +
                    # devices
         | 
| 59 | 
            +
                    self.cpu = torch.device('cpu')
         | 
| 60 | 
            +
                    self.gpu = torch.device('cuda')
         | 
| 61 | 
            +
                    # Interpolation length and weights
         | 
| 62 | 
            +
                    self.interpolation_weights = i_weights
         | 
| 63 | 
            +
                    self.i_length = i_length
         | 
| 64 | 
            +
                    # N-grams
         | 
| 65 | 
            +
                    self.bigram = ngrams[0] if len(ngrams) >= 1 else None
         | 
| 66 | 
            +
                    self.trigram = ngrams[1] if len(ngrams) >= 2 else None
         | 
| 67 | 
            +
                    self.fourgram = ngrams[2] if len(ngrams) >= 3 else None
         | 
| 68 | 
            +
                    self.fivegram = ngrams[3] if len(ngrams) >= 4 else None
         | 
| 69 | 
            +
                    self.sixgram = ngrams[4] if len(ngrams) >= 5 else None
         | 
| 70 | 
            +
                    self.sevengram = ngrams[5] if len(ngrams) >= 6 else None
         | 
| 71 | 
            +
                    # Timing
         | 
| 72 | 
            +
                    self.get_time = get_time
         | 
| 73 | 
            +
                    self.lookup_time = None
         | 
| 74 | 
            +
             | 
| 75 | 
            +
                def forward(self, probs, still_prompt, is_first, cur_pos, n_token_sample):
         | 
| 76 | 
            +
                    """
         | 
| 77 | 
            +
                    Apply beam decoding to update generations.
         | 
| 78 | 
            +
                    
         | 
| 79 | 
            +
                    Args:
         | 
| 80 | 
            +
                        probs (torch.Tensor): Next token probability distribution
         | 
| 81 | 
            +
                        still_prompt (torch.Tensor): Flags of prompts that should not generate yet (n_prompts, )
         | 
| 82 | 
            +
                        is_first (torch.Tensor): Flags of prompts that are on their first generation (n_prompts, )
         | 
| 83 | 
            +
                        cur_pos (int): Current generation position
         | 
| 84 | 
            +
                        n_token_sample (int): Number of tokens from model distribution to use
         | 
| 85 | 
            +
                        
         | 
| 86 | 
            +
                    Return:
         | 
| 87 | 
            +
                        if standard beam search:
         | 
| 88 | 
            +
                            attention_change_ids (torch.Tensor): New indices in kv cache (n_prompts, n_drafts)
         | 
| 89 | 
            +
                        if mixed:
         | 
| 90 | 
            +
                            token_weights (torch.Tensor): Mixing weights (n_prompts, vocab_size)
         | 
| 91 | 
            +
                    """        
         | 
| 92 | 
            +
                    # Adjust input probabilities
         | 
| 93 | 
            +
                    probs = self.get_top_k(probs, 32000, n_token_sample)
         | 
| 94 | 
            +
                    reshaped_probs = probs.reshape(self.n_prompts, 1, -1)
         | 
| 95 | 
            +
                    reshaped_probs = reshaped_probs.repeat(1, self.n_drafts, 1)
         | 
| 96 | 
            +
                    # Ngram smoothing 
         | 
| 97 | 
            +
                    if self.smoothing is not None:
         | 
| 98 | 
            +
                        if self.smoothing == "geom":
         | 
| 99 | 
            +
                            ngram_probs = self.ngram_probs(self.alive_seq, cur_pos, probs=probs)              
         | 
| 100 | 
            +
                            # Make mask and normalize
         | 
| 101 | 
            +
                            prob_mask = reshaped_probs != 0
         | 
| 102 | 
            +
                            ngram_probs *= prob_mask    
         | 
| 103 | 
            +
                            # Calculate logprobs and interpolate distributions
         | 
| 104 | 
            +
                            llm_log_probs = torch.log(reshaped_probs)
         | 
| 105 | 
            +
                            ngram_log_probs = torch.log(ngram_probs)
         | 
| 106 | 
            +
                            log_probs = (1 - self.alpha) * llm_log_probs + self.alpha * ngram_log_probs
         | 
| 107 | 
            +
                            # Apply penalty to drafts where no interpolation occurred
         | 
| 108 | 
            +
                            is_all_inf = (log_probs != float("-inf")).sum(dim=-1, keepdims=True) == 0 
         | 
| 109 | 
            +
                            log_probs = torch.where(is_all_inf, (1 - self.alpha) * llm_log_probs - self.penalty, log_probs)
         | 
| 110 | 
            +
                        elif self.smoothing == "all":
         | 
| 111 | 
            +
                            ngram_probs = self.ngram_probs(self.alive_seq, cur_pos, probs=None)              
         | 
| 112 | 
            +
                            log_probs = torch.log(ngram_probs)
         | 
| 113 | 
            +
                    else:
         | 
| 114 | 
            +
                        log_probs = torch.log(reshaped_probs)
         | 
| 115 | 
            +
                    curr_log_probs = self.alive_log_probs.unsqueeze(dim=2) + log_probs # [n_prompts, n_drafts, vocab_size]
         | 
| 116 | 
            +
                    # Warning if nan
         | 
| 117 | 
            +
                    if (torch.any(torch.isnan(curr_log_probs)).item()):
         | 
| 118 | 
            +
                        raise RuntimeWarning("nan in sequence log probs", file=self.output_file)
         | 
| 119 | 
            +
                    # Potential Sequences
         | 
| 120 | 
            +
                    flat_curr_log_probs = curr_log_probs.reshape(-1, self.vocab_size*self.n_drafts)
         | 
| 121 | 
            +
                    topk_log_probs, topk_idx = torch.topk(flat_curr_log_probs, 2 * self.n_drafts, dim=-1)
         | 
| 122 | 
            +
                    topk_beam_id = topk_idx // self.vocab_size # [n_prompts, 2 * n_drafts]
         | 
| 123 | 
            +
                    topk_idx = topk_idx % self.vocab_size # [n_prompts, 2 * n_drafts]
         | 
| 124 | 
            +
                    # First timestep uses top-k next tokens
         | 
| 125 | 
            +
                    is_first_idx = is_first.nonzero(as_tuple=True)[0]
         | 
| 126 | 
            +
                    if len(is_first_idx) != 0:
         | 
| 127 | 
            +
                        first_time_log_probs = log_probs[is_first_idx][:, 0, :].squeeze(dim=1)
         | 
| 128 | 
            +
                        first_time_log_probs, first_time_topk_idx = torch.topk(first_time_log_probs, 2 * self.n_drafts, dim=1)
         | 
| 129 | 
            +
                        topk_idx[is_first_idx] = first_time_topk_idx
         | 
| 130 | 
            +
                        topk_log_probs[is_first_idx] = self.alive_log_probs[is_first_idx, 0].unsqueeze(dim=1) + first_time_log_probs 
         | 
| 131 | 
            +
                    # New sequences
         | 
| 132 | 
            +
                    topk_seq = torch.take_along_dim(self.alive_seq, topk_beam_id.unsqueeze(2), dim=1) # [n_prompts, 2 * n_drafts, vocab_size]
         | 
| 133 | 
            +
                    topk_seq[:, :, cur_pos] = topk_idx
         | 
| 134 | 
            +
                    topk_finished = topk_idx == self.tokenizer.eos_id
         | 
| 135 | 
            +
                    # Only update sequences for those that have begun generating
         | 
| 136 | 
            +
                    new_alive_seq, new_alive_log_probs = self.grow_alive(topk_seq, topk_log_probs, topk_finished)
         | 
| 137 | 
            +
                    new_fin_seq, new_fin_log_probs = self.grow_fin(topk_seq, topk_log_probs, topk_finished)
         | 
| 138 | 
            +
                    still_prompt_probs = still_prompt.reshape(-1, 1)
         | 
| 139 | 
            +
                    still_prompt_seqs = still_prompt.reshape(-1, 1, 1)
         | 
| 140 | 
            +
                    self.alive_seq = torch.where(still_prompt_seqs, self.alive_seq, new_alive_seq)
         | 
| 141 | 
            +
                    self.alive_log_probs = torch.where(still_prompt_probs, self.alive_log_probs, new_alive_log_probs) 
         | 
| 142 | 
            +
                    self.fin_seq = torch.where(still_prompt_seqs, self.fin_seq, new_fin_seq)
         | 
| 143 | 
            +
                    self.fin_log_probs = torch.where(still_prompt_probs, self.fin_log_probs, new_fin_log_probs)
         | 
| 144 | 
            +
                    # Create superposition matrix and return it
         | 
| 145 | 
            +
                    topk_idx = self.alive_seq[:, :, cur_pos].reshape(self.n_prompts, -1)
         | 
| 146 | 
            +
                    token_weights = self.superposition_matrix(topk_idx)
         | 
| 147 | 
            +
                    return token_weights
         | 
| 148 | 
            +
                    
         | 
| 149 | 
            +
                def grow_alive(self, topk_seq, topk_log_probs, topk_finished):
         | 
| 150 | 
            +
                    """
         | 
| 151 | 
            +
                    Extend running generations.
         | 
| 152 | 
            +
                    Args:
         | 
| 153 | 
            +
                        topk_seq (torch.Tensor): Top k sequences (n_prompts, 2 * n_drafts, vocab_size)
         | 
| 154 | 
            +
                        topk_log_probs (torch.Tensor): Log probabilities (n_prompts, 2 * n_drafts)
         | 
| 155 | 
            +
                        topk_finished (torch.Tensor): Whether a sequence is finished (n_prompts, 2 * n_drafts) 
         | 
| 156 | 
            +
                    Returns:
         | 
| 157 | 
            +
                        new_alive_seq, new_alive_log_probs
         | 
| 158 | 
            +
                    """
         | 
| 159 | 
            +
                    topk_log_probs = topk_log_probs + topk_finished * -INF 
         | 
| 160 | 
            +
                    new_alive_log_probs, new_alive_idx = torch.topk(topk_log_probs, self.n_drafts, dim=1)
         | 
| 161 | 
            +
                    new_alive_seq = torch.take_along_dim(topk_seq, new_alive_idx.unsqueeze(2), dim=1)
         | 
| 162 | 
            +
                    return new_alive_seq, new_alive_log_probs
         | 
| 163 | 
            +
                    
         | 
| 164 | 
            +
                def grow_fin(self, topk_seq, topk_log_probs, topk_finished):
         | 
| 165 | 
            +
                    """
         | 
| 166 | 
            +
                    Update stopped generations. 
         | 
| 167 | 
            +
                    Args:
         | 
| 168 | 
            +
                        topk_seq (torch.Tensor): Top k sequences (n_prompts, 2 * n_drafts, vocab_size)
         | 
| 169 | 
            +
                        topk_log_probs (torch.Tensor): Log probabilities (n_prompts, 2 * n_drafts)
         | 
| 170 | 
            +
                        topk_finished (torch.Tensor): Whether a sequence is finished (n_prompts, 2 * n_drafts) 
         | 
| 171 | 
            +
                        
         | 
| 172 | 
            +
                    Returns:
         | 
| 173 | 
            +
                        new_fin_seq, new_fin_log_probs
         | 
| 174 | 
            +
                    """
         | 
| 175 | 
            +
                    topk_log_probs = topk_log_probs + ~topk_finished * -INF 
         | 
| 176 | 
            +
                    new_fin_seq = torch.cat([self.fin_seq, topk_seq], dim=1)
         | 
| 177 | 
            +
                    new_fin_log_probs = torch.cat([self.fin_log_probs, topk_log_probs], dim=1)
         | 
| 178 | 
            +
                    new_fin_log_probs, new_fin_idx = torch.topk(new_fin_log_probs, self.n_drafts, dim=1)
         | 
| 179 | 
            +
                    new_fin_seq = torch.take_along_dim(new_fin_seq, new_fin_idx.unsqueeze(2), dim=1)
         | 
| 180 | 
            +
                    return new_fin_seq, new_fin_log_probs
         | 
| 181 | 
            +
             | 
| 182 | 
            +
                def get_top_k(self, probs, m, k):
         | 
| 183 | 
            +
                    """
         | 
| 184 | 
            +
                    Zero out all but top-k tokens in a probability distribution.
         | 
| 185 | 
            +
                    Args:
         | 
| 186 | 
            +
                        probs (torch.Tensor): Probability distribution tensor.
         | 
| 187 | 
            +
                        m (float): Number of tokens to consider (only relevant when sampling).
         | 
| 188 | 
            +
                        k (int): Number of tokens to sample/keep.
         | 
| 189 | 
            +
                    Returns:
         | 
| 190 | 
            +
                        torch.Tensor: New probability distribution based on renormalized probabilities. 
         | 
| 191 | 
            +
                    """
         | 
| 192 | 
            +
                    n_prompts, _ = probs.shape 
         | 
| 193 | 
            +
                    probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
         | 
| 194 | 
            +
                    top_k_mask = torch.arange(probs.shape[-1])
         | 
| 195 | 
            +
                    top_k_mask = top_k_mask.expand(probs.shape[0], -1)
         | 
| 196 | 
            +
                    top_k_mask = top_k_mask >= m # Set to 1 past k elements
         | 
| 197 | 
            +
                    probs_sort[top_k_mask] = 0.0 # Zero wherever mask = 1
         | 
| 198 | 
            +
                    probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
         | 
| 199 | 
            +
                    next_token = torch.gather(probs_idx, -1, torch.topk(probs_sort, k, dim=-1)[1])       
         | 
| 200 | 
            +
                    # Set all other probs to 0
         | 
| 201 | 
            +
                    new_probs_map = torch.zeros(probs.shape).bool()
         | 
| 202 | 
            +
                    new_probs_map[torch.repeat_interleave(torch.arange(n_prompts), k), torch.flatten(next_token)] = True
         | 
| 203 | 
            +
                    new_probs = torch.where(new_probs_map, probs, 0)
         | 
| 204 | 
            +
                    # Renormalize
         | 
| 205 | 
            +
                    new_probs.div_(new_probs.sum(dim=-1, keepdim=True))
         | 
| 206 | 
            +
                    return new_probs
         | 
| 207 | 
            +
                
         | 
| 208 | 
            +
                def superposition_matrix(self, tokens):
         | 
| 209 | 
            +
                    """
         | 
| 210 | 
            +
                    Create superposition matrix based on provided tokens.
         | 
| 211 | 
            +
                    Args:
         | 
| 212 | 
            +
                        tokens (torch.Tensor): Tokens to mix (n_prompts, n_drafts)
         | 
| 213 | 
            +
                    Returns:
         | 
| 214 | 
            +
                        SUperposition matrix
         | 
| 215 | 
            +
                    """
         | 
| 216 | 
            +
                    # Create superposition matrix
         | 
| 217 | 
            +
                    mixing_matrix = torch.zeros(self.n_prompts, self.vocab_size)
         | 
| 218 | 
            +
                    # Convert draft log probs to probabilities
         | 
| 219 | 
            +
                    weightings = log_prob_to_prob(self.alive_log_probs)
         | 
| 220 | 
            +
                    # Update probabilities in superposition matrix with draft probabilities
         | 
| 221 | 
            +
                    for p_idx in range(self.n_prompts):
         | 
| 222 | 
            +
                        for d_idx in range(self.n_drafts):
         | 
| 223 | 
            +
                            tok_idx = tokens[p_idx][d_idx]
         | 
| 224 | 
            +
                            mixing_matrix[p_idx][tok_idx] += weightings[p_idx][d_idx]
         | 
| 225 | 
            +
                    # Renormalize
         | 
| 226 | 
            +
                    mixing_matrix.div_(mixing_matrix.sum(dim=-1, keepdims=True))
         | 
| 227 | 
            +
                    return mixing_matrix
         | 
| 228 | 
            +
                
         | 
| 229 | 
            +
                def ngram_probs(self, alive_seq, cur_pos, probs):
         | 
| 230 | 
            +
                    """
         | 
| 231 | 
            +
                    Calculate and return next token distribution using ngram models.
         | 
| 232 | 
            +
                    Args:
         | 
| 233 | 
            +
                        alive_seq (torch.Tensor): Current drafts (n_prompts, n_drafts, seqlen)
         | 
| 234 | 
            +
                        cur_pos (int): Current timestep
         | 
| 235 | 
            +
                        probs (torch.Tensor): Current next probability distribution from model (n_prompts, vocab_size).
         | 
| 236 | 
            +
                        As described in the paper, only tokens w/nonzero probability in `prob` are considered for the
         | 
| 237 | 
            +
                        ngram distribution. However, passing in `None` as `probs` will consider all tokens.
         | 
| 238 | 
            +
                    Returns:
         | 
| 239 | 
            +
                        Next token distribution for each draft (n_prompts, n_drafts, vocab_size)
         | 
| 240 | 
            +
                    """
         | 
| 241 | 
            +
                    if self.get_time:
         | 
| 242 | 
            +
                        # Start timer
         | 
| 243 | 
            +
                        start_time = datetime.now()
         | 
| 244 | 
            +
                    # Create distribution matrix
         | 
| 245 | 
            +
                    next_token_probs = torch.zeros(self.n_prompts, self.n_drafts, 32000)
         | 
| 246 | 
            +
                    if probs is not None:
         | 
| 247 | 
            +
                        # Loop over all prefixes
         | 
| 248 | 
            +
                        for p_idx in range(len(alive_seq)):
         | 
| 249 | 
            +
                            # List of possible tokens for the prefix
         | 
| 250 | 
            +
                            nz = torch.nonzero(probs[p_idx, :], as_tuple=True)[0].tolist()
         | 
| 251 | 
            +
                            # Generate next token distribution
         | 
| 252 | 
            +
                            for draft_idx in range(self.n_drafts):
         | 
| 253 | 
            +
                                i_mask = torch.sum(torch.tensor(self.i_length) <= cur_pos)
         | 
| 254 | 
            +
                                new_i_weights = self.interpolation_weights[:i_mask]
         | 
| 255 | 
            +
                                new_i_length = self.i_length[:i_mask]
         | 
| 256 | 
            +
                                # For each next token
         | 
| 257 | 
            +
                                for nt in nz:
         | 
| 258 | 
            +
                                    # Calculate probability using ngram interpolation
         | 
| 259 | 
            +
                                    for i, weight in zip(new_i_length, new_i_weights):
         | 
| 260 | 
            +
                                        if cur_pos - i >= 0:
         | 
| 261 | 
            +
                                            key = tuple(alive_seq[p_idx, draft_idx, cur_pos-i:cur_pos].tolist())
         | 
| 262 | 
            +
                                            if i == 1:
         | 
| 263 | 
            +
                                                prob = self.bigram.prob(key, nt)
         | 
| 264 | 
            +
                                            elif i == 2:
         | 
| 265 | 
            +
                                                prob = self.trigram.prob(key, nt)
         | 
| 266 | 
            +
                                            elif i == 3:
         | 
| 267 | 
            +
                                                prob = self.fourgram.prob(key, nt)
         | 
| 268 | 
            +
                                            elif i == 4:
         | 
| 269 | 
            +
                                                prob = self.fivegram.prob(key, nt)
         | 
| 270 | 
            +
                                            elif i == 5:
         | 
| 271 | 
            +
                                                prob = self.sixgram.prob(key, nt)
         | 
| 272 | 
            +
                                            elif i == 6:
         | 
| 273 | 
            +
                                                prob = self.sevengram.prob(key, nt)
         | 
| 274 | 
            +
                                        if prob >= 0:
         | 
| 275 | 
            +
                                            next_token_probs[p_idx, draft_idx, nt] += weight * prob
         | 
| 276 | 
            +
                    else:
         | 
| 277 | 
            +
                        for p_idx in range(len(alive_seq)):
         | 
| 278 | 
            +
                            for draft_idx in range(self.n_drafts):
         | 
| 279 | 
            +
                                i_mask = torch.sum(torch.tensor(self.i_length) <= cur_pos)
         | 
| 280 | 
            +
                                new_i_weights = self.interpolation_weights[:i_mask]
         | 
| 281 | 
            +
                                new_i_length = self.i_length[:i_mask]
         | 
| 282 | 
            +
                                for i, weight in zip(new_i_length, new_i_weights):
         | 
| 283 | 
            +
                                    if cur_pos - i >= 0:
         | 
| 284 | 
            +
                                        key = tuple(alive_seq[p_idx, draft_idx, cur_pos-i:cur_pos].tolist())
         | 
| 285 | 
            +
                                        if i == 1:
         | 
| 286 | 
            +
                                            ntd = self.bigram.ntd(key)
         | 
| 287 | 
            +
                                        elif i == 2:
         | 
| 288 | 
            +
                                            ntd = self.trigram.ntd(key)
         | 
| 289 | 
            +
                                        elif i == 3:
         | 
| 290 | 
            +
                                            ntd = self.fourgram.ntd(key)
         | 
| 291 | 
            +
                                        elif i == 4:
         | 
| 292 | 
            +
                                            ntd = self.fivegram.ntd(key)
         | 
| 293 | 
            +
                                        elif i == 5:
         | 
| 294 | 
            +
                                            ntd = self.sixgram.ntd(key)
         | 
| 295 | 
            +
                                        elif i == 6:
         | 
| 296 | 
            +
                                            ntd = self.sevengram.ntd(key)
         | 
| 297 | 
            +
                                    if ntd is not None:
         | 
| 298 | 
            +
                                        next_token_probs[p_idx, draft_idx, :] += weight * ntd
         | 
| 299 | 
            +
                    if self.get_time:    
         | 
| 300 | 
            +
                        total_time = datetime.now() - start_time
         | 
| 301 | 
            +
                        self.lookup_time = total_time if self.lookup_time is None else self.lookup_time + total_time
         | 
| 302 | 
            +
                    return next_token_probs
         | 
| 303 | 
            +
             | 
| 304 | 
            +
                def return_results(self, prompt_len=None):
         | 
| 305 | 
            +
                    """
         | 
| 306 | 
            +
                    Return generations and perplexities
         | 
| 307 | 
            +
                    
         | 
| 308 | 
            +
                    Args:
         | 
| 309 | 
            +
                        prompt_len (int): Length of prompt in tokens. If is None, then ppl is not calculated.
         | 
| 310 | 
            +
                    Returns:
         | 
| 311 | 
            +
                        (self.alive_seq, alive_ppl), (self.fin_seq, fin_ppl) 
         | 
| 312 | 
            +
                        OR
         | 
| 313 | 
            +
                        (self.alive_seq, alive_ppl), (self.fin_seq, fin_ppl), self.lookup_time
         | 
| 314 | 
            +
                    """
         | 
| 315 | 
            +
                    # PPL
         | 
| 316 | 
            +
                    alive_ppl = 0
         | 
| 317 | 
            +
                    fin_ppl = 0
         | 
| 318 | 
            +
                    if prompt_len is not None:
         | 
| 319 | 
            +
                        alive_ppl = torch.exp(self.alive_log_probs / (-1 * (self.alive_seq.size(dim=-1)-prompt_len)))      
         | 
| 320 | 
            +
                        # Fin ppl
         | 
| 321 | 
            +
                        fin_seq_lengths = (self.fin_seq != self.tokenizer.pad_id).sum(dim=-1)
         | 
| 322 | 
            +
                        fin_ppl = torch.exp(self.fin_log_probs / (-1 * (fin_seq_lengths - prompt_len)))
         | 
| 323 | 
            +
                        fin_ppl += ((fin_ppl == 0) * float("inf"))
         | 
| 324 | 
            +
                    # print time
         | 
| 325 | 
            +
                    if not self.get_time:
         | 
| 326 | 
            +
                        return (self.alive_seq.to(torch.long), alive_ppl), (self.fin_seq.to(torch.long), fin_ppl)
         | 
| 327 | 
            +
                    else:
         | 
| 328 | 
            +
                        return (self.alive_seq.to(torch.long), alive_ppl), (self.fin_seq.to(torch.long), fin_ppl), self.lookup_time
         | 
    	
        superposed/llama/superposed_generation.py
    ADDED
    
    | @@ -0,0 +1,198 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import json
         | 
| 5 | 
            +
            import os
         | 
| 6 | 
            +
            import sys
         | 
| 7 | 
            +
            import time
         | 
| 8 | 
            +
            from pathlib import Path
         | 
| 9 | 
            +
            from typing import List, Optional
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            import torch
         | 
| 12 | 
            +
            import torch.nn.functional as F
         | 
| 13 | 
            +
            from fairscale.nn.model_parallel.initialize import (
         | 
| 14 | 
            +
                get_model_parallel_rank,
         | 
| 15 | 
            +
                initialize_model_parallel,
         | 
| 16 | 
            +
                model_parallel_is_initialized,
         | 
| 17 | 
            +
            )
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            from superposed.llama.model import ModelArgs
         | 
| 20 | 
            +
            from superposed.llama.superposed_model import SuperposedTransformer
         | 
| 21 | 
            +
            from superposed.llama.tokenizer import Tokenizer
         | 
| 22 | 
            +
            from superposed.llama.superpose import Superpose
         | 
| 23 | 
            +
            from superposed.llama.utils import *
         | 
| 24 | 
            +
            from superposed.ngrams.ngram_models import make_models
         | 
| 25 | 
            +
             | 
| 26 | 
            +
            class SuperposedLlama:
         | 
| 27 | 
            +
                @staticmethod
         | 
| 28 | 
            +
                def build(
         | 
| 29 | 
            +
                    ckpt_dir: str,
         | 
| 30 | 
            +
                    tokenizer_path: str,
         | 
| 31 | 
            +
                    max_seq_len: int,
         | 
| 32 | 
            +
                    max_batch_size: int,
         | 
| 33 | 
            +
                    device = None,
         | 
| 34 | 
            +
                    model_parallel_size: Optional[int] = None,
         | 
| 35 | 
            +
                    seed: int = 1,
         | 
| 36 | 
            +
                ):
         | 
| 37 | 
            +
                    if not torch.distributed.is_initialized():
         | 
| 38 | 
            +
                        torch.distributed.init_process_group("nccl")
         | 
| 39 | 
            +
                    if not model_parallel_is_initialized():
         | 
| 40 | 
            +
                        if model_parallel_size is None:
         | 
| 41 | 
            +
                            model_parallel_size = int(os.environ.get("WORLD_SIZE", 1))
         | 
| 42 | 
            +
                        initialize_model_parallel(model_parallel_size)
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                    local_rank = int(os.environ.get("LOCAL_RANK", 0))
         | 
| 45 | 
            +
                    if device == None:
         | 
| 46 | 
            +
                        torch.cuda.set_device(local_rank)
         | 
| 47 | 
            +
                        device = torch.cuda.current_device()
         | 
| 48 | 
            +
                    torch.manual_seed(seed)
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                    if local_rank > 0:
         | 
| 51 | 
            +
                        sys.stdout = open(os.devnull, "w")
         | 
| 52 | 
            +
             | 
| 53 | 
            +
                    start_time = time.time()
         | 
| 54 | 
            +
                    checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
         | 
| 55 | 
            +
                    assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}"
         | 
| 56 | 
            +
                    assert model_parallel_size == len(
         | 
| 57 | 
            +
                        checkpoints
         | 
| 58 | 
            +
                    ), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {model_parallel_size}"
         | 
| 59 | 
            +
                    ckpt_path = checkpoints[get_model_parallel_rank()]
         | 
| 60 | 
            +
                    checkpoint = torch.load(ckpt_path, map_location="cpu")
         | 
| 61 | 
            +
                    with open(Path(ckpt_dir) / "params.json", "r") as f:
         | 
| 62 | 
            +
                        params = json.loads(f.read())
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                    model_args: ModelArgs = ModelArgs(
         | 
| 65 | 
            +
                        max_seq_len=max_seq_len,
         | 
| 66 | 
            +
                        max_batch_size=max_batch_size,
         | 
| 67 | 
            +
                        **params,
         | 
| 68 | 
            +
                    )
         | 
| 69 | 
            +
                    tokenizer = Tokenizer(model_path=tokenizer_path)
         | 
| 70 | 
            +
                    model_args.vocab_size = tokenizer.n_words
         | 
| 71 | 
            +
                    torch.set_default_tensor_type(torch.cuda.HalfTensor)
         | 
| 72 | 
            +
                    # Set up superposed decoding
         | 
| 73 | 
            +
                    model = SuperposedTransformer(model_args)
         | 
| 74 | 
            +
                    model.load_state_dict(checkpoint, strict=False)
         | 
| 75 | 
            +
                    print(f"Loaded in {time.time() - start_time:.2f} seconds")
         | 
| 76 | 
            +
                    return SuperposedLlama(model, tokenizer, device)
         | 
| 77 | 
            +
             | 
| 78 | 
            +
                def __init__(self, model: SuperposedTransformer, tokenizer: Tokenizer, device):
         | 
| 79 | 
            +
                    print(device)
         | 
| 80 | 
            +
                    self.model = model.to(device).eval()
         | 
| 81 | 
            +
                    self.tokenizer = tokenizer
         | 
| 82 | 
            +
                    self.device = device
         | 
| 83 | 
            +
                    
         | 
| 84 | 
            +
                @torch.inference_mode()
         | 
| 85 | 
            +
                def sup_generate(
         | 
| 86 | 
            +
                    self,
         | 
| 87 | 
            +
                    prompt_tokens: List[List[int]],
         | 
| 88 | 
            +
                    smoothing,
         | 
| 89 | 
            +
                    max_gen_len: int,
         | 
| 90 | 
            +
                    n_token_sample: int,
         | 
| 91 | 
            +
                    alpha: int, # weight on bigram probs
         | 
| 92 | 
            +
                    temp: int,
         | 
| 93 | 
            +
                    n_drafts: int = 1, # number of beams
         | 
| 94 | 
            +
                    verbose: bool = False,
         | 
| 95 | 
            +
                    i_weights = None,
         | 
| 96 | 
            +
                    i_length = None,
         | 
| 97 | 
            +
                    ngrams = None,
         | 
| 98 | 
            +
                    get_time: bool = False,
         | 
| 99 | 
            +
                    penalty = 200
         | 
| 100 | 
            +
                ):
         | 
| 101 | 
            +
                    """
         | 
| 102 | 
            +
                    Run multi-sequence generation using superposed embeddings.
         | 
| 103 | 
            +
                    Args:
         | 
| 104 | 
            +
                        prompt_tokens (List[List[int]]): Initial tokenized prompts
         | 
| 105 | 
            +
                        max_gen_len (int): Maximum numbers of tokens to generate
         | 
| 106 | 
            +
                        alpha (float): Alpha value
         | 
| 107 | 
            +
                        temp (float): Temperature
         | 
| 108 | 
            +
                        n_drafts (int): Number of drafts
         | 
| 109 | 
            +
                        verbose (bool): Whether to save intermediate embeddings for analysis
         | 
| 110 | 
            +
                        bsz (int): Batch size (default = 16)
         | 
| 111 | 
            +
                        i_weights (List[float]): Ngram interpolation weights
         | 
| 112 | 
            +
                        i_length (List[int]): Ngram models to interpolate (1 for bigram, 2 for trigram, etc.)
         | 
| 113 | 
            +
                        ngrams (Tuple): Ngram models 
         | 
| 114 | 
            +
                        get_time (bool): Return information on time spent doing Ngram lookup
         | 
| 115 | 
            +
                        penalty (float): Penalty on uninterpolated drafts
         | 
| 116 | 
            +
                    Returns:
         | 
| 117 | 
            +
                        (alive_seq, alive_ppl), (fin_seq, fin_ppl): Tuple of (n_prompts, n_drafts, seqlen),
         | 
| 118 | 
            +
                        (n_prompts, n_drafts) for sequences still generating and sequences that have finished.
         | 
| 119 | 
            +
                    """
         | 
| 120 | 
            +
                    # Check batch size and prompt lengths
         | 
| 121 | 
            +
                    params = self.model.params
         | 
| 122 | 
            +
                    bsz = len(prompt_tokens)
         | 
| 123 | 
            +
                    assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)
         | 
| 124 | 
            +
             | 
| 125 | 
            +
                    min_prompt_len = min(len(t) for t in prompt_tokens)
         | 
| 126 | 
            +
                    max_prompt_len = max(len(t) for t in prompt_tokens)
         | 
| 127 | 
            +
                    prompt_len = min_prompt_len
         | 
| 128 | 
            +
                    assert max_prompt_len <= params.max_seq_len
         | 
| 129 | 
            +
                    total_len = min(params.max_seq_len, max_gen_len + max_prompt_len)
         | 
| 130 | 
            +
                    pad_id = self.tokenizer.pad_id
         | 
| 131 | 
            +
                    
         | 
| 132 | 
            +
                    # Initialize token tensor and pad where necessary
         | 
| 133 | 
            +
                    tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device=self.device)
         | 
| 134 | 
            +
                    for k, t in enumerate(prompt_tokens):
         | 
| 135 | 
            +
                        tokens[k, :len(t)] = torch.tensor(t, dtype=torch.long, device=self.device)
         | 
| 136 | 
            +
                    
         | 
| 137 | 
            +
                    # If no generation is possible
         | 
| 138 | 
            +
                    if min_prompt_len == total_len:
         | 
| 139 | 
            +
                        raise RuntimeError("no generation possible")
         | 
| 140 | 
            +
             | 
| 141 | 
            +
                    # Initialize decoding object
         | 
| 142 | 
            +
                    initial_tokens = tokens.unsqueeze(1).repeat(1, n_drafts, 1)
         | 
| 143 | 
            +
                    superpose = Superpose(initial_tokens, 
         | 
| 144 | 
            +
                                       tokenizer=self.tokenizer,
         | 
| 145 | 
            +
                                       vocab_size=params.vocab_size,
         | 
| 146 | 
            +
                                       smoothing=smoothing,
         | 
| 147 | 
            +
                                       alpha=alpha,
         | 
| 148 | 
            +
                                       i_weights=i_weights,
         | 
| 149 | 
            +
                                       i_length=i_length,
         | 
| 150 | 
            +
                                       ngrams=ngrams,
         | 
| 151 | 
            +
                                       get_time=get_time,
         | 
| 152 | 
            +
                                       penalty=penalty)
         | 
| 153 | 
            +
                    unseen_first = torch.ones(bsz)
         | 
| 154 | 
            +
                    # Superposition matrix
         | 
| 155 | 
            +
                    token_weights = torch.zeros(bsz, self.model.vocab_size)
         | 
| 156 | 
            +
                    if verbose:
         | 
| 157 | 
            +
                        state_list = []
         | 
| 158 | 
            +
                    prev_pos = 0
         | 
| 159 | 
            +
                    # Begin inference
         | 
| 160 | 
            +
                    for cur_pos in range(min_prompt_len, total_len):
         | 
| 161 | 
            +
                        input_text_mask = tokens != pad_id
         | 
| 162 | 
            +
                        # Take model step
         | 
| 163 | 
            +
                        if cur_pos == min_prompt_len:
         | 
| 164 | 
            +
                            token_weights = None
         | 
| 165 | 
            +
                        logits = self.model.forward(tokens[:, prev_pos:cur_pos], 
         | 
| 166 | 
            +
                                                    start_pos=prev_pos, 
         | 
| 167 | 
            +
                                                    token_weights=token_weights, 
         | 
| 168 | 
            +
                                                    verbose=verbose)
         | 
| 169 | 
            +
                        if verbose:
         | 
| 170 | 
            +
                            logits, states = logits
         | 
| 171 | 
            +
                        # Softmax
         | 
| 172 | 
            +
                        if temp > 0:
         | 
| 173 | 
            +
                            probs = torch.softmax(logits[:, -1] / temp, dim=-1)
         | 
| 174 | 
            +
                        else:
         | 
| 175 | 
            +
                            raise RuntimeError("Temperature must be greater than 0 while mixing")
         | 
| 176 | 
            +
                        if verbose:
         | 
| 177 | 
            +
                            states["end_probs"] = probs
         | 
| 178 | 
            +
                            state_list.append(states)
         | 
| 179 | 
            +
                        # Flag prompts on first generation
         | 
| 180 | 
            +
                        is_first = torch.mul(tokens[:, cur_pos] == pad_id, unseen_first)
         | 
| 181 | 
            +
                        unseen_first[is_first.nonzero(as_tuple=True)[0]] = 0
         | 
| 182 | 
            +
                        # Flag prompts not yet generating
         | 
| 183 | 
            +
                        still_prompt = input_text_mask[:, cur_pos]
         | 
| 184 | 
            +
                        # Superposition pass
         | 
| 185 | 
            +
                        token_weights = superpose(probs, still_prompt, is_first, cur_pos, n_token_sample)
         | 
| 186 | 
            +
                        # Do not superpose for prompts not yet generating
         | 
| 187 | 
            +
                        keep_idx = input_text_mask[:, cur_pos].ravel().nonzero()
         | 
| 188 | 
            +
                        keep_token_weights = torch.zeros_like(token_weights)
         | 
| 189 | 
            +
                        keep_token_weights[keep_idx, tokens[keep_idx, cur_pos]] = 1
         | 
| 190 | 
            +
                        token_weights = torch.where(input_text_mask[:, cur_pos].unsqueeze(1).expand(-1, self.model.vocab_size), 
         | 
| 191 | 
            +
                                                    keep_token_weights, token_weights)
         | 
| 192 | 
            +
                        prev_pos = cur_pos
         | 
| 193 | 
            +
                    results = superpose.return_results(prompt_len)
         | 
| 194 | 
            +
                    if verbose:
         | 
| 195 | 
            +
                        torch.save(state_list, "../embeddings.pt")
         | 
| 196 | 
            +
                        return results
         | 
| 197 | 
            +
                    else:
         | 
| 198 | 
            +
                        return results
         | 
    	
        superposed/llama/superposed_model.py
    ADDED
    
    | @@ -0,0 +1,515 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import math
         | 
| 5 | 
            +
            from dataclasses import dataclass
         | 
| 6 | 
            +
            from typing import Optional, Tuple
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            import fairscale.nn.model_parallel.initialize as fs_init
         | 
| 9 | 
            +
            import torch
         | 
| 10 | 
            +
            import torch.nn.functional as F
         | 
| 11 | 
            +
            from fairscale.nn.model_parallel.layers import (
         | 
| 12 | 
            +
                ColumnParallelLinear,
         | 
| 13 | 
            +
                ParallelEmbedding,
         | 
| 14 | 
            +
                RowParallelLinear,
         | 
| 15 | 
            +
            )
         | 
| 16 | 
            +
            from torch import nn
         | 
| 17 | 
            +
             | 
| 18 | 
            +
             | 
| 19 | 
            +
            @dataclass
         | 
| 20 | 
            +
            class ModelArgs:
         | 
| 21 | 
            +
                dim: int = 4096
         | 
| 22 | 
            +
                n_layers: int = 32
         | 
| 23 | 
            +
                n_heads: int = 32
         | 
| 24 | 
            +
                n_kv_heads: Optional[int] = None
         | 
| 25 | 
            +
                vocab_size: int = -1  # defined later by tokenizer
         | 
| 26 | 
            +
                multiple_of: int = 256  # make SwiGLU hidden layer size multiple of large power of 2
         | 
| 27 | 
            +
                ffn_dim_multiplier: Optional[float] = None
         | 
| 28 | 
            +
                norm_eps: float = 1e-5
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                max_batch_size: int = 32
         | 
| 31 | 
            +
                max_seq_len: int = 2048
         | 
| 32 | 
            +
             | 
| 33 | 
            +
             | 
| 34 | 
            +
            class RMSNorm(torch.nn.Module):
         | 
| 35 | 
            +
                def __init__(self, dim: int, eps: float = 1e-6):
         | 
| 36 | 
            +
                    """
         | 
| 37 | 
            +
                    Initialize the RMSNorm normalization layer.
         | 
| 38 | 
            +
             | 
| 39 | 
            +
                    Args:
         | 
| 40 | 
            +
                        dim (int): The dimension of the input tensor.
         | 
| 41 | 
            +
                        eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
         | 
| 42 | 
            +
             | 
| 43 | 
            +
                    Attributes:
         | 
| 44 | 
            +
                        eps (float): A small value added to the denominator for numerical stability.
         | 
| 45 | 
            +
                        weight (nn.Parameter): Learnable scaling parameter.
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                    """
         | 
| 48 | 
            +
                    super().__init__()
         | 
| 49 | 
            +
                    self.eps = eps
         | 
| 50 | 
            +
                    self.weight = nn.Parameter(torch.ones(dim))
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                def _norm(self, x):
         | 
| 53 | 
            +
                    """
         | 
| 54 | 
            +
                    Apply the RMSNorm normalization to the input tensor.
         | 
| 55 | 
            +
             | 
| 56 | 
            +
                    Args:
         | 
| 57 | 
            +
                        x (torch.Tensor): The input tensor.
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                    Returns:
         | 
| 60 | 
            +
                        torch.Tensor: The normalized tensor.
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                    """
         | 
| 63 | 
            +
                    return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                def forward(self, x):
         | 
| 66 | 
            +
                    """
         | 
| 67 | 
            +
                    Forward pass through the RMSNorm layer.
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                    Args:
         | 
| 70 | 
            +
                        x (torch.Tensor): The input tensor.
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                    Returns:
         | 
| 73 | 
            +
                        torch.Tensor: The output tensor after applying RMSNorm.
         | 
| 74 | 
            +
             | 
| 75 | 
            +
                    """
         | 
| 76 | 
            +
                    output = self._norm(x.float()).type_as(x)
         | 
| 77 | 
            +
                    k = output * self.weight
         | 
| 78 | 
            +
                    return k
         | 
| 79 | 
            +
             | 
| 80 | 
            +
             | 
| 81 | 
            +
            def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
         | 
| 82 | 
            +
                """
         | 
| 83 | 
            +
                Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                This function calculates a frequency tensor with complex exponentials using the given dimension 'dim'
         | 
| 86 | 
            +
                and the end index 'end'. The 'theta' parameter scales the frequencies.
         | 
| 87 | 
            +
                The returned tensor contains complex values in complex64 data type.
         | 
| 88 | 
            +
             | 
| 89 | 
            +
                Args:
         | 
| 90 | 
            +
                    dim (int): Dimension of the frequency tensor.
         | 
| 91 | 
            +
                    end (int): End index for precomputing frequencies.
         | 
| 92 | 
            +
                    theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                Returns:
         | 
| 95 | 
            +
                    torch.Tensor: Precomputed frequency tensor with complex exponentials.
         | 
| 96 | 
            +
             | 
| 97 | 
            +
                
         | 
| 98 | 
            +
                    
         | 
| 99 | 
            +
             | 
| 100 | 
            +
                """
         | 
| 101 | 
            +
                freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
         | 
| 102 | 
            +
                t = torch.arange(end, device=freqs.device)  # type: ignore
         | 
| 103 | 
            +
                freqs = torch.outer(t, freqs).float()  # type: ignore
         | 
| 104 | 
            +
                freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64
         | 
| 105 | 
            +
                return freqs_cis
         | 
| 106 | 
            +
             | 
| 107 | 
            +
             | 
| 108 | 
            +
            def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
         | 
| 109 | 
            +
                """
         | 
| 110 | 
            +
                Reshape frequency tensor for broadcasting it with another tensor.
         | 
| 111 | 
            +
             | 
| 112 | 
            +
                This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
         | 
| 113 | 
            +
                for the purpose of broadcasting the frequency tensor during element-wise operations.
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                Args:
         | 
| 116 | 
            +
                    freqs_cis (torch.Tensor): Frequency tensor to be reshaped.
         | 
| 117 | 
            +
                    x (torch.Tensor): Target tensor for broadcasting compatibility.
         | 
| 118 | 
            +
             | 
| 119 | 
            +
                Returns:
         | 
| 120 | 
            +
                    torch.Tensor: Reshaped frequency tensor.
         | 
| 121 | 
            +
             | 
| 122 | 
            +
                Raises:
         | 
| 123 | 
            +
                    AssertionError: If the frequency tensor doesn't match the expected shape.
         | 
| 124 | 
            +
                    AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions.
         | 
| 125 | 
            +
                """
         | 
| 126 | 
            +
                ndim = x.ndim
         | 
| 127 | 
            +
                assert 0 <= 1 < ndim
         | 
| 128 | 
            +
                assert freqs_cis.shape == (x.shape[1], x.shape[-1])
         | 
| 129 | 
            +
                shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
         | 
| 130 | 
            +
                return freqs_cis.view(*shape)
         | 
| 131 | 
            +
             | 
| 132 | 
            +
             | 
| 133 | 
            +
            def apply_rotary_emb(
         | 
| 134 | 
            +
                xq: torch.Tensor,
         | 
| 135 | 
            +
                xk: torch.Tensor,
         | 
| 136 | 
            +
                freqs_cis: torch.Tensor,
         | 
| 137 | 
            +
            ) -> Tuple[torch.Tensor, torch.Tensor]:
         | 
| 138 | 
            +
                """
         | 
| 139 | 
            +
                Apply rotary embeddings to input tensors using the given frequency tensor.
         | 
| 140 | 
            +
             | 
| 141 | 
            +
                This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided
         | 
| 142 | 
            +
                frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor
         | 
| 143 | 
            +
                is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are
         | 
| 144 | 
            +
                returned as real tensors.
         | 
| 145 | 
            +
             | 
| 146 | 
            +
                Args:
         | 
| 147 | 
            +
                    xq (torch.Tensor): Query tensor to apply rotary embeddings.
         | 
| 148 | 
            +
                    xk (torch.Tensor): Key tensor to apply rotary embeddings.
         | 
| 149 | 
            +
                    freqs_cis (torch.Tensor): Precomputed frequency tensor for complex exponentials.
         | 
| 150 | 
            +
             | 
| 151 | 
            +
                Returns:
         | 
| 152 | 
            +
                    Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
         | 
| 153 | 
            +
             | 
| 154 | 
            +
                    
         | 
| 155 | 
            +
             | 
| 156 | 
            +
                """
         | 
| 157 | 
            +
                xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
         | 
| 158 | 
            +
                xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
         | 
| 159 | 
            +
                freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
         | 
| 160 | 
            +
                xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
         | 
| 161 | 
            +
                xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
         | 
| 162 | 
            +
                return xq_out.type_as(xq), xk_out.type_as(xk)
         | 
| 163 | 
            +
             | 
| 164 | 
            +
             | 
| 165 | 
            +
            def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
         | 
| 166 | 
            +
                """torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
         | 
| 167 | 
            +
                bs, slen, n_kv_heads, head_dim = x.shape
         | 
| 168 | 
            +
                if n_rep == 1:
         | 
| 169 | 
            +
                    return x
         | 
| 170 | 
            +
                return (
         | 
| 171 | 
            +
                    x[:, :, :, None, :]
         | 
| 172 | 
            +
                    .expand(bs, slen, n_kv_heads, n_rep, head_dim)
         | 
| 173 | 
            +
                    .reshape(bs, slen, n_kv_heads * n_rep, head_dim)
         | 
| 174 | 
            +
                )
         | 
| 175 | 
            +
             | 
| 176 | 
            +
             | 
| 177 | 
            +
            class Attention(nn.Module):
         | 
| 178 | 
            +
                """Multi-head attention module."""
         | 
| 179 | 
            +
                def __init__(self, args: ModelArgs):
         | 
| 180 | 
            +
                    """
         | 
| 181 | 
            +
                    Initialize the Attention module.
         | 
| 182 | 
            +
             | 
| 183 | 
            +
                    Args:
         | 
| 184 | 
            +
                        args (ModelArgs): Model configuration parameters.
         | 
| 185 | 
            +
             | 
| 186 | 
            +
                    Attributes:
         | 
| 187 | 
            +
                        n_kv_heads (int): Number of key and value heads.
         | 
| 188 | 
            +
                        n_local_heads (int): Number of local query heads.
         | 
| 189 | 
            +
                        n_local_kv_heads (int): Number of local key and value heads.
         | 
| 190 | 
            +
                        n_rep (int): Number of repetitions for local heads.
         | 
| 191 | 
            +
                        head_dim (int): Dimension size of each attention head.
         | 
| 192 | 
            +
                        wq (ColumnParallelLinear): Linear transformation for queries.
         | 
| 193 | 
            +
                        wk (ColumnParallelLinear): Linear transformation for keys.
         | 
| 194 | 
            +
                        wv (ColumnParallelLinear): Linear transformation for values.
         | 
| 195 | 
            +
                        wo (RowParallelLinear): Linear transformation for output.
         | 
| 196 | 
            +
                        cache_k (torch.Tensor): Cached keys for attention.
         | 
| 197 | 
            +
                        cache_v (torch.Tensor): Cached values for attention.
         | 
| 198 | 
            +
             | 
| 199 | 
            +
                    """
         | 
| 200 | 
            +
                    super().__init__()
         | 
| 201 | 
            +
                    self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
         | 
| 202 | 
            +
                    model_parallel_size = fs_init.get_model_parallel_world_size()
         | 
| 203 | 
            +
                    self.n_local_heads = args.n_heads // model_parallel_size
         | 
| 204 | 
            +
                    self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
         | 
| 205 | 
            +
                    self.n_rep = self.n_local_heads // self.n_local_kv_heads
         | 
| 206 | 
            +
                    self.head_dim = args.dim // args.n_heads
         | 
| 207 | 
            +
             | 
| 208 | 
            +
                    self.wq = ColumnParallelLinear(
         | 
| 209 | 
            +
                        args.dim,
         | 
| 210 | 
            +
                        args.n_heads * self.head_dim,
         | 
| 211 | 
            +
                        bias=False,
         | 
| 212 | 
            +
                        gather_output=False,
         | 
| 213 | 
            +
                        init_method=lambda x: x,
         | 
| 214 | 
            +
                    )
         | 
| 215 | 
            +
                    self.wk = ColumnParallelLinear(
         | 
| 216 | 
            +
                        args.dim,
         | 
| 217 | 
            +
                        self.n_kv_heads * self.head_dim,
         | 
| 218 | 
            +
                        bias=False,
         | 
| 219 | 
            +
                        gather_output=False,
         | 
| 220 | 
            +
                        init_method=lambda x: x,
         | 
| 221 | 
            +
                    )
         | 
| 222 | 
            +
                    self.wv = ColumnParallelLinear(
         | 
| 223 | 
            +
                        args.dim,
         | 
| 224 | 
            +
                        self.n_kv_heads * self.head_dim,
         | 
| 225 | 
            +
                        bias=False,
         | 
| 226 | 
            +
                        gather_output=False,
         | 
| 227 | 
            +
                        init_method=lambda x: x,
         | 
| 228 | 
            +
                    )
         | 
| 229 | 
            +
                    self.wo = RowParallelLinear(
         | 
| 230 | 
            +
                        args.n_heads * self.head_dim,
         | 
| 231 | 
            +
                        args.dim,
         | 
| 232 | 
            +
                        bias=False,
         | 
| 233 | 
            +
                        input_is_parallel=True,
         | 
| 234 | 
            +
                        init_method=lambda x: x,
         | 
| 235 | 
            +
                    )
         | 
| 236 | 
            +
             | 
| 237 | 
            +
                    self.cache_k = torch.zeros(
         | 
| 238 | 
            +
                        (
         | 
| 239 | 
            +
                            args.max_batch_size,
         | 
| 240 | 
            +
                            args.max_seq_len,
         | 
| 241 | 
            +
                            self.n_local_kv_heads,
         | 
| 242 | 
            +
                            self.head_dim,
         | 
| 243 | 
            +
                        )
         | 
| 244 | 
            +
                    ).cuda()
         | 
| 245 | 
            +
                    self.cache_v = torch.zeros(
         | 
| 246 | 
            +
                        (
         | 
| 247 | 
            +
                            args.max_batch_size,
         | 
| 248 | 
            +
                            args.max_seq_len,
         | 
| 249 | 
            +
                            self.n_local_kv_heads,
         | 
| 250 | 
            +
                            self.head_dim,
         | 
| 251 | 
            +
                        )
         | 
| 252 | 
            +
                    ).cuda()
         | 
| 253 | 
            +
             | 
| 254 | 
            +
                def forward(
         | 
| 255 | 
            +
                    self,
         | 
| 256 | 
            +
                    x: torch.Tensor,
         | 
| 257 | 
            +
                    start_pos: int,
         | 
| 258 | 
            +
                    freqs_cis: torch.Tensor,
         | 
| 259 | 
            +
                    mask: Optional[torch.Tensor]
         | 
| 260 | 
            +
                ):
         | 
| 261 | 
            +
                    """
         | 
| 262 | 
            +
                    Forward pass of the attention module.
         | 
| 263 | 
            +
             | 
| 264 | 
            +
                    Args:
         | 
| 265 | 
            +
                        x (torch.Tensor): Input tensor.
         | 
| 266 | 
            +
                        start_pos (int): Starting position for caching.
         | 
| 267 | 
            +
                        freqs_cis (torch.Tensor): Precomputed frequency tensor.
         | 
| 268 | 
            +
                        mask (torch.Tensor, optional): Attention mask tensor.
         | 
| 269 | 
            +
             | 
| 270 | 
            +
                    Returns:
         | 
| 271 | 
            +
                        torch.Tensor: Output tensor after attention.
         | 
| 272 | 
            +
             | 
| 273 | 
            +
                    """
         | 
| 274 | 
            +
                    bsz, seqlen, _ = x.shape
         | 
| 275 | 
            +
                        
         | 
| 276 | 
            +
                    xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
         | 
| 277 | 
            +
             | 
| 278 | 
            +
                    xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
         | 
| 279 | 
            +
                    xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
         | 
| 280 | 
            +
                    xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
         | 
| 281 | 
            +
             | 
| 282 | 
            +
                    xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
         | 
| 283 | 
            +
             | 
| 284 | 
            +
                    self.cache_k = self.cache_k.to(xq)
         | 
| 285 | 
            +
                    self.cache_v = self.cache_v.to(xq)
         | 
| 286 | 
            +
             | 
| 287 | 
            +
                    self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
         | 
| 288 | 
            +
                    self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv
         | 
| 289 | 
            +
             | 
| 290 | 
            +
                    keys = self.cache_k[:bsz, : start_pos + seqlen]
         | 
| 291 | 
            +
                    values = self.cache_v[:bsz, : start_pos + seqlen]
         | 
| 292 | 
            +
             | 
| 293 | 
            +
                    # repeat k/v heads if n_kv_heads < n_heads
         | 
| 294 | 
            +
                    keys = repeat_kv(keys, self.n_rep)  # (bs, seqlen, n_local_heads, head_dim)
         | 
| 295 | 
            +
                    values = repeat_kv(values, self.n_rep)  # (bs, seqlen, n_local_heads, head_dim)
         | 
| 296 | 
            +
             | 
| 297 | 
            +
                    xq = xq.transpose(1, 2)  # (bs, n_local_heads, seqlen, head_dim)
         | 
| 298 | 
            +
                    keys = keys.transpose(1, 2)
         | 
| 299 | 
            +
                    values = values.transpose(1, 2)
         | 
| 300 | 
            +
                    scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
         | 
| 301 | 
            +
                    if mask is not None:
         | 
| 302 | 
            +
                        scores = scores + mask  # (bs, n_local_heads, seqlen, cache_len + seqlen)
         | 
| 303 | 
            +
                    scores = F.softmax(scores.float(), dim=-1).type_as(xq)
         | 
| 304 | 
            +
                    output = torch.matmul(scores, values)  # (bs, n_local_heads, seqlen, head_dim)
         | 
| 305 | 
            +
                    output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
         | 
| 306 | 
            +
                    return self.wo(output)
         | 
| 307 | 
            +
             | 
| 308 | 
            +
             | 
| 309 | 
            +
            class FeedForward(nn.Module):
         | 
| 310 | 
            +
                def __init__(
         | 
| 311 | 
            +
                    self,
         | 
| 312 | 
            +
                    dim: int,
         | 
| 313 | 
            +
                    hidden_dim: int,
         | 
| 314 | 
            +
                    multiple_of: int,
         | 
| 315 | 
            +
                    ffn_dim_multiplier: Optional[float],
         | 
| 316 | 
            +
                ):
         | 
| 317 | 
            +
                    """
         | 
| 318 | 
            +
                    Initialize the FeedForward module.
         | 
| 319 | 
            +
             | 
| 320 | 
            +
                    Args:
         | 
| 321 | 
            +
                        dim (int): Input dimension.
         | 
| 322 | 
            +
                        hidden_dim (int): Hidden dimension of the feedforward layer.
         | 
| 323 | 
            +
                        multiple_of (int): Value to ensure hidden dimension is a multiple of this value.
         | 
| 324 | 
            +
                        ffn_dim_multiplier (float, optional): Custom multiplier for hidden dimension. Defaults to None.
         | 
| 325 | 
            +
             | 
| 326 | 
            +
                    Attributes:
         | 
| 327 | 
            +
                        w1 (ColumnParallelLinear): Linear transformation for the first layer.
         | 
| 328 | 
            +
                        w2 (RowParallelLinear): Linear transformation for the second layer.
         | 
| 329 | 
            +
                        w3 (ColumnParallelLinear): Linear transformation for the third layer.
         | 
| 330 | 
            +
             | 
| 331 | 
            +
                    """
         | 
| 332 | 
            +
                    super().__init__()
         | 
| 333 | 
            +
                    hidden_dim = int(2 * hidden_dim / 3)
         | 
| 334 | 
            +
                    # custom dim factor multiplier
         | 
| 335 | 
            +
                    if ffn_dim_multiplier is not None:
         | 
| 336 | 
            +
                        hidden_dim = int(ffn_dim_multiplier * hidden_dim)
         | 
| 337 | 
            +
                    hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
         | 
| 338 | 
            +
             | 
| 339 | 
            +
                    self.w1 = ColumnParallelLinear(
         | 
| 340 | 
            +
                        dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x
         | 
| 341 | 
            +
                    )
         | 
| 342 | 
            +
                    self.w2 = RowParallelLinear(
         | 
| 343 | 
            +
                        hidden_dim, dim, bias=False, input_is_parallel=True, init_method=lambda x: x
         | 
| 344 | 
            +
                    )
         | 
| 345 | 
            +
                    self.w3 = ColumnParallelLinear(
         | 
| 346 | 
            +
                        dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x
         | 
| 347 | 
            +
                    )
         | 
| 348 | 
            +
             | 
| 349 | 
            +
                def forward(self, x):
         | 
| 350 | 
            +
                    return self.w2(F.silu(self.w1(x)) * self.w3(x))
         | 
| 351 | 
            +
             | 
| 352 | 
            +
             | 
| 353 | 
            +
            class MixedTransformerBlock(nn.Module):
         | 
| 354 | 
            +
                def __init__(self, layer_id: int, args: ModelArgs):
         | 
| 355 | 
            +
                    """
         | 
| 356 | 
            +
                    Initialize a TransformerBlock.
         | 
| 357 | 
            +
             | 
| 358 | 
            +
                    Args:
         | 
| 359 | 
            +
                        layer_id (int): Identifier for the layer.
         | 
| 360 | 
            +
                        args (ModelArgs): Model configuration parameters.
         | 
| 361 | 
            +
             | 
| 362 | 
            +
                    Attributes:
         | 
| 363 | 
            +
                        n_heads (int): Number of attention heads.
         | 
| 364 | 
            +
                        dim (int): Dimension size of the model.
         | 
| 365 | 
            +
                        head_dim (int): Dimension size of each attention head.
         | 
| 366 | 
            +
                        attention (Attention): Attention module.
         | 
| 367 | 
            +
                        feed_forward (FeedForward): FeedForward module.
         | 
| 368 | 
            +
                        layer_id (int): Identifier for the layer.
         | 
| 369 | 
            +
                        attention_norm (RMSNorm): Layer normalization for attention output.
         | 
| 370 | 
            +
                        ffn_norm (RMSNorm): Layer normalization for feedforward output.
         | 
| 371 | 
            +
             | 
| 372 | 
            +
                    """
         | 
| 373 | 
            +
                    super().__init__()
         | 
| 374 | 
            +
                    self.n_heads = args.n_heads
         | 
| 375 | 
            +
                    self.dim = args.dim
         | 
| 376 | 
            +
                    self.head_dim = args.dim // args.n_heads
         | 
| 377 | 
            +
                    self.attention = Attention(args)
         | 
| 378 | 
            +
                    self.feed_forward = FeedForward(
         | 
| 379 | 
            +
                        dim=args.dim,
         | 
| 380 | 
            +
                        hidden_dim=4 * args.dim,
         | 
| 381 | 
            +
                        multiple_of=args.multiple_of,
         | 
| 382 | 
            +
                        ffn_dim_multiplier=args.ffn_dim_multiplier,
         | 
| 383 | 
            +
                    )
         | 
| 384 | 
            +
                    self.layer_id = layer_id
         | 
| 385 | 
            +
                    self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
         | 
| 386 | 
            +
                    self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
         | 
| 387 | 
            +
             | 
| 388 | 
            +
                def forward(
         | 
| 389 | 
            +
                    self,
         | 
| 390 | 
            +
                    x: torch.Tensor,
         | 
| 391 | 
            +
                    start_pos: int,
         | 
| 392 | 
            +
                    freqs_cis: torch.Tensor,
         | 
| 393 | 
            +
                    mask: Optional[torch.Tensor]
         | 
| 394 | 
            +
                ):
         | 
| 395 | 
            +
                    """
         | 
| 396 | 
            +
                    Perform a forward pass through the TransformerBlock.
         | 
| 397 | 
            +
             | 
| 398 | 
            +
                    Args:
         | 
| 399 | 
            +
                        x (torch.Tensor): Input tensor.
         | 
| 400 | 
            +
                        start_pos (int): Starting position for attention caching.
         | 
| 401 | 
            +
                        freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies.
         | 
| 402 | 
            +
                        mask (torch.Tensor, optional): Masking tensor for attention. Defaults to None.
         | 
| 403 | 
            +
             | 
| 404 | 
            +
                    Returns:
         | 
| 405 | 
            +
                        torch.Tensor: Output tensor after applying attention and feedforward layers.
         | 
| 406 | 
            +
             | 
| 407 | 
            +
                    """
         | 
| 408 | 
            +
                    h = x + self.attention.forward(
         | 
| 409 | 
            +
                        self.attention_norm(x), start_pos, freqs_cis, mask
         | 
| 410 | 
            +
                    )
         | 
| 411 | 
            +
                    out = h + self.feed_forward.forward(self.ffn_norm(h))
         | 
| 412 | 
            +
                    return out
         | 
| 413 | 
            +
             | 
| 414 | 
            +
            class SuperposedTransformer(nn.Module):
         | 
| 415 | 
            +
                def __init__(self, params: ModelArgs):
         | 
| 416 | 
            +
                    """
         | 
| 417 | 
            +
                    Initialize a Transformer model.
         | 
| 418 | 
            +
             | 
| 419 | 
            +
                    Args:
         | 
| 420 | 
            +
                        params (ModelArgs): Model configuration parameters.
         | 
| 421 | 
            +
             | 
| 422 | 
            +
                    Attributes:
         | 
| 423 | 
            +
                        params (ModelArgs): Model configuration parameters.
         | 
| 424 | 
            +
                        vocab_size (int): Vocabulary size.
         | 
| 425 | 
            +
                        n_layers (int): Number of layers in the model.
         | 
| 426 | 
            +
                        tok_embeddings (ParallelEmbedding): Token embeddings.
         | 
| 427 | 
            +
                        layers (torch.nn.ModuleList): List of Transformer blocks.
         | 
| 428 | 
            +
                        norm (RMSNorm): Layer normalization for the model output.
         | 
| 429 | 
            +
                        output (ColumnParallelLinear): Linear layer for final output.
         | 
| 430 | 
            +
                        freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies.
         | 
| 431 | 
            +
             | 
| 432 | 
            +
                    """
         | 
| 433 | 
            +
                    super().__init__()
         | 
| 434 | 
            +
                    self.params = params
         | 
| 435 | 
            +
                    self.vocab_size = params.vocab_size
         | 
| 436 | 
            +
                    self.n_layers = params.n_layers
         | 
| 437 | 
            +
             | 
| 438 | 
            +
                    self.tok_embeddings = ParallelEmbedding(
         | 
| 439 | 
            +
                        params.vocab_size, params.dim, init_method=lambda x: x
         | 
| 440 | 
            +
                    )
         | 
| 441 | 
            +
                    
         | 
| 442 | 
            +
                    self.tok_mixing_embeddings = ColumnParallelLinear(
         | 
| 443 | 
            +
                        params.vocab_size, params.dim, bias=False, init_method=lambda x: x
         | 
| 444 | 
            +
                    ) # dims here are formality (what matters is below)
         | 
| 445 | 
            +
                    self.tok_mixing_embeddings.weight = nn.Parameter(self.tok_embeddings.weight.T)
         | 
| 446 | 
            +
             | 
| 447 | 
            +
                    self.layers = torch.nn.ModuleList()
         | 
| 448 | 
            +
                    for layer_id in range(params.n_layers):
         | 
| 449 | 
            +
                        self.layers.append(MixedTransformerBlock(layer_id, params))
         | 
| 450 | 
            +
             | 
| 451 | 
            +
                    self.norm = RMSNorm(params.dim, eps=params.norm_eps)
         | 
| 452 | 
            +
                    self.output = ColumnParallelLinear(
         | 
| 453 | 
            +
                        params.dim, params.vocab_size, bias=False, init_method=lambda x: x
         | 
| 454 | 
            +
                    )
         | 
| 455 | 
            +
             | 
| 456 | 
            +
                    self.freqs_cis = precompute_freqs_cis(
         | 
| 457 | 
            +
                        # Note that self.params.max_seq_len is multiplied by 2 because the token limit for the Llama 2 generation of models is 4096. 
         | 
| 458 | 
            +
                        # Adding this multiplier instead of using 4096 directly allows for dynamism of token lengths while training or fine-tuning.
         | 
| 459 | 
            +
                        self.params.dim // self.params.n_heads, self.params.max_seq_len * 2
         | 
| 460 | 
            +
                    )
         | 
| 461 | 
            +
             | 
| 462 | 
            +
                @torch.inference_mode()
         | 
| 463 | 
            +
                def forward(self, 
         | 
| 464 | 
            +
                            tokens: torch.Tensor, 
         | 
| 465 | 
            +
                            start_pos: int, 
         | 
| 466 | 
            +
                            token_weights: Optional[torch.Tensor], 
         | 
| 467 | 
            +
                            verbose: Optional[bool] = False):
         | 
| 468 | 
            +
                    """
         | 
| 469 | 
            +
                    Perform a forward pass through the Transformer model.
         | 
| 470 | 
            +
             | 
| 471 | 
            +
                    Args:
         | 
| 472 | 
            +
                        tokens (torch.Tensor): Input token indices.
         | 
| 473 | 
            +
                        start_pos (int): Starting position for attention caching.
         | 
| 474 | 
            +
                        token_weights (torch.Tensor): Superposition matrix.
         | 
| 475 | 
            +
                        verbose (bool): Whether to return intermediate hidden layer states
         | 
| 476 | 
            +
             | 
| 477 | 
            +
                    Returns:
         | 
| 478 | 
            +
                        torch.Tensor or (torch.Tensor, Dict): Output logits after applying the Transformer model.
         | 
| 479 | 
            +
             | 
| 480 | 
            +
                    """
         | 
| 481 | 
            +
                    if verbose:
         | 
| 482 | 
            +
                        states = {"layers": [], "weights": None}
         | 
| 483 | 
            +
                    _bsz, seqlen = tokens.shape
         | 
| 484 | 
            +
                    if token_weights is not None:
         | 
| 485 | 
            +
                        h = self.tok_mixing_embeddings(token_weights.half()).unsqueeze(1)
         | 
| 486 | 
            +
                    else:
         | 
| 487 | 
            +
                        h = self.tok_embeddings(tokens)   
         | 
| 488 | 
            +
                    self.freqs_cis = self.freqs_cis.to(h.device)
         | 
| 489 | 
            +
                    freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]
         | 
| 490 | 
            +
                    if verbose:
         | 
| 491 | 
            +
                        states["layers"].append(h)
         | 
| 492 | 
            +
                        states["weights"] = token_weights
         | 
| 493 | 
            +
             | 
| 494 | 
            +
                    mask = None
         | 
| 495 | 
            +
                    if seqlen > 1:
         | 
| 496 | 
            +
                        mask = torch.full(
         | 
| 497 | 
            +
                            (1, 1, seqlen, seqlen), float("-inf"), device=tokens.device
         | 
| 498 | 
            +
                        )
         | 
| 499 | 
            +
                        mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h)
         | 
| 500 | 
            +
             | 
| 501 | 
            +
                    for layer in self.layers:
         | 
| 502 | 
            +
                            h = layer(h, start_pos, freqs_cis, mask)
         | 
| 503 | 
            +
                            if verbose:
         | 
| 504 | 
            +
                                states["layers"].append(h)
         | 
| 505 | 
            +
                            
         | 
| 506 | 
            +
                    h = self.norm(h)
         | 
| 507 | 
            +
                    if verbose:
         | 
| 508 | 
            +
                        states["layers"].append(h)
         | 
| 509 | 
            +
                    
         | 
| 510 | 
            +
                    output = self.output(h).float()
         | 
| 511 | 
            +
                    
         | 
| 512 | 
            +
                    if verbose:
         | 
| 513 | 
            +
                        return output, states
         | 
| 514 | 
            +
                    else:   
         | 
| 515 | 
            +
                        return output
         | 
    	
        superposed/llama/tokenizer.py
    ADDED
    
    | @@ -0,0 +1,68 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import os
         | 
| 5 | 
            +
            from logging import getLogger
         | 
| 6 | 
            +
            from typing import List
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            from sentencepiece import SentencePieceProcessor
         | 
| 9 | 
            +
             | 
| 10 | 
            +
             | 
| 11 | 
            +
            logger = getLogger()
         | 
| 12 | 
            +
             | 
| 13 | 
            +
             | 
| 14 | 
            +
            class Tokenizer:
         | 
| 15 | 
            +
                """tokenizing and encoding/decoding text using SentencePiece."""
         | 
| 16 | 
            +
                def __init__(self, model_path: str):
         | 
| 17 | 
            +
                    """
         | 
| 18 | 
            +
                    Initializes the Tokenizer with a SentencePiece model.
         | 
| 19 | 
            +
             | 
| 20 | 
            +
                    Args:
         | 
| 21 | 
            +
                        model_path (str): The path to the SentencePiece model file.
         | 
| 22 | 
            +
                    """
         | 
| 23 | 
            +
                    # reload tokenizer
         | 
| 24 | 
            +
                    assert os.path.isfile(model_path), model_path
         | 
| 25 | 
            +
                    self.sp_model = SentencePieceProcessor(model_file=model_path)
         | 
| 26 | 
            +
                    logger.info(f"Reloaded SentencePiece model from {model_path}")
         | 
| 27 | 
            +
             | 
| 28 | 
            +
                    # BOS / EOS token IDs
         | 
| 29 | 
            +
                    self.n_words: int = self.sp_model.vocab_size()
         | 
| 30 | 
            +
                    self.bos_id: int = self.sp_model.bos_id()
         | 
| 31 | 
            +
                    self.eos_id: int = self.sp_model.eos_id()
         | 
| 32 | 
            +
                    self.pad_id: int = self.sp_model.pad_id()
         | 
| 33 | 
            +
                    logger.info(
         | 
| 34 | 
            +
                        f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}"
         | 
| 35 | 
            +
                    )
         | 
| 36 | 
            +
                    assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()
         | 
| 37 | 
            +
             | 
| 38 | 
            +
                def encode(self, s: str, bos: bool, eos: bool) -> List[int]:
         | 
| 39 | 
            +
                    """
         | 
| 40 | 
            +
                    Encodes a string into a list of token IDs.
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                    Args:
         | 
| 43 | 
            +
                        s (str): The input string to be encoded.
         | 
| 44 | 
            +
                        bos (bool): Whether to prepend the beginning-of-sequence token.
         | 
| 45 | 
            +
                        eos (bool): Whether to append the end-of-sequence token.
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                    Returns:
         | 
| 48 | 
            +
                        List[int]: A list of token IDs.
         | 
| 49 | 
            +
                    """
         | 
| 50 | 
            +
                    assert type(s) is str
         | 
| 51 | 
            +
                    t = self.sp_model.encode(s)
         | 
| 52 | 
            +
                    if bos:
         | 
| 53 | 
            +
                        t = [self.bos_id] + t
         | 
| 54 | 
            +
                    if eos:
         | 
| 55 | 
            +
                        t = t + [self.eos_id]
         | 
| 56 | 
            +
                    return t
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                def decode(self, t: List[int]) -> str:
         | 
| 59 | 
            +
                    """
         | 
| 60 | 
            +
                    Decodes a list of token IDs into a string.
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                    Args:
         | 
| 63 | 
            +
                        t (List[int]): The list of token IDs to be decoded.
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                    Returns:
         | 
| 66 | 
            +
                        str: The decoded string.
         | 
| 67 | 
            +
                    """
         | 
| 68 | 
            +
                    return self.sp_model.decode(t)
         | 
    	
        superposed/llama/utils.py
    ADDED
    
    | @@ -0,0 +1,70 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            def log_prob_to_prob(log_probs, temp=1):
         | 
| 4 | 
            +
                """
         | 
| 5 | 
            +
                Convert log probabilities to probability distribution and normalize.
         | 
| 6 | 
            +
                Args:
         | 
| 7 | 
            +
                    log_probs (torch.Tensor): Log probs (n_prompts, n_drafts, vocab_size)
         | 
| 8 | 
            +
                Returns:
         | 
| 9 | 
            +
                    Probability distribution (n_prompts, n_drafts, vocab_size)
         | 
| 10 | 
            +
                """
         | 
| 11 | 
            +
                # stability constant
         | 
| 12 | 
            +
                log_probs = log_probs + torch.max(log_probs, dim=-1, keepdim=True)[0]
         | 
| 13 | 
            +
                probs = torch.softmax(log_probs / temp, dim=-1)
         | 
| 14 | 
            +
                return probs
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            def decode(tokenizer, encoding):
         | 
| 17 | 
            +
                """
         | 
| 18 | 
            +
                Decode a list of tokens to a string
         | 
| 19 | 
            +
                Args:
         | 
| 20 | 
            +
                    tokenizer (Any): Tokenizer
         | 
| 21 | 
            +
                    encoding (torch.Tensor): Encoding
         | 
| 22 | 
            +
                Returns:
         | 
| 23 | 
            +
                    decoding (str)
         | 
| 24 | 
            +
                """
         | 
| 25 | 
            +
                pad_locs = (encoding == -1).nonzero()
         | 
| 26 | 
            +
                if len(pad_locs > 0):
         | 
| 27 | 
            +
                    encoding = encoding[:pad_locs[0].item()]
         | 
| 28 | 
            +
                return tokenizer.decode(encoding.to(torch.int32).tolist())
         | 
| 29 | 
            +
             | 
| 30 | 
            +
            def print_gen(gens, logprobs, tokenizer, n_drafts, prompt_len, output_file):
         | 
| 31 | 
            +
                """
         | 
| 32 | 
            +
                Print out generations for debugging.
         | 
| 33 | 
            +
                Args:
         | 
| 34 | 
            +
                    gens (n_prompts * n_drafts, seq_len): Generations to print
         | 
| 35 | 
            +
                    logprobs (n_prompts * n_drafts): Log probs of each generation
         | 
| 36 | 
            +
                    tokenizer (any): Tokenizer
         | 
| 37 | 
            +
                    n_drafts (int): Number of drafts per prompt
         | 
| 38 | 
            +
                    prompt_len (int): Number of tokens in prompt
         | 
| 39 | 
            +
                """
         | 
| 40 | 
            +
                n_prompts, n_drafts, seq_len = gens.shape
         | 
| 41 | 
            +
                gens = gens.reshape(-1, seq_len)
         | 
| 42 | 
            +
                logprobs = logprobs.flatten()
         | 
| 43 | 
            +
                count = 0
         | 
| 44 | 
            +
                for i in range(len(gens)):
         | 
| 45 | 
            +
                    d = decode(tokenizer, gens[i])
         | 
| 46 | 
            +
                    # first draft of this prompt
         | 
| 47 | 
            +
                    if i % n_drafts == 0:
         | 
| 48 | 
            +
                        count = 0
         | 
| 49 | 
            +
                        print("---------------", file=output_file)
         | 
| 50 | 
            +
                        prompt = decode(tokenizer, gens[i][:prompt_len])
         | 
| 51 | 
            +
                        print(f"prompt: {prompt}", file=output_file)
         | 
| 52 | 
            +
                    print(f"logprob: {logprobs[i]} {count}: {d}", file=output_file)
         | 
| 53 | 
            +
                    count += 1
         | 
| 54 | 
            +
                    
         | 
| 55 | 
            +
            def print_probs(next_probs, tokenizer, output_file):
         | 
| 56 | 
            +
                """
         | 
| 57 | 
            +
                Print out next token options and probabilities for debugging
         | 
| 58 | 
            +
                Args:
         | 
| 59 | 
            +
                    next_probs (torch.Tensor): Next token probabilities (n_prompts, n_drafts, vocab_size)
         | 
| 60 | 
            +
                    tokenizer (any): Tokenizer
         | 
| 61 | 
            +
                """
         | 
| 62 | 
            +
                print("\tReminder: At most first n_drafts from seq can be selected.", file=output_file)
         | 
| 63 | 
            +
                n_prompts, n_drafts, vocab_size = next_probs.shape
         | 
| 64 | 
            +
                for p_idx in range(n_prompts):
         | 
| 65 | 
            +
                    print(f"\tPrompt {p_idx}:", file=output_file)
         | 
| 66 | 
            +
                    for d_idx in range(n_drafts):
         | 
| 67 | 
            +
                        next_token_probs, next_token_idx = next_probs[p_idx, d_idx].topk(n_drafts+2, dim=-1)
         | 
| 68 | 
            +
                        print(f"\t\tTokens: {[tokenizer.decode([i.item()]) for i in next_token_idx]}", file=output_file)
         | 
| 69 | 
            +
                        print(f"\t\tLog Probs: {torch.log(next_token_probs)}", file=output_file)
         | 
| 70 | 
            +
                        print(f"\t\tProbs: {next_token_probs}", file=output_file)
         | 
    	
        superposed/ngrams/__pycache__/ngram_models.cpython-312.pyc
    ADDED
    
    | Binary file (5.53 kB). View file | 
|  | 
    	
        superposed/ngrams/make_corpus.py
    ADDED
    
    | @@ -0,0 +1,268 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import multiprocessing
         | 
| 2 | 
            +
            import argparse
         | 
| 3 | 
            +
            import os
         | 
| 4 | 
            +
            import pickle
         | 
| 5 | 
            +
            import glob
         | 
| 6 | 
            +
            import json
         | 
| 7 | 
            +
            from datasets import load_dataset
         | 
| 8 | 
            +
            from tqdm import tqdm
         | 
| 9 | 
            +
            from transformers import AutoTokenizer, LlamaTokenizer
         | 
| 10 | 
            +
            from loguru import logger
         | 
| 11 | 
            +
             | 
| 12 | 
            +
             | 
| 13 | 
            +
            def create_corpuses(
         | 
| 14 | 
            +
                ckpt_path,
         | 
| 15 | 
            +
                start_doc,
         | 
| 16 | 
            +
                end_doc,
         | 
| 17 | 
            +
                dataset, 
         | 
| 18 | 
            +
                tokenizer, 
         | 
| 19 | 
            +
                train_bigram: bool, 
         | 
| 20 | 
            +
                train_trigram: bool,
         | 
| 21 | 
            +
                train_fourgram: bool,
         | 
| 22 | 
            +
                train_fivegram: bool,
         | 
| 23 | 
            +
                train_sixgram: bool,
         | 
| 24 | 
            +
                train_sevengram: bool
         | 
| 25 | 
            +
            ):
         | 
| 26 | 
            +
                bigram_corpus = {}
         | 
| 27 | 
            +
                trigram_corpus = {}
         | 
| 28 | 
            +
                fourgram_corpus = {}
         | 
| 29 | 
            +
                fivegram_corpus = {}
         | 
| 30 | 
            +
                sixgram_corpus = {}
         | 
| 31 | 
            +
                sevengram_corpus = {}
         | 
| 32 | 
            +
             | 
| 33 | 
            +
                bigram_corpus_counts = {}
         | 
| 34 | 
            +
                trigram_corpus_counts = {}
         | 
| 35 | 
            +
                fourgram_corpus_counts = {}
         | 
| 36 | 
            +
                fivegram_corpus_counts = {}
         | 
| 37 | 
            +
                sixgram_corpus_counts = {}
         | 
| 38 | 
            +
                sevengram_corpus_counts = {}
         | 
| 39 | 
            +
                
         | 
| 40 | 
            +
                iterations = end_doc - start_doc
         | 
| 41 | 
            +
                for i in tqdm(range(iterations)):
         | 
| 42 | 
            +
                  t = dataset[start_doc + i]["text"]
         | 
| 43 | 
            +
                  encoded_text = tokenizer.encode(t)
         | 
| 44 | 
            +
                  for start_idx in range(1, len(encoded_text)): # count from first real to eos
         | 
| 45 | 
            +
                    pOne = encoded_text[start_idx-1] if start_idx >= 1 else None
         | 
| 46 | 
            +
                    pTwo = encoded_text[start_idx-2] if start_idx >= 2 else None
         | 
| 47 | 
            +
                    pThree = encoded_text[start_idx-3] if start_idx >= 3 else None
         | 
| 48 | 
            +
                    pFour = encoded_text[start_idx-4] if start_idx >= 4 else None
         | 
| 49 | 
            +
                    pFive = encoded_text[start_idx-5] if start_idx >= 5 else None
         | 
| 50 | 
            +
                    pSix = encoded_text[start_idx-6] if start_idx >= 6 else None
         | 
| 51 | 
            +
                    
         | 
| 52 | 
            +
                    token = encoded_text[start_idx]
         | 
| 53 | 
            +
                    # bigram
         | 
| 54 | 
            +
                    if train_bigram and start_idx >= 1:
         | 
| 55 | 
            +
                      prior = pOne
         | 
| 56 | 
            +
                      if prior not in bigram_corpus:
         | 
| 57 | 
            +
                        bigram_corpus[prior] = {}
         | 
| 58 | 
            +
                        bigram_corpus_counts[prior] = 0
         | 
| 59 | 
            +
                      bigram_corpus[prior][token] = bigram_corpus[prior].get(token, 0) + 1
         | 
| 60 | 
            +
                      bigram_corpus_counts[prior] += 1
         | 
| 61 | 
            +
                    # trigram 
         | 
| 62 | 
            +
                    if train_trigram and start_idx >= 2:
         | 
| 63 | 
            +
                      prior = (pTwo, pOne)
         | 
| 64 | 
            +
                      if prior not in trigram_corpus:
         | 
| 65 | 
            +
                        trigram_corpus[prior] = {}
         | 
| 66 | 
            +
                        trigram_corpus_counts[prior] = 0
         | 
| 67 | 
            +
                      trigram_corpus[prior][token] = trigram_corpus[prior].get(token, 0) + 1
         | 
| 68 | 
            +
                      trigram_corpus_counts[prior] += 1
         | 
| 69 | 
            +
                    # fourgram
         | 
| 70 | 
            +
                    if train_fourgram and start_idx >= 3: 
         | 
| 71 | 
            +
                      prior = (pThree, pTwo, pOne)
         | 
| 72 | 
            +
                      if prior not in fourgram_corpus:
         | 
| 73 | 
            +
                        fourgram_corpus[prior] = {}
         | 
| 74 | 
            +
                        fourgram_corpus_counts[prior] = 0
         | 
| 75 | 
            +
                      fourgram_corpus[prior][token] = fourgram_corpus[prior].get(token, 0) + 1
         | 
| 76 | 
            +
                      fourgram_corpus_counts[prior] += 1     
         | 
| 77 | 
            +
                    # fivegram
         | 
| 78 | 
            +
                    if train_fivegram and start_idx >= 4:
         | 
| 79 | 
            +
                      prior = (pFour, pThree, pTwo, pOne)
         | 
| 80 | 
            +
                      if prior not in fivegram_corpus:
         | 
| 81 | 
            +
                        fivegram_corpus[prior] = {}
         | 
| 82 | 
            +
                        fivegram_corpus_counts[prior] = 0
         | 
| 83 | 
            +
                      fivegram_corpus[prior][token] = fivegram_corpus[prior].get(token, 0) + 1
         | 
| 84 | 
            +
                      fivegram_corpus_counts[prior] += 1            
         | 
| 85 | 
            +
                    # sixgram
         | 
| 86 | 
            +
                    if train_sixgram and start_idx >= 5:
         | 
| 87 | 
            +
                      prior = (pFive, pFour, pThree, pTwo, pOne)
         | 
| 88 | 
            +
                      if prior not in sixgram_corpus:
         | 
| 89 | 
            +
                        sixgram_corpus[prior] = {}
         | 
| 90 | 
            +
                        sixgram_corpus_counts[prior] = 0
         | 
| 91 | 
            +
                      sixgram_corpus[prior][token] = sixgram_corpus[prior].get(token, 0) + 1
         | 
| 92 | 
            +
                      sixgram_corpus_counts[prior] += 1     
         | 
| 93 | 
            +
                    # sevengram
         | 
| 94 | 
            +
                    if train_sevengram and start_idx >= 6:
         | 
| 95 | 
            +
                      prior = (pSix, pFive, pFour, pThree, pTwo, pOne)
         | 
| 96 | 
            +
                      if prior not in sevengram_corpus:
         | 
| 97 | 
            +
                        sevengram_corpus[prior] = {}
         | 
| 98 | 
            +
                        sevengram_corpus_counts[prior] = 0
         | 
| 99 | 
            +
                      sevengram_corpus[prior][token] = sevengram_corpus[prior].get(token, 0) + 1
         | 
| 100 | 
            +
                      sevengram_corpus_counts[prior] += 1  
         | 
| 101 | 
            +
                save_corpus(ckpt_path, bigram_corpus, trigram_corpus, fourgram_corpus, fivegram_corpus, sixgram_corpus, sevengram_corpus, start_doc, end_doc)
         | 
| 102 | 
            +
                save_counts(ckpt_path, bigram_corpus_counts, trigram_corpus_counts, fourgram_corpus_counts, fivegram_corpus_counts, sixgram_corpus_counts, sevengram_corpus_counts, start_doc, end_doc)    
         | 
| 103 | 
            +
             | 
| 104 | 
            +
            def merge_corpus_helper(c1, c2):
         | 
| 105 | 
            +
              """
         | 
| 106 | 
            +
              Merge the corpuses c1 and c2, returning the merged result.
         | 
| 107 | 
            +
              """
         | 
| 108 | 
            +
              for prior in c2:
         | 
| 109 | 
            +
                # if share prior
         | 
| 110 | 
            +
                if prior in c1:
         | 
| 111 | 
            +
                  c1_prior = c1[prior]
         | 
| 112 | 
            +
                  c2_prior = c2[prior]
         | 
| 113 | 
            +
                  for token in c2_prior:
         | 
| 114 | 
            +
                    # if share token
         | 
| 115 | 
            +
                    if token in c1_prior:
         | 
| 116 | 
            +
                      c1_prior[token] += c2_prior[token]
         | 
| 117 | 
            +
                    # else just use c2's
         | 
| 118 | 
            +
                    else:
         | 
| 119 | 
            +
                      c1_prior[token] = c2_prior[token]
         | 
| 120 | 
            +
                else:
         | 
| 121 | 
            +
                  # else just use c2's
         | 
| 122 | 
            +
                  c1[prior] = c2[prior]
         | 
| 123 | 
            +
              return c1
         | 
| 124 | 
            +
             | 
| 125 | 
            +
            def merge_counts_helper(c1, c2):
         | 
| 126 | 
            +
              """
         | 
| 127 | 
            +
              Merge the count corpuses c1 and c2, returning the merged result.
         | 
| 128 | 
            +
              """
         | 
| 129 | 
            +
              for prior in c2:
         | 
| 130 | 
            +
                if prior in c1:
         | 
| 131 | 
            +
                  c1[prior] += c2[prior]
         | 
| 132 | 
            +
                else:
         | 
| 133 | 
            +
                  c1[prior] = c2[prior]
         | 
| 134 | 
            +
              return c1
         | 
| 135 | 
            +
             | 
| 136 | 
            +
            def save_corpus(save_dir, b_d, t_d, fo_d, fi_d, si_d, se_d, start_doc, end_doc):
         | 
| 137 | 
            +
              """
         | 
| 138 | 
            +
              Save corpuses b_d (bigram) to se_d (sevengram), where the corpus contains mappings
         | 
| 139 | 
            +
              {prefix : {next_token1: ct, next_token2: ct, ...}}.
         | 
| 140 | 
            +
              """
         | 
| 141 | 
            +
              prefixes = ["b_d", "t_d", "fo_d", "fi_d", "si_d", "se_d"]
         | 
| 142 | 
            +
              for p, corpus in zip(prefixes, [b_d, t_d, fo_d, fi_d, si_d, se_d]):
         | 
| 143 | 
            +
                with open(f"{save_dir}/{p}{start_doc}-{end_doc}.pkl", "wb") as f:
         | 
| 144 | 
            +
                  pickle.dump(corpus, f)
         | 
| 145 | 
            +
             | 
| 146 | 
            +
            def save_counts(save_dir, b_ct, t_ct, fo_ct, fi_ct, si_ct, se_ct, start_doc, end_doc):
         | 
| 147 | 
            +
              """
         | 
| 148 | 
            +
              Save count corpuses b_ct (bigram) to se_ct (sevengram), where each count
         | 
| 149 | 
            +
              corpus contains mappings {prefix : total}.
         | 
| 150 | 
            +
              """
         | 
| 151 | 
            +
              prefixes = ["b_ct", "t_ct", "fo_ct", "fi_ct", "si_ct", "se_ct"]
         | 
| 152 | 
            +
              for p, corpus in zip(prefixes, [b_ct, t_ct, fo_ct, fi_ct, si_ct, se_ct]):
         | 
| 153 | 
            +
                with open(f"{save_dir}/{p}{start_doc}-{end_doc}.pkl", "wb") as f:
         | 
| 154 | 
            +
                  pickle.dump(corpus, f)
         | 
| 155 | 
            +
                  
         | 
| 156 | 
            +
            def merge_corpuses(ckpt_path):
         | 
| 157 | 
            +
              """
         | 
| 158 | 
            +
              Helper to merge corpuses in `ckpt_path`, where `ckpt_path` might contain
         | 
| 159 | 
            +
              multiple bigram, trigram, etc. corpuses from each process.
         | 
| 160 | 
            +
              """
         | 
| 161 | 
            +
              prefixes = ["b_d", "t_d", "fo_d", "fi_d", "si_d", "se_d"]
         | 
| 162 | 
            +
              for prefix in prefixes: 
         | 
| 163 | 
            +
                if os.path.exists(f"{ckpt_path}/{prefix}_final.pkl"):
         | 
| 164 | 
            +
                  os.remove(f"{ckpt_path}/{prefix}_final.pkl")
         | 
| 165 | 
            +
                corpus = None 
         | 
| 166 | 
            +
                for filepath in glob.glob(f"{ckpt_path}/{prefix}*"):
         | 
| 167 | 
            +
                  with open(filepath, "rb") as f:
         | 
| 168 | 
            +
                    current = pickle.load(f)
         | 
| 169 | 
            +
                    if corpus is None:
         | 
| 170 | 
            +
                      corpus = current
         | 
| 171 | 
            +
                    else:
         | 
| 172 | 
            +
                      corpus = merge_corpus_helper(corpus, current)
         | 
| 173 | 
            +
                  os.remove(filepath)
         | 
| 174 | 
            +
                with open(f"{ckpt_path}/{prefix}_final.pkl", "wb") as f: 
         | 
| 175 | 
            +
                  pickle.dump(corpus, f)
         | 
| 176 | 
            +
              
         | 
| 177 | 
            +
            def merge_counts(ckpt_path):
         | 
| 178 | 
            +
              """
         | 
| 179 | 
            +
              Helper to merge count corpuses in `ckpt_path`, where `ckpt_path` might contain
         | 
| 180 | 
            +
              multiple bigram, trigram, etc. count corpuses from each process.
         | 
| 181 | 
            +
              """
         | 
| 182 | 
            +
              prefixes = ["b_ct", "t_ct", "fo_ct", "fi_ct", "si_ct", "se_ct"]
         | 
| 183 | 
            +
              for prefix in prefixes: 
         | 
| 184 | 
            +
                if os.path.exists(f"{ckpt_path}/{prefix}_final.pkl"):
         | 
| 185 | 
            +
                  os.remove(f"{ckpt_path}/{prefix}_final.pkl")
         | 
| 186 | 
            +
                  
         | 
| 187 | 
            +
                counts = None 
         | 
| 188 | 
            +
                for filepath in glob.glob(f"{ckpt_path}/{prefix}*"):
         | 
| 189 | 
            +
                  with open(filepath, "rb") as f:
         | 
| 190 | 
            +
                    current = pickle.load(f)
         | 
| 191 | 
            +
                    if counts is None:
         | 
| 192 | 
            +
                      counts = current
         | 
| 193 | 
            +
                    else:
         | 
| 194 | 
            +
                      counts = merge_counts_helper(counts, current)
         | 
| 195 | 
            +
                  os.remove(filepath)
         | 
| 196 | 
            +
                with open(f"{ckpt_path}/{prefix}_final.pkl", "wb") as f: 
         | 
| 197 | 
            +
                  pickle.dump(counts, f)
         | 
| 198 | 
            +
              
         | 
| 199 | 
            +
             | 
| 200 | 
            +
            if __name__ == "__main__":
         | 
| 201 | 
            +
              # Input arguments
         | 
| 202 | 
            +
              parser = argparse.ArgumentParser()
         | 
| 203 | 
            +
              parser.add_argument("ckpt_path", type=str, help="Path to store ngram models")
         | 
| 204 | 
            +
              parser.add_argument("start_doc", type=str, help="# of first document")
         | 
| 205 | 
            +
              parser.add_argument("end_doc", type=str, help="# of last document")
         | 
| 206 | 
            +
              parser.add_argument("c", type=int, help="number of processes")
         | 
| 207 | 
            +
              parser.add_argument("--tok_name", type=str, help="name of HF tokenizer, or llama", default="llama")
         | 
| 208 | 
            +
              for arg_name in ["--bigram", "--trigram", "--fourgram", "--fivegram", "--sixgram", "--sevengram"]:
         | 
| 209 | 
            +
                parser.add_argument(arg_name, type=str, help=f"Whether to make a {arg_name} model")
         | 
| 210 | 
            +
              parser.add_argument("--dset_name", type=str, help="name of HF dataset")
         | 
| 211 | 
            +
              parser.add_argument("--dset_path", type=str, help="path to dataset")
         | 
| 212 | 
            +
              # Parse arguments
         | 
| 213 | 
            +
              args = parser.parse_args()
         | 
| 214 | 
            +
              start_doc_ovr = int(args.start_doc)
         | 
| 215 | 
            +
              end_doc_ovr = int(args.end_doc)
         | 
| 216 | 
            +
              n_cores = args.c
         | 
| 217 | 
            +
              tok_name = args.tok_name
         | 
| 218 | 
            +
              ckpt_path = args.ckpt_path
         | 
| 219 | 
            +
              dset_name = args.dset_name
         | 
| 220 | 
            +
              dset_path = args.dset_path
         | 
| 221 | 
            +
              if not dset_name and not dset_path:
         | 
| 222 | 
            +
                raise RuntimeError("Please provide a dataset")
         | 
| 223 | 
            +
              if not os.path.exists(ckpt_path):
         | 
| 224 | 
            +
                os.makedirs(ckpt_path)
         | 
| 225 | 
            +
              logger.info(f"{start_doc_ovr} {end_doc_ovr} {n_cores}")
         | 
| 226 | 
            +
              
         | 
| 227 | 
            +
              # Load dataset and tokenizer
         | 
| 228 | 
            +
              if dset_name:
         | 
| 229 | 
            +
                ds = load_dataset(dset_name, cache_dir="../../../datasets/")["train"].shuffle(seed=42)
         | 
| 230 | 
            +
              else:
         | 
| 231 | 
            +
                with open(dset_path, "r") as f:
         | 
| 232 | 
            +
                  ds = json.load(f)["train"]
         | 
| 233 | 
            +
              if tok_name == "llama":
         | 
| 234 | 
            +
                # REPLACE WITH YOUR OWN PATH
         | 
| 235 | 
            +
                tokenizer = LlamaTokenizer.from_pretrained("../../7B_HF", add_bos_token=False)
         | 
| 236 | 
            +
              else:
         | 
| 237 | 
            +
                tokenizer = AutoTokenizer.from_pretrained(tok_name)
         | 
| 238 | 
            +
             | 
| 239 | 
            +
              # Start running
         | 
| 240 | 
            +
              num_processes = n_cores
         | 
| 241 | 
            +
              total_docs = end_doc_ovr - start_doc_ovr
         | 
| 242 | 
            +
              docs_per_c = (total_docs) // num_processes
         | 
| 243 | 
            +
              processes = []
         | 
| 244 | 
            +
              for core in range(n_cores):
         | 
| 245 | 
            +
                start_doc = core * docs_per_c # relative start doc 
         | 
| 246 | 
            +
                end_doc = (core + 1) * docs_per_c if core < n_cores - 1 else total_docs # relative end doc
         | 
| 247 | 
            +
                logger.info(f"Starting core {core} from document {start_doc} to {end_doc}")
         | 
| 248 | 
            +
                process = multiprocessing.Process(target=create_corpuses, 
         | 
| 249 | 
            +
                                                  args=(ckpt_path,
         | 
| 250 | 
            +
                                                        start_doc_ovr + start_doc, 
         | 
| 251 | 
            +
                                                        start_doc_ovr + end_doc, 
         | 
| 252 | 
            +
                                                        ds, tokenizer, 
         | 
| 253 | 
            +
                                                        args.bigram, 
         | 
| 254 | 
            +
                                                        args.trigram, 
         | 
| 255 | 
            +
                                                        args.fourgram, 
         | 
| 256 | 
            +
                                                        args.fivegram, 
         | 
| 257 | 
            +
                                                        args.sixgram, 
         | 
| 258 | 
            +
                                                        args.sevengram))
         | 
| 259 | 
            +
                processes.append(process)
         | 
| 260 | 
            +
                process.start()
         | 
| 261 | 
            +
              for process in processes:
         | 
| 262 | 
            +
                process.join()
         | 
| 263 | 
            +
              logger.info("Finished Saving")
         | 
| 264 | 
            +
              logger.info("Merging...")
         | 
| 265 | 
            +
              merge_corpuses(ckpt_path)
         | 
| 266 | 
            +
              merge_counts(ckpt_path)
         | 
| 267 | 
            +
              logger.info("Merged.")
         | 
| 268 | 
            +
                  
         | 
    	
        superposed/ngrams/ngram_models.py
    ADDED
    
    | @@ -0,0 +1,115 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import pickle
         | 
| 2 | 
            +
            import sys
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import torch
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            class NGram():
         | 
| 7 | 
            +
              def __init__(self, corpus, corpus_counts, type):
         | 
| 8 | 
            +
                self.corpus = corpus
         | 
| 9 | 
            +
                self.counts = corpus_counts
         | 
| 10 | 
            +
                self.type = type
         | 
| 11 | 
            +
             | 
| 12 | 
            +
              def prob(self, key, next):
         | 
| 13 | 
            +
                """
         | 
| 14 | 
            +
                Args:
         | 
| 15 | 
            +
                  key (tuple): tuple of token ID's forming prior
         | 
| 16 | 
            +
                  next (int): probability of next token
         | 
| 17 | 
            +
                """
         | 
| 18 | 
            +
                l = len(key)
         | 
| 19 | 
            +
                if self.type == "bigram":
         | 
| 20 | 
            +
                  assert l == 1
         | 
| 21 | 
            +
                  key = key[0]
         | 
| 22 | 
            +
                elif self.type == "trigram":
         | 
| 23 | 
            +
                  assert l == 2
         | 
| 24 | 
            +
                elif self.type == "fourgram":
         | 
| 25 | 
            +
                  assert l == 3
         | 
| 26 | 
            +
                elif self.type == "fivegram":
         | 
| 27 | 
            +
                  assert l == 4
         | 
| 28 | 
            +
                elif self.type == "sixgram":
         | 
| 29 | 
            +
                  assert l == 5
         | 
| 30 | 
            +
                elif self.type == "sevengram":
         | 
| 31 | 
            +
                  assert l == 6
         | 
| 32 | 
            +
                  
         | 
| 33 | 
            +
                count = 0
         | 
| 34 | 
            +
                if key in self.corpus:
         | 
| 35 | 
            +
                  count = self.corpus[key].get(next, 0)
         | 
| 36 | 
            +
                  total = sum(self.corpus[key].values())
         | 
| 37 | 
            +
                  return count / total
         | 
| 38 | 
            +
                else:
         | 
| 39 | 
            +
                  return -1
         | 
| 40 | 
            +
                
         | 
| 41 | 
            +
              def ntd(self, key, vocab_size=32000):
         | 
| 42 | 
            +
                """
         | 
| 43 | 
            +
                Args:
         | 
| 44 | 
            +
                  key (tuple): tuple of token ID's forming prior
         | 
| 45 | 
            +
                Returns:
         | 
| 46 | 
            +
                  prob_tensor (torch.Tensor): (vocab_size, ) of full next token probabilities
         | 
| 47 | 
            +
                """
         | 
| 48 | 
            +
                if key in self.corpus:
         | 
| 49 | 
            +
                  prob_tensor = torch.zeros(vocab_size)
         | 
| 50 | 
            +
                  total = sum(self.corpus[key].values())
         | 
| 51 | 
            +
                  for next_token in self.corpus[key]:
         | 
| 52 | 
            +
                    prob_tensor[next_token] = self.corpus[key][next_token] / total
         | 
| 53 | 
            +
                  return prob_tensor
         | 
| 54 | 
            +
                else:
         | 
| 55 | 
            +
                  return None
         | 
| 56 | 
            +
             | 
| 57 | 
            +
            def make_models(ckpt_path, bigram, trigram, fourgram, fivegram, sixgram, sevengram):
         | 
| 58 | 
            +
              """
         | 
| 59 | 
            +
              Loads and returns a list correspoding to bigram to sevengram models, containing
         | 
| 60 | 
            +
              the models that whose parameters are `True`. See below for expected corpus names.
         | 
| 61 | 
            +
              Args:
         | 
| 62 | 
            +
                ckpt_path (str): Location of ngram models
         | 
| 63 | 
            +
                bigram-sevengram: Which models to load
         | 
| 64 | 
            +
              Returns:
         | 
| 65 | 
            +
                List of n-gram models
         | 
| 66 | 
            +
              """
         | 
| 67 | 
            +
              models = []
         | 
| 68 | 
            +
              if bigram:
         | 
| 69 | 
            +
                print("Making bigram...")
         | 
| 70 | 
            +
                with open(f"{ckpt_path}/b_d_final.pkl", "rb") as f:
         | 
| 71 | 
            +
                    bigram = pickle.load(f)
         | 
| 72 | 
            +
                bigram_model = NGram(bigram, None, "bigram")    
         | 
| 73 | 
            +
                models.append(bigram_model)
         | 
| 74 | 
            +
                print(sys.getsizeof(bigram))
         | 
| 75 | 
            +
                
         | 
| 76 | 
            +
              if trigram:
         | 
| 77 | 
            +
                print("Making trigram...")
         | 
| 78 | 
            +
                with open(f"{ckpt_path}/t_d_final.pkl", "rb") as f:
         | 
| 79 | 
            +
                    trigram = pickle.load(f)
         | 
| 80 | 
            +
                trigram_model = NGram(trigram, None, "trigram")
         | 
| 81 | 
            +
                models.append(trigram_model)
         | 
| 82 | 
            +
                print(sys.getsizeof(trigram))
         | 
| 83 | 
            +
                
         | 
| 84 | 
            +
              if fourgram:
         | 
| 85 | 
            +
                print("Making fourgram...")
         | 
| 86 | 
            +
                with open(f"{ckpt_path}/fo_d_final.pkl", "rb") as f:
         | 
| 87 | 
            +
                    fourgram = pickle.load(f)
         | 
| 88 | 
            +
                fourgram_model = NGram(fourgram, None, "fourgram")
         | 
| 89 | 
            +
                models.append(fourgram_model)
         | 
| 90 | 
            +
                print(sys.getsizeof(fourgram))
         | 
| 91 | 
            +
              
         | 
| 92 | 
            +
              if fivegram:
         | 
| 93 | 
            +
                print("Making fivegram...")
         | 
| 94 | 
            +
                with open(f"{ckpt_path}/fi_d_final.pkl", "rb") as f:
         | 
| 95 | 
            +
                    fivegram = pickle.load(f)
         | 
| 96 | 
            +
                fivegram_model = NGram(fivegram, None, "fivegram")
         | 
| 97 | 
            +
                models.append(fivegram_model)
         | 
| 98 | 
            +
                print(sys.getsizeof(fivegram))
         | 
| 99 | 
            +
                  
         | 
| 100 | 
            +
              if sixgram:
         | 
| 101 | 
            +
                print("Making sixgram...")
         | 
| 102 | 
            +
                with open(f"{ckpt_path}/si_d_final.pkl", "rb") as f:
         | 
| 103 | 
            +
                    sixgram = pickle.load(f)
         | 
| 104 | 
            +
                sixgram_model = NGram(sixgram, None, "sixgram")
         | 
| 105 | 
            +
                models.append(sixgram_model)
         | 
| 106 | 
            +
                print(sys.getsizeof(sixgram))
         | 
| 107 | 
            +
             | 
| 108 | 
            +
              if sevengram:
         | 
| 109 | 
            +
                print("Making sevengram...")
         | 
| 110 | 
            +
                with open(f"{ckpt_path}/se_d_final.pkl", "rb") as f:
         | 
| 111 | 
            +
                    sevengram = pickle.load(f)
         | 
| 112 | 
            +
                sevengram_model = NGram(sevengram, None, "sevengram")
         | 
| 113 | 
            +
                models.append(sevengram_model)
         | 
| 114 | 
            +
             | 
| 115 | 
            +
              return models
         | 
    	
        superposed/ngrams/test.json
    ADDED
    
    | @@ -0,0 +1,8 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            {
         | 
| 2 | 
            +
                "train": [
         | 
| 3 | 
            +
                    {"text": "Hi my name is"},
         | 
| 4 | 
            +
                    {"text": "This is a story of"},
         | 
| 5 | 
            +
            	{"text": "In many cases, the architecture you want to use can be guessed from the name or the path of the pretrained model you are supplying"},
         | 
| 6 | 
            +
            	{"text": "There is one class of AutoModel for each task, and for each backend (PyTorch, TensorFlow, or Flax)."}
         | 
| 7 | 
            +
                ]
         | 
| 8 | 
            +
            }
         | 
    	
        superposed/notebooks/custom.ipynb
    ADDED
    
    | @@ -0,0 +1,289 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            {
         | 
| 2 | 
            +
             "cells": [
         | 
| 3 | 
            +
              {
         | 
| 4 | 
            +
               "cell_type": "code",
         | 
| 5 | 
            +
               "execution_count": 2,
         | 
| 6 | 
            +
               "id": "119805f4-8589-4379-ad87-a7bad4c0e658",
         | 
| 7 | 
            +
               "metadata": {},
         | 
| 8 | 
            +
               "outputs": [
         | 
| 9 | 
            +
                {
         | 
| 10 | 
            +
                 "name": "stderr",
         | 
| 11 | 
            +
                 "output_type": "stream",
         | 
| 12 | 
            +
                 "text": [
         | 
| 13 | 
            +
                  "/gscratch/raivn/ethans/miniconda3/envs/llms_12.1/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
         | 
| 14 | 
            +
                  "  from .autonotebook import tqdm as notebook_tqdm\n",
         | 
| 15 | 
            +
                  "<frozen importlib._bootstrap>:241: RuntimeWarning: pyarrow.lib.IpcWriteOptions size changed, may indicate binary incompatibility. Expected 72 from C header, got 88 from PyObject\n",
         | 
| 16 | 
            +
                  "<frozen importlib._bootstrap>:241: RuntimeWarning: pyarrow.lib.IpcReadOptions size changed, may indicate binary incompatibility. Expected 96 from C header, got 104 from PyObject\n",
         | 
| 17 | 
            +
                  "<frozen importlib._bootstrap>:241: RuntimeWarning: pyarrow._fs.FileInfo size changed, may indicate binary incompatibility. Expected 64 from C header, got 88 from PyObject\n",
         | 
| 18 | 
            +
                  "<frozen importlib._bootstrap>:241: RuntimeWarning: pyarrow._fs.FileSelector size changed, may indicate binary incompatibility. Expected 48 from C header, got 72 from PyObject\n",
         | 
| 19 | 
            +
                  "2024-05-30 03:09:58.230601: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n",
         | 
| 20 | 
            +
                  "2024-05-30 03:09:58.280835: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
         | 
| 21 | 
            +
                  "To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
         | 
| 22 | 
            +
                  "2024-05-30 03:10:03.250651: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n"
         | 
| 23 | 
            +
                 ]
         | 
| 24 | 
            +
                }
         | 
| 25 | 
            +
               ],
         | 
| 26 | 
            +
               "source": [
         | 
| 27 | 
            +
                "%load_ext autoreload\n",
         | 
| 28 | 
            +
                "%autoreload 2\n",
         | 
| 29 | 
            +
                "\n",
         | 
| 30 | 
            +
                "import json\n",
         | 
| 31 | 
            +
                "import os\n",
         | 
| 32 | 
            +
                "import pickle\n",
         | 
| 33 | 
            +
                "from datetime import datetime\n",
         | 
| 34 | 
            +
                "\n",
         | 
| 35 | 
            +
                "import evaluate\n",
         | 
| 36 | 
            +
                "import torch\n",
         | 
| 37 | 
            +
                "from tqdm import tqdm\n",
         | 
| 38 | 
            +
                "\n",
         | 
| 39 | 
            +
                "from eval import *\n",
         | 
| 40 | 
            +
                "from superposed.llama.metrics import *\n",
         | 
| 41 | 
            +
                "from superposed.llama.generation import Llama\n",
         | 
| 42 | 
            +
                "from superposed.llama.superposed_generation import SuperposedLlama\n",
         | 
| 43 | 
            +
                "from superposed.llama.tokenizer import Tokenizer\n",
         | 
| 44 | 
            +
                "from superposed.ngrams.ngram_models import make_models"
         | 
| 45 | 
            +
               ]
         | 
| 46 | 
            +
              },
         | 
| 47 | 
            +
              {
         | 
| 48 | 
            +
               "cell_type": "code",
         | 
| 49 | 
            +
               "execution_count": 4,
         | 
| 50 | 
            +
               "id": "51c15900-c8b8-46d9-a884-6842a391ef48",
         | 
| 51 | 
            +
               "metadata": {},
         | 
| 52 | 
            +
               "outputs": [],
         | 
| 53 | 
            +
               "source": [
         | 
| 54 | 
            +
                "sup_device = torch.device(\"cuda:0\")\n",
         | 
| 55 | 
            +
                "tokenizer = Tokenizer('../../7B/tokenizer.model')"
         | 
| 56 | 
            +
               ]
         | 
| 57 | 
            +
              },
         | 
| 58 | 
            +
              {
         | 
| 59 | 
            +
               "cell_type": "code",
         | 
| 60 | 
            +
               "execution_count": 5,
         | 
| 61 | 
            +
               "id": "9817d9a4-ad64-41c6-b87b-b1e422b836a9",
         | 
| 62 | 
            +
               "metadata": {},
         | 
| 63 | 
            +
               "outputs": [
         | 
| 64 | 
            +
                {
         | 
| 65 | 
            +
                 "name": "stdout",
         | 
| 66 | 
            +
                 "output_type": "stream",
         | 
| 67 | 
            +
                 "text": [
         | 
| 68 | 
            +
                  "Parameters: {'alpha': 0.54, 'temp': 0.06, 'n_drafts': 3, 'prompt_len': 15, 'n_token_sample': 9, 'n_token_consider': 32000, 'mixing_method': 'sample_new_weights_with_score', 'smoothing': 'geom', 'sample_tokens': 0, 'sample_beams': 0, 'i_weights': [0.01, 0.04, 0.15, 0.18, 0.12], 'i_length': [1, 2, 3, 4, 5]}\n"
         | 
| 69 | 
            +
                 ]
         | 
| 70 | 
            +
                }
         | 
| 71 | 
            +
               ],
         | 
| 72 | 
            +
               "source": [
         | 
| 73 | 
            +
                "# Params\n",
         | 
| 74 | 
            +
                "param_file = \"../../params/p15_d3_mixed.json\"\n",
         | 
| 75 | 
            +
                "with open(param_file, \"r\") as f:\n",
         | 
| 76 | 
            +
                "    params = json.load(f)\n",
         | 
| 77 | 
            +
                "    print(f\"Parameters: {params}\")\n",
         | 
| 78 | 
            +
                "alpha = params[\"alpha\"]\n",
         | 
| 79 | 
            +
                "temp = params[\"temp\"]\n",
         | 
| 80 | 
            +
                "n_drafts = params[\"n_drafts\"]\n",
         | 
| 81 | 
            +
                "prompt_len = params[\"prompt_len\"]\n",
         | 
| 82 | 
            +
                "n_token_sample = params[\"n_token_sample\"]\n",
         | 
| 83 | 
            +
                "i_weights = params[\"i_weights\"]\n",
         | 
| 84 | 
            +
                "i_length = params[\"i_length\"]"
         | 
| 85 | 
            +
               ]
         | 
| 86 | 
            +
              },
         | 
| 87 | 
            +
              {
         | 
| 88 | 
            +
               "cell_type": "code",
         | 
| 89 | 
            +
               "execution_count": 6,
         | 
| 90 | 
            +
               "id": "9c99098e-a38b-4c78-a0e9-8c80309830bb",
         | 
| 91 | 
            +
               "metadata": {},
         | 
| 92 | 
            +
               "outputs": [
         | 
| 93 | 
            +
                {
         | 
| 94 | 
            +
                 "name": "stdout",
         | 
| 95 | 
            +
                 "output_type": "stream",
         | 
| 96 | 
            +
                 "text": [
         | 
| 97 | 
            +
                  "Making bigram...\n",
         | 
| 98 | 
            +
                  "1310800\n",
         | 
| 99 | 
            +
                  "Making trigram...\n",
         | 
| 100 | 
            +
                  "671088728\n",
         | 
| 101 | 
            +
                  "Making fourgram...\n",
         | 
| 102 | 
            +
                  "2684354648\n",
         | 
| 103 | 
            +
                  "Making fivegram...\n",
         | 
| 104 | 
            +
                  "5368709200\n",
         | 
| 105 | 
            +
                  "Making sixgram...\n",
         | 
| 106 | 
            +
                  "5368709200\n"
         | 
| 107 | 
            +
                 ]
         | 
| 108 | 
            +
                }
         | 
| 109 | 
            +
               ],
         | 
| 110 | 
            +
               "source": [
         | 
| 111 | 
            +
                "# Create ngram models\n",
         | 
| 112 | 
            +
                "ngrams = make_models(\"../../ckpts-200k\", bigram=True, trigram=True, fourgram=True, fivegram=True, sixgram=True, sevengram=False)"
         | 
| 113 | 
            +
               ]
         | 
| 114 | 
            +
              },
         | 
| 115 | 
            +
              {
         | 
| 116 | 
            +
               "cell_type": "code",
         | 
| 117 | 
            +
               "execution_count": 7,
         | 
| 118 | 
            +
               "id": "c3331332-242c-4e98-9f11-58c6dc0ef581",
         | 
| 119 | 
            +
               "metadata": {},
         | 
| 120 | 
            +
               "outputs": [
         | 
| 121 | 
            +
                {
         | 
| 122 | 
            +
                 "name": "stdout",
         | 
| 123 | 
            +
                 "output_type": "stream",
         | 
| 124 | 
            +
                 "text": [
         | 
| 125 | 
            +
                  "> initializing model parallel with size 1\n",
         | 
| 126 | 
            +
                  "> initializing ddp with size 1\n",
         | 
| 127 | 
            +
                  "> initializing pipeline with size 1\n"
         | 
| 128 | 
            +
                 ]
         | 
| 129 | 
            +
                },
         | 
| 130 | 
            +
                {
         | 
| 131 | 
            +
                 "name": "stderr",
         | 
| 132 | 
            +
                 "output_type": "stream",
         | 
| 133 | 
            +
                 "text": [
         | 
| 134 | 
            +
                  "/gscratch/raivn/ethans/miniconda3/envs/llms_12.1/lib/python3.11/site-packages/torch/__init__.py:614: UserWarning: torch.set_default_tensor_type() is deprecated as of PyTorch 2.1, please use torch.set_default_dtype() and torch.set_default_device() as alternatives. (Triggered internally at ../torch/csrc/tensor/python_tensor.cpp:451.)\n",
         | 
| 135 | 
            +
                  "  _C._set_default_tensor_type(t)\n"
         | 
| 136 | 
            +
                 ]
         | 
| 137 | 
            +
                },
         | 
| 138 | 
            +
                {
         | 
| 139 | 
            +
                 "name": "stdout",
         | 
| 140 | 
            +
                 "output_type": "stream",
         | 
| 141 | 
            +
                 "text": [
         | 
| 142 | 
            +
                  "Loaded in 25.15 seconds\n",
         | 
| 143 | 
            +
                  "cuda:0\n"
         | 
| 144 | 
            +
                 ]
         | 
| 145 | 
            +
                }
         | 
| 146 | 
            +
               ],
         | 
| 147 | 
            +
               "source": [
         | 
| 148 | 
            +
                "weight_path = \"../../7B/\"\n",
         | 
| 149 | 
            +
                "model = SuperposedLlama.build(ckpt_dir=weight_path, \n",
         | 
| 150 | 
            +
                "                         tokenizer_path=f'{weight_path}/tokenizer.model', \n",
         | 
| 151 | 
            +
                "                         max_seq_len=100, \n",
         | 
| 152 | 
            +
                "                         max_batch_size=32,\n",
         | 
| 153 | 
            +
                "                         device=sup_device,\n",
         | 
| 154 | 
            +
                "                         model_parallel_size=1)"
         | 
| 155 | 
            +
               ]
         | 
| 156 | 
            +
              },
         | 
| 157 | 
            +
              {
         | 
| 158 | 
            +
               "cell_type": "markdown",
         | 
| 159 | 
            +
               "id": "e2b48c23-d6a3-43b1-ad4c-54524aacfda6",
         | 
| 160 | 
            +
               "metadata": {},
         | 
| 161 | 
            +
               "source": [
         | 
| 162 | 
            +
                "# Inference"
         | 
| 163 | 
            +
               ]
         | 
| 164 | 
            +
              },
         | 
| 165 | 
            +
              {
         | 
| 166 | 
            +
               "cell_type": "code",
         | 
| 167 | 
            +
               "execution_count": 11,
         | 
| 168 | 
            +
               "id": "5093373b-bf76-47e3-8f99-1045b60f29c3",
         | 
| 169 | 
            +
               "metadata": {},
         | 
| 170 | 
            +
               "outputs": [],
         | 
| 171 | 
            +
               "source": [
         | 
| 172 | 
            +
                "def decode(tokenizer, encoding):\n",
         | 
| 173 | 
            +
                "    \"\"\"\n",
         | 
| 174 | 
            +
                "    Args:\n",
         | 
| 175 | 
            +
                "        tokenizer (Any): Tokenizer\n",
         | 
| 176 | 
            +
                "        encoding (torch.Tensor): Encoding\n",
         | 
| 177 | 
            +
                "    Returns:\n",
         | 
| 178 | 
            +
                "        decoding (str)\n",
         | 
| 179 | 
            +
                "    \"\"\"\n",
         | 
| 180 | 
            +
                "    eos_locs = (encoding == tokenizer.eos_id).nonzero()\n",
         | 
| 181 | 
            +
                "    if len(eos_locs > 0):\n",
         | 
| 182 | 
            +
                "        encoding = encoding[:eos_locs[0]]\n",
         | 
| 183 | 
            +
                "    return tokenizer.decode(encoding.to(torch.int32).tolist())"
         | 
| 184 | 
            +
               ]
         | 
| 185 | 
            +
              },
         | 
| 186 | 
            +
              {
         | 
| 187 | 
            +
               "cell_type": "code",
         | 
| 188 | 
            +
               "execution_count": 22,
         | 
| 189 | 
            +
               "id": "18703b19-f3e9-46e4-ab1c-c6d3b403c6d2",
         | 
| 190 | 
            +
               "metadata": {},
         | 
| 191 | 
            +
               "outputs": [],
         | 
| 192 | 
            +
               "source": [
         | 
| 193 | 
            +
                "prompts = [\n",
         | 
| 194 | 
            +
                "    \"Hi my name is\",\n",
         | 
| 195 | 
            +
                "    \"The Seattle Seahawks were Super Bowl\",\n",
         | 
| 196 | 
            +
                "    \"Penguins are birds native to\"\n",
         | 
| 197 | 
            +
                "]\n",
         | 
| 198 | 
            +
                "tokenized_prompts = tokenizer.encode(prompts, True, False)"
         | 
| 199 | 
            +
               ]
         | 
| 200 | 
            +
              },
         | 
| 201 | 
            +
              {
         | 
| 202 | 
            +
               "cell_type": "code",
         | 
| 203 | 
            +
               "execution_count": 23,
         | 
| 204 | 
            +
               "id": "d39cd735-9480-4979-ac92-bbd470f75570",
         | 
| 205 | 
            +
               "metadata": {},
         | 
| 206 | 
            +
               "outputs": [],
         | 
| 207 | 
            +
               "source": [
         | 
| 208 | 
            +
                "alive_gens, _ = model.sup_generate(prompt_tokens=tokenized_prompts, \n",
         | 
| 209 | 
            +
                "                                        smoothing=\"geom\",\n",
         | 
| 210 | 
            +
                "                                        max_gen_len=10, \n",
         | 
| 211 | 
            +
                "                                        n_token_sample=n_token_sample,\n",
         | 
| 212 | 
            +
                "                                        alpha=alpha, \n",
         | 
| 213 | 
            +
                "                                        temp=temp,\n",
         | 
| 214 | 
            +
                "                                        n_drafts=n_drafts,\n",
         | 
| 215 | 
            +
                "                                        i_weights=i_weights,\n",
         | 
| 216 | 
            +
                "                                        i_length=i_length,\n",
         | 
| 217 | 
            +
                "                                        ngrams=ngrams,\n",
         | 
| 218 | 
            +
                "                                        get_time=False,\n",
         | 
| 219 | 
            +
                "                                        penalty=200)"
         | 
| 220 | 
            +
               ]
         | 
| 221 | 
            +
              },
         | 
| 222 | 
            +
              {
         | 
| 223 | 
            +
               "cell_type": "code",
         | 
| 224 | 
            +
               "execution_count": 24,
         | 
| 225 | 
            +
               "id": "cfefa793-e49e-483a-a504-5cc9e23f619d",
         | 
| 226 | 
            +
               "metadata": {},
         | 
| 227 | 
            +
               "outputs": [],
         | 
| 228 | 
            +
               "source": [
         | 
| 229 | 
            +
                "gens = alive_gens[0].reshape(len(prompts) * n_drafts, -1)"
         | 
| 230 | 
            +
               ]
         | 
| 231 | 
            +
              },
         | 
| 232 | 
            +
              {
         | 
| 233 | 
            +
               "cell_type": "code",
         | 
| 234 | 
            +
               "execution_count": 25,
         | 
| 235 | 
            +
               "id": "5abf87ab-2ee0-4204-868b-1215abf0c8aa",
         | 
| 236 | 
            +
               "metadata": {},
         | 
| 237 | 
            +
               "outputs": [
         | 
| 238 | 
            +
                {
         | 
| 239 | 
            +
                 "name": "stdout",
         | 
| 240 | 
            +
                 "output_type": "stream",
         | 
| 241 | 
            +
                 "text": [
         | 
| 242 | 
            +
                  "Hi\n",
         | 
| 243 | 
            +
                  "my name\n",
         | 
| 244 | 
            +
                  "is L\n",
         | 
| 245 | 
            +
                  "inda,\n",
         | 
| 246 | 
            +
                  "I am\n",
         | 
| 247 | 
            +
                  "a \n",
         | 
| 248 | 
            +
                  "40\n",
         | 
| 249 | 
            +
                  "year old\n",
         | 
| 250 | 
            +
                  "woman who\n"
         | 
| 251 | 
            +
                 ]
         | 
| 252 | 
            +
                }
         | 
| 253 | 
            +
               ],
         | 
| 254 | 
            +
               "source": [
         | 
| 255 | 
            +
                "for i in gens:\n",
         | 
| 256 | 
            +
                "    print(decode(tokenizer, i))"
         | 
| 257 | 
            +
               ]
         | 
| 258 | 
            +
              },
         | 
| 259 | 
            +
              {
         | 
| 260 | 
            +
               "cell_type": "code",
         | 
| 261 | 
            +
               "execution_count": null,
         | 
| 262 | 
            +
               "id": "e73dc3cc-baa5-468d-bdd1-827465bdeb62",
         | 
| 263 | 
            +
               "metadata": {},
         | 
| 264 | 
            +
               "outputs": [],
         | 
| 265 | 
            +
               "source": []
         | 
| 266 | 
            +
              }
         | 
| 267 | 
            +
             ],
         | 
| 268 | 
            +
             "metadata": {
         | 
| 269 | 
            +
              "kernelspec": {
         | 
| 270 | 
            +
               "display_name": "Python 3 (ipykernel)",
         | 
| 271 | 
            +
               "language": "python",
         | 
| 272 | 
            +
               "name": "python3"
         | 
| 273 | 
            +
              },
         | 
| 274 | 
            +
              "language_info": {
         | 
| 275 | 
            +
               "codemirror_mode": {
         | 
| 276 | 
            +
                "name": "ipython",
         | 
| 277 | 
            +
                "version": 3
         | 
| 278 | 
            +
               },
         | 
| 279 | 
            +
               "file_extension": ".py",
         | 
| 280 | 
            +
               "mimetype": "text/x-python",
         | 
| 281 | 
            +
               "name": "python",
         | 
| 282 | 
            +
               "nbconvert_exporter": "python",
         | 
| 283 | 
            +
               "pygments_lexer": "ipython3",
         | 
| 284 | 
            +
               "version": "3.11.5"
         | 
| 285 | 
            +
              }
         | 
| 286 | 
            +
             },
         | 
| 287 | 
            +
             "nbformat": 4,
         | 
| 288 | 
            +
             "nbformat_minor": 5
         | 
| 289 | 
            +
            }
         | 
    	
        superposed/notebooks/nq.ipynb
    ADDED
    
    | @@ -0,0 +1,417 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            {
         | 
| 2 | 
            +
             "cells": [
         | 
| 3 | 
            +
              {
         | 
| 4 | 
            +
               "cell_type": "code",
         | 
| 5 | 
            +
               "execution_count": 2,
         | 
| 6 | 
            +
               "metadata": {},
         | 
| 7 | 
            +
               "outputs": [
         | 
| 8 | 
            +
                {
         | 
| 9 | 
            +
                 "name": "stdout",
         | 
| 10 | 
            +
                 "output_type": "stream",
         | 
| 11 | 
            +
                 "text": [
         | 
| 12 | 
            +
                  "The autoreload extension is already loaded. To reload it, use:\n",
         | 
| 13 | 
            +
                  "  %reload_ext autoreload\n"
         | 
| 14 | 
            +
                 ]
         | 
| 15 | 
            +
                }
         | 
| 16 | 
            +
               ],
         | 
| 17 | 
            +
               "source": [
         | 
| 18 | 
            +
                "%load_ext autoreload\n",
         | 
| 19 | 
            +
                "%autoreload 2\n",
         | 
| 20 | 
            +
                "\n",
         | 
| 21 | 
            +
                "import json\n",
         | 
| 22 | 
            +
                "import os\n",
         | 
| 23 | 
            +
                "import re\n",
         | 
| 24 | 
            +
                "from datetime import datetime\n",
         | 
| 25 | 
            +
                "\n",
         | 
| 26 | 
            +
                "import torch\n",
         | 
| 27 | 
            +
                "from datasets import load_dataset\n",
         | 
| 28 | 
            +
                "from tqdm import tqdm\n",
         | 
| 29 | 
            +
                "\n",
         | 
| 30 | 
            +
                "from eval import *\n",
         | 
| 31 | 
            +
                "from superposed.llama.metrics import *\n",
         | 
| 32 | 
            +
                "from superposed.llama.generation import Llama\n",
         | 
| 33 | 
            +
                "from superposed.llama.superposed_generation import SuperposedLlama\n",
         | 
| 34 | 
            +
                "from superposed.llama.tokenizer import Tokenizer\n",
         | 
| 35 | 
            +
                "from superposed.ngrams.ngram_models import make_models"
         | 
| 36 | 
            +
               ]
         | 
| 37 | 
            +
              },
         | 
| 38 | 
            +
              {
         | 
| 39 | 
            +
               "cell_type": "markdown",
         | 
| 40 | 
            +
               "metadata": {},
         | 
| 41 | 
            +
               "source": [
         | 
| 42 | 
            +
                "# Setup"
         | 
| 43 | 
            +
               ]
         | 
| 44 | 
            +
              },
         | 
| 45 | 
            +
              {
         | 
| 46 | 
            +
               "cell_type": "code",
         | 
| 47 | 
            +
               "execution_count": 3,
         | 
| 48 | 
            +
               "metadata": {},
         | 
| 49 | 
            +
               "outputs": [],
         | 
| 50 | 
            +
               "source": [
         | 
| 51 | 
            +
                "nq = load_dataset(\"nq_open\")[\"validation\"]"
         | 
| 52 | 
            +
               ]
         | 
| 53 | 
            +
              },
         | 
| 54 | 
            +
              {
         | 
| 55 | 
            +
               "cell_type": "code",
         | 
| 56 | 
            +
               "execution_count": 6,
         | 
| 57 | 
            +
               "metadata": {},
         | 
| 58 | 
            +
               "outputs": [
         | 
| 59 | 
            +
                {
         | 
| 60 | 
            +
                 "name": "stdout",
         | 
| 61 | 
            +
                 "output_type": "stream",
         | 
| 62 | 
            +
                 "text": [
         | 
| 63 | 
            +
                  "Parameters: {'alpha': 0.54, 'temp': 0.06, 'n_drafts': 3, 'prompt_len': 15, 'n_token_sample': 9, 'n_token_consider': 32000, 'mixing_method': 'sample_new_weights_with_score', 'smoothing': 'geom', 'sample_tokens': 0, 'sample_beams': 0, 'i_weights': [0.01, 0.04, 0.15, 0.18, 0.12], 'i_length': [1, 2, 3, 4, 5]}\n"
         | 
| 64 | 
            +
                 ]
         | 
| 65 | 
            +
                }
         | 
| 66 | 
            +
               ],
         | 
| 67 | 
            +
               "source": [
         | 
| 68 | 
            +
                "# Params\n",
         | 
| 69 | 
            +
                "param_file = \"../../params/p15_d3_mixed.json\"\n",
         | 
| 70 | 
            +
                "with open(param_file, \"r\") as f:\n",
         | 
| 71 | 
            +
                "    params = json.load(f)\n",
         | 
| 72 | 
            +
                "    print(f\"Parameters: {params}\")\n",
         | 
| 73 | 
            +
                "alpha = params[\"alpha\"]\n",
         | 
| 74 | 
            +
                "temp = params[\"temp\"]\n",
         | 
| 75 | 
            +
                "n_drafts = params[\"n_drafts\"]\n",
         | 
| 76 | 
            +
                "prompt_len = params[\"prompt_len\"]\n",
         | 
| 77 | 
            +
                "n_token_sample = params[\"n_token_sample\"]\n",
         | 
| 78 | 
            +
                "i_weights = params[\"i_weights\"]\n",
         | 
| 79 | 
            +
                "i_length = params[\"i_length\"]"
         | 
| 80 | 
            +
               ]
         | 
| 81 | 
            +
              },
         | 
| 82 | 
            +
              {
         | 
| 83 | 
            +
               "cell_type": "markdown",
         | 
| 84 | 
            +
               "metadata": {},
         | 
| 85 | 
            +
               "source": [
         | 
| 86 | 
            +
                "# Create Models"
         | 
| 87 | 
            +
               ]
         | 
| 88 | 
            +
              },
         | 
| 89 | 
            +
              {
         | 
| 90 | 
            +
               "cell_type": "code",
         | 
| 91 | 
            +
               "execution_count": 7,
         | 
| 92 | 
            +
               "metadata": {},
         | 
| 93 | 
            +
               "outputs": [
         | 
| 94 | 
            +
                {
         | 
| 95 | 
            +
                 "name": "stdout",
         | 
| 96 | 
            +
                 "output_type": "stream",
         | 
| 97 | 
            +
                 "text": [
         | 
| 98 | 
            +
                  "Making bigram...\n",
         | 
| 99 | 
            +
                  "1310800\n",
         | 
| 100 | 
            +
                  "Making trigram...\n",
         | 
| 101 | 
            +
                  "671088728\n",
         | 
| 102 | 
            +
                  "Making fourgram...\n",
         | 
| 103 | 
            +
                  "2684354648\n",
         | 
| 104 | 
            +
                  "Making fivegram...\n",
         | 
| 105 | 
            +
                  "5368709200\n",
         | 
| 106 | 
            +
                  "Making sixgram...\n",
         | 
| 107 | 
            +
                  "5368709200\n"
         | 
| 108 | 
            +
                 ]
         | 
| 109 | 
            +
                }
         | 
| 110 | 
            +
               ],
         | 
| 111 | 
            +
               "source": [
         | 
| 112 | 
            +
                "ngrams = make_models(\"../../ckpts-200k\", bigram=True, trigram=True, fourgram=True, fivegram=True, sixgram=True, sevengram=False)"
         | 
| 113 | 
            +
               ]
         | 
| 114 | 
            +
              },
         | 
| 115 | 
            +
              {
         | 
| 116 | 
            +
               "cell_type": "code",
         | 
| 117 | 
            +
               "execution_count": 9,
         | 
| 118 | 
            +
               "metadata": {},
         | 
| 119 | 
            +
               "outputs": [],
         | 
| 120 | 
            +
               "source": [
         | 
| 121 | 
            +
                "sup_device = torch.device(\"cuda:0\")\n",
         | 
| 122 | 
            +
                "reg_device = torch.device(\"cuda:1\")"
         | 
| 123 | 
            +
               ]
         | 
| 124 | 
            +
              },
         | 
| 125 | 
            +
              {
         | 
| 126 | 
            +
               "cell_type": "code",
         | 
| 127 | 
            +
               "execution_count": 11,
         | 
| 128 | 
            +
               "metadata": {},
         | 
| 129 | 
            +
               "outputs": [
         | 
| 130 | 
            +
                {
         | 
| 131 | 
            +
                 "name": "stdout",
         | 
| 132 | 
            +
                 "output_type": "stream",
         | 
| 133 | 
            +
                 "text": [
         | 
| 134 | 
            +
                  "> initializing model parallel with size 1\n",
         | 
| 135 | 
            +
                  "> initializing ddp with size 1\n",
         | 
| 136 | 
            +
                  "> initializing pipeline with size 1\n"
         | 
| 137 | 
            +
                 ]
         | 
| 138 | 
            +
                },
         | 
| 139 | 
            +
                {
         | 
| 140 | 
            +
                 "name": "stderr",
         | 
| 141 | 
            +
                 "output_type": "stream",
         | 
| 142 | 
            +
                 "text": [
         | 
| 143 | 
            +
                  "/gscratch/raivn/ethans/miniconda3/envs/llms_12.1/lib/python3.11/site-packages/torch/__init__.py:614: UserWarning: torch.set_default_tensor_type() is deprecated as of PyTorch 2.1, please use torch.set_default_dtype() and torch.set_default_device() as alternatives. (Triggered internally at ../torch/csrc/tensor/python_tensor.cpp:451.)\n",
         | 
| 144 | 
            +
                  "  _C._set_default_tensor_type(t)\n"
         | 
| 145 | 
            +
                 ]
         | 
| 146 | 
            +
                },
         | 
| 147 | 
            +
                {
         | 
| 148 | 
            +
                 "name": "stdout",
         | 
| 149 | 
            +
                 "output_type": "stream",
         | 
| 150 | 
            +
                 "text": [
         | 
| 151 | 
            +
                  "Loaded in 33.68 seconds\n",
         | 
| 152 | 
            +
                  "cuda:0\n"
         | 
| 153 | 
            +
                 ]
         | 
| 154 | 
            +
                }
         | 
| 155 | 
            +
               ],
         | 
| 156 | 
            +
               "source": [
         | 
| 157 | 
            +
                "# load superposed\n",
         | 
| 158 | 
            +
                "weight_path = \"../../7B/\"\n",
         | 
| 159 | 
            +
                "sup_model = SuperposedLlama.build(ckpt_dir=weight_path, \n",
         | 
| 160 | 
            +
                "                                 tokenizer_path=f'{weight_path}/tokenizer.model', \n",
         | 
| 161 | 
            +
                "                                 max_seq_len=1000, \n",
         | 
| 162 | 
            +
                "                                 max_batch_size=16,\n",
         | 
| 163 | 
            +
                "                                 device=sup_device,\n",
         | 
| 164 | 
            +
                "                                 model_parallel_size=1)"
         | 
| 165 | 
            +
               ]
         | 
| 166 | 
            +
              },
         | 
| 167 | 
            +
              {
         | 
| 168 | 
            +
               "cell_type": "code",
         | 
| 169 | 
            +
               "execution_count": 12,
         | 
| 170 | 
            +
               "metadata": {},
         | 
| 171 | 
            +
               "outputs": [
         | 
| 172 | 
            +
                {
         | 
| 173 | 
            +
                 "name": "stdout",
         | 
| 174 | 
            +
                 "output_type": "stream",
         | 
| 175 | 
            +
                 "text": [
         | 
| 176 | 
            +
                  "0\n",
         | 
| 177 | 
            +
                  "Loaded in 22.47 seconds\n"
         | 
| 178 | 
            +
                 ]
         | 
| 179 | 
            +
                }
         | 
| 180 | 
            +
               ],
         | 
| 181 | 
            +
               "source": [
         | 
| 182 | 
            +
                "# load regular\n",
         | 
| 183 | 
            +
                "reg_model = Llama.build(ckpt_dir=weight_path, \n",
         | 
| 184 | 
            +
                "                    tokenizer_path=f'{weight_path}/tokenizer.model', \n",
         | 
| 185 | 
            +
                "                    max_seq_len=1000, \n",
         | 
| 186 | 
            +
                "                    max_batch_size=16,\n",
         | 
| 187 | 
            +
                "                    device=reg_device, # reg_device,\n",
         | 
| 188 | 
            +
                "                    model_parallel_size=1)"
         | 
| 189 | 
            +
               ]
         | 
| 190 | 
            +
              },
         | 
| 191 | 
            +
              {
         | 
| 192 | 
            +
               "cell_type": "code",
         | 
| 193 | 
            +
               "execution_count": 13,
         | 
| 194 | 
            +
               "metadata": {},
         | 
| 195 | 
            +
               "outputs": [],
         | 
| 196 | 
            +
               "source": [
         | 
| 197 | 
            +
                "tokenizer = Tokenizer(f\"{weight_path}/tokenizer.model\")"
         | 
| 198 | 
            +
               ]
         | 
| 199 | 
            +
              },
         | 
| 200 | 
            +
              {
         | 
| 201 | 
            +
               "cell_type": "markdown",
         | 
| 202 | 
            +
               "metadata": {},
         | 
| 203 | 
            +
               "source": [
         | 
| 204 | 
            +
                "# Evaluation"
         | 
| 205 | 
            +
               ]
         | 
| 206 | 
            +
              },
         | 
| 207 | 
            +
              {
         | 
| 208 | 
            +
               "cell_type": "code",
         | 
| 209 | 
            +
               "execution_count": 14,
         | 
| 210 | 
            +
               "metadata": {},
         | 
| 211 | 
            +
               "outputs": [],
         | 
| 212 | 
            +
               "source": [
         | 
| 213 | 
            +
                "model_types = [\"greedy\", \"superposed\", \"regular\"]\n",
         | 
| 214 | 
            +
                "model_type = model_types[1]"
         | 
| 215 | 
            +
               ]
         | 
| 216 | 
            +
              },
         | 
| 217 | 
            +
              {
         | 
| 218 | 
            +
               "cell_type": "code",
         | 
| 219 | 
            +
               "execution_count": 17,
         | 
| 220 | 
            +
               "metadata": {},
         | 
| 221 | 
            +
               "outputs": [],
         | 
| 222 | 
            +
               "source": [
         | 
| 223 | 
            +
                "def evaluate_nq(model_type, question, max_gen_len):\n",
         | 
| 224 | 
            +
                "    question = \"Answer these questions:\\n\\nQ: \" + question + \"?\\nA:\"\n",
         | 
| 225 | 
            +
                "    text_len = len(question) # for truncating\n",
         | 
| 226 | 
            +
                "    prompt_len = len(tokenizer.encode([question], True, False)[0]) # for model\n",
         | 
| 227 | 
            +
                "    if model_type == \"regular\" or model_type == \"greedy\":\n",
         | 
| 228 | 
            +
                "        if model_type == \"regular\":\n",
         | 
| 229 | 
            +
                "            input = [question for _ in range(n_drafts)]\n",
         | 
| 230 | 
            +
                "            print(input)\n",
         | 
| 231 | 
            +
                "            sequences, _ = evaluate_nucleus_losses(data=input,\n",
         | 
| 232 | 
            +
                "                                                   model=reg_model,\n",
         | 
| 233 | 
            +
                "                                                   tokenizer=tokenizer,\n",
         | 
| 234 | 
            +
                "                                                   prompt_len=prompt_len,\n",
         | 
| 235 | 
            +
                "                                                   max_gen_len=max_gen_len,\n",
         | 
| 236 | 
            +
                "                                                   temp=0.6,\n",
         | 
| 237 | 
            +
                "                                                   bsz=8,\n",
         | 
| 238 | 
            +
                "                                                   marker=False)\n",
         | 
| 239 | 
            +
                "        else:\n",
         | 
| 240 | 
            +
                "            sequences, _ = evaluate_nucleus_losses(data=[question],\n",
         | 
| 241 | 
            +
                "                                       model=reg_model,\n",
         | 
| 242 | 
            +
                "                                       tokenizer=tokenizer,\n",
         | 
| 243 | 
            +
                "                                       prompt_len=prompt_len,\n",
         | 
| 244 | 
            +
                "                                       max_gen_len=max_gen_len,\n",
         | 
| 245 | 
            +
                "                                       temp=0,\n",
         | 
| 246 | 
            +
                "                                       bsz=8,\n",
         | 
| 247 | 
            +
                "                                       marker=False)\n",
         | 
| 248 | 
            +
                "        n_pd, seq_len = sequences.shape\n",
         | 
| 249 | 
            +
                "    elif model_type == \"superposed\":\n",
         | 
| 250 | 
            +
                "        sequences, _ = evaluate_mixed_losses(data=[question],\n",
         | 
| 251 | 
            +
                "                                                   model=sup_model,\n",
         | 
| 252 | 
            +
                "                                                   tokenizer=tokenizer,\n",
         | 
| 253 | 
            +
                "                                                   prompt_len=prompt_len,\n",
         | 
| 254 | 
            +
                "                                                   max_gen_len=max_gen_len,\n",
         | 
| 255 | 
            +
                "                                                   alpha=alpha,\n",
         | 
| 256 | 
            +
                "                                                   temp=temp,\n",
         | 
| 257 | 
            +
                "                                                   n_drafts=n_drafts,\n",
         | 
| 258 | 
            +
                "                                                   n_token_sample=n_token_sample,\n",
         | 
| 259 | 
            +
                "                                                   smoothing=None, # Use greedy\n",
         | 
| 260 | 
            +
                "                                                   bsz=8,\n",
         | 
| 261 | 
            +
                "                                                   i_weights=i_weights,\n",
         | 
| 262 | 
            +
                "                                                   i_length=i_length,\n",
         | 
| 263 | 
            +
                "                                                   ngrams=ngrams,\n",
         | 
| 264 | 
            +
                "                                                   marker=False)\n",
         | 
| 265 | 
            +
                "        n_p, n_d, seq_len = sequences.shape\n",
         | 
| 266 | 
            +
                "    # Process results\n",
         | 
| 267 | 
            +
                "    sequences = sequences.reshape(-1, seq_len).tolist()\n",
         | 
| 268 | 
            +
                "    for d_idx in range(len(sequences)):\n",
         | 
| 269 | 
            +
                "        draft = sequences[d_idx]\n",
         | 
| 270 | 
            +
                "        if -1 in draft:\n",
         | 
| 271 | 
            +
                "            draft = draft[:draft.index(-1)]\n",
         | 
| 272 | 
            +
                "        sequences[d_idx] = draft\n",
         | 
| 273 | 
            +
                "    decoded_seq = tokenizer.decode(sequences)\n",
         | 
| 274 | 
            +
                "    answers = []\n",
         | 
| 275 | 
            +
                "    for s in decoded_seq:\n",
         | 
| 276 | 
            +
                "        answers.append(re.split(\"[,.\\n]\", s[text_len:].strip())[0])\n",
         | 
| 277 | 
            +
                "    return answers\n",
         | 
| 278 | 
            +
                "            "
         | 
| 279 | 
            +
               ]
         | 
| 280 | 
            +
              },
         | 
| 281 | 
            +
              {
         | 
| 282 | 
            +
               "cell_type": "code",
         | 
| 283 | 
            +
               "execution_count": null,
         | 
| 284 | 
            +
               "metadata": {},
         | 
| 285 | 
            +
               "outputs": [],
         | 
| 286 | 
            +
               "source": [
         | 
| 287 | 
            +
                "# Run evaluation\n",
         | 
| 288 | 
            +
                "predictions = []\n",
         | 
| 289 | 
            +
                "print(f\"Precision from 1 to {n_drafts}\")\n",
         | 
| 290 | 
            +
                "for sample in tqdm(nq):\n",
         | 
| 291 | 
            +
                "    # Adaptively determine max generation length\n",
         | 
| 292 | 
            +
                "    longest = 0\n",
         | 
| 293 | 
            +
                "    shortest = 1000\n",
         | 
| 294 | 
            +
                "    for answer in sample[\"answer\"]:\n",
         | 
| 295 | 
            +
                "        tmp = tokenizer.encode([answer], False, False)[0]\n",
         | 
| 296 | 
            +
                "        if len(tmp) > longest:\n",
         | 
| 297 | 
            +
                "            longest = len(tmp)\n",
         | 
| 298 | 
            +
                "        if len(tmp) < shortest:\n",
         | 
| 299 | 
            +
                "            shortest = len(tmp)\n",
         | 
| 300 | 
            +
                "    question = sample[\"question\"]\n",
         | 
| 301 | 
            +
                "    answer = evaluate_nq(model_type, question, max_gen_len=shortest+3)\n",
         | 
| 302 | 
            +
                "    predictions.append({\"question\": question, \"answer\": answer})"
         | 
| 303 | 
            +
               ]
         | 
| 304 | 
            +
              },
         | 
| 305 | 
            +
              {
         | 
| 306 | 
            +
               "cell_type": "code",
         | 
| 307 | 
            +
               "execution_count": 52,
         | 
| 308 | 
            +
               "metadata": {},
         | 
| 309 | 
            +
               "outputs": [],
         | 
| 310 | 
            +
               "source": [
         | 
| 311 | 
            +
                "# Separate results into precisions\n",
         | 
| 312 | 
            +
                "precisions = {}\n",
         | 
| 313 | 
            +
                "for i in range(1, n_drafts+1):\n",
         | 
| 314 | 
            +
                "    prec = str(i)\n",
         | 
| 315 | 
            +
                "    responses = []\n",
         | 
| 316 | 
            +
                "    for result in predictions:\n",
         | 
| 317 | 
            +
                "        responses.append({\"question\": result[\"question\"], \"answer\": result[\"answer\"][:i]})\n",
         | 
| 318 | 
            +
                "    precisions[prec] = responses"
         | 
| 319 | 
            +
               ]
         | 
| 320 | 
            +
              },
         | 
| 321 | 
            +
              {
         | 
| 322 | 
            +
               "cell_type": "code",
         | 
| 323 | 
            +
               "execution_count": 53,
         | 
| 324 | 
            +
               "metadata": {},
         | 
| 325 | 
            +
               "outputs": [
         | 
| 326 | 
            +
                {
         | 
| 327 | 
            +
                 "name": "stdout",
         | 
| 328 | 
            +
                 "output_type": "stream",
         | 
| 329 | 
            +
                 "text": [
         | 
| 330 | 
            +
                  "{'question': 'when was the last time anyone was on the moon', 'answer': ['2019', '2019', '2019-', '2019-', '1019']}\n",
         | 
| 331 | 
            +
                  "================\n",
         | 
| 332 | 
            +
                  "{'question': \"who wrote he ain't heavy he's my brother lyrics\", 'answer': ['The song was written by', 'The lyr was written by', 'The Hol was written by', 'Neil song was written by', 'Neil lyr was written by']}\n",
         | 
| 333 | 
            +
                  "================\n",
         | 
| 334 | 
            +
                  "{'question': 'how many seasons of the bastard executioner are there', 'answer': ['1', 'There1', 'there1', '1', 'There1']}\n",
         | 
| 335 | 
            +
                  "================\n",
         | 
| 336 | 
            +
                  "{'question': 'when did the eagles win last super bowl', 'answer': ['2018', 'The2018', '1018', '2017', 'the2018']}\n",
         | 
| 337 | 
            +
                  "================\n",
         | 
| 338 | 
            +
                  "{'question': \"who won last year's ncaa women's basketball\", 'answer': ['the university of connecticut', 'The university of connecticut', 'university of connecticut', 'the University of connecticut', 'The University of connecticut']}\n",
         | 
| 339 | 
            +
                  "================\n",
         | 
| 340 | 
            +
                  "{'question': 'when did the isle of wight become an island', 'answer': ['1207', 'when1207', '1287', '1277', 'when1287']}\n",
         | 
| 341 | 
            +
                  "================\n",
         | 
| 342 | 
            +
                  "{'question': 'love yourself by justin bieber is about who', 'answer': ['love yourself by justin b', 'love yourself is justin b', 'Justin yourself by justin b', 'Justin yourself is justin b', 'It yourself by justin b']}\n",
         | 
| 343 | 
            +
                  "================\n",
         | 
| 344 | 
            +
                  "{'question': 'who was the ruler of england in 1616', 'answer': ['James I', 'James I of', 'King I', 'j I', 'James I']}\n",
         | 
| 345 | 
            +
                  "================\n",
         | 
| 346 | 
            +
                  "{'question': 'what is the hot coffee mod in san andreas', 'answer': ['The Hot Coffee mod is a modification for Grand', 'The Hot Coffee mod is a mod for Grand', 'The hot Coffee mod is a modification for Grand', 'The Hot Coffee mod is a modification that Grand', 'It Hot Coffee mod is a modification for Grand']}\n",
         | 
| 347 | 
            +
                  "================\n",
         | 
| 348 | 
            +
                  "{'question': 'what is the maximum data rate for the 802.11a standard select one', 'answer': ['54 Mbps', '54Mbps', '54 mbps', '54 Mbps', '54 Mbps']}\n",
         | 
| 349 | 
            +
                  "================\n"
         | 
| 350 | 
            +
                 ]
         | 
| 351 | 
            +
                }
         | 
| 352 | 
            +
               ],
         | 
| 353 | 
            +
               "source": [
         | 
| 354 | 
            +
                "# Print some results\n",
         | 
| 355 | 
            +
                "counter = 0\n",
         | 
| 356 | 
            +
                "for k in predictions:\n",
         | 
| 357 | 
            +
                "    if counter >= 10:\n",
         | 
| 358 | 
            +
                "        break\n",
         | 
| 359 | 
            +
                "    print(k)\n",
         | 
| 360 | 
            +
                "    counter += 1\n",
         | 
| 361 | 
            +
                "    print(\"================\")"
         | 
| 362 | 
            +
               ]
         | 
| 363 | 
            +
              },
         | 
| 364 | 
            +
              {
         | 
| 365 | 
            +
               "cell_type": "markdown",
         | 
| 366 | 
            +
               "metadata": {},
         | 
| 367 | 
            +
               "source": [
         | 
| 368 | 
            +
                "# Saving"
         | 
| 369 | 
            +
               ]
         | 
| 370 | 
            +
              },
         | 
| 371 | 
            +
              {
         | 
| 372 | 
            +
               "cell_type": "code",
         | 
| 373 | 
            +
               "execution_count": 54,
         | 
| 374 | 
            +
               "metadata": {},
         | 
| 375 | 
            +
               "outputs": [
         | 
| 376 | 
            +
                {
         | 
| 377 | 
            +
                 "name": "stdout",
         | 
| 378 | 
            +
                 "output_type": "stream",
         | 
| 379 | 
            +
                 "text": [
         | 
| 380 | 
            +
                  "dict_keys(['1', '2', '3', '4', '5'])\n"
         | 
| 381 | 
            +
                 ]
         | 
| 382 | 
            +
                }
         | 
| 383 | 
            +
               ],
         | 
| 384 | 
            +
               "source": [
         | 
| 385 | 
            +
                "# Save results\n",
         | 
| 386 | 
            +
                "os.makedirs(\"../../nq/\", exist_ok=True)\n",
         | 
| 387 | 
            +
                "print(precisions.keys())\n",
         | 
| 388 | 
            +
                "for prec in range(1, n_drafts+1):\n",
         | 
| 389 | 
            +
                "    out_path = f\"../nq/eval_{model_type}_{prec}_test.jsonl\"\n",
         | 
| 390 | 
            +
                "    with open(out_path, \"w\") as f:\n",
         | 
| 391 | 
            +
                "        for obj in precisions[str(prec)]:    \n",
         | 
| 392 | 
            +
                "            f.write(json.dumps(obj) + \"\\n\")"
         | 
| 393 | 
            +
               ]
         | 
| 394 | 
            +
              }
         | 
| 395 | 
            +
             ],
         | 
| 396 | 
            +
             "metadata": {
         | 
| 397 | 
            +
              "kernelspec": {
         | 
| 398 | 
            +
               "display_name": "Python 3 (ipykernel)",
         | 
| 399 | 
            +
               "language": "python",
         | 
| 400 | 
            +
               "name": "python3"
         | 
| 401 | 
            +
              },
         | 
| 402 | 
            +
              "language_info": {
         | 
| 403 | 
            +
               "codemirror_mode": {
         | 
| 404 | 
            +
                "name": "ipython",
         | 
| 405 | 
            +
                "version": 3
         | 
| 406 | 
            +
               },
         | 
| 407 | 
            +
               "file_extension": ".py",
         | 
| 408 | 
            +
               "mimetype": "text/x-python",
         | 
| 409 | 
            +
               "name": "python",
         | 
| 410 | 
            +
               "nbconvert_exporter": "python",
         | 
| 411 | 
            +
               "pygments_lexer": "ipython3",
         | 
| 412 | 
            +
               "version": "3.11.5"
         | 
| 413 | 
            +
              }
         | 
| 414 | 
            +
             },
         | 
| 415 | 
            +
             "nbformat": 4,
         | 
| 416 | 
            +
             "nbformat_minor": 4
         | 
| 417 | 
            +
            }
         | 
    	
        superposed/notebooks/triviaqa.ipynb
    ADDED
    
    | @@ -0,0 +1,404 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            {
         | 
| 2 | 
            +
             "cells": [
         | 
| 3 | 
            +
              {
         | 
| 4 | 
            +
               "cell_type": "code",
         | 
| 5 | 
            +
               "execution_count": 1,
         | 
| 6 | 
            +
               "metadata": {},
         | 
| 7 | 
            +
               "outputs": [
         | 
| 8 | 
            +
                {
         | 
| 9 | 
            +
                 "name": "stderr",
         | 
| 10 | 
            +
                 "output_type": "stream",
         | 
| 11 | 
            +
                 "text": [
         | 
| 12 | 
            +
                  "/gscratch/raivn/ethans/miniconda3/envs/llms_12.1/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
         | 
| 13 | 
            +
                  "  from .autonotebook import tqdm as notebook_tqdm\n",
         | 
| 14 | 
            +
                  "<frozen importlib._bootstrap>:241: RuntimeWarning: pyarrow.lib.IpcWriteOptions size changed, may indicate binary incompatibility. Expected 72 from C header, got 88 from PyObject\n",
         | 
| 15 | 
            +
                  "<frozen importlib._bootstrap>:241: RuntimeWarning: pyarrow.lib.IpcReadOptions size changed, may indicate binary incompatibility. Expected 96 from C header, got 104 from PyObject\n",
         | 
| 16 | 
            +
                  "<frozen importlib._bootstrap>:241: RuntimeWarning: pyarrow._fs.FileInfo size changed, may indicate binary incompatibility. Expected 64 from C header, got 88 from PyObject\n",
         | 
| 17 | 
            +
                  "<frozen importlib._bootstrap>:241: RuntimeWarning: pyarrow._fs.FileSelector size changed, may indicate binary incompatibility. Expected 48 from C header, got 72 from PyObject\n",
         | 
| 18 | 
            +
                  "2024-05-30 01:35:17.813978: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n",
         | 
| 19 | 
            +
                  "2024-05-30 01:35:20.452213: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
         | 
| 20 | 
            +
                  "To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
         | 
| 21 | 
            +
                  "2024-05-30 01:35:41.833487: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n"
         | 
| 22 | 
            +
                 ]
         | 
| 23 | 
            +
                }
         | 
| 24 | 
            +
               ],
         | 
| 25 | 
            +
               "source": [
         | 
| 26 | 
            +
                "%load_ext autoreload\n",
         | 
| 27 | 
            +
                "%autoreload 2\n",
         | 
| 28 | 
            +
                "\n",
         | 
| 29 | 
            +
                "import copy\n",
         | 
| 30 | 
            +
                "import json\n",
         | 
| 31 | 
            +
                "import pickle\n",
         | 
| 32 | 
            +
                "import os\n",
         | 
| 33 | 
            +
                "import random\n",
         | 
| 34 | 
            +
                "import re\n",
         | 
| 35 | 
            +
                "import string\n",
         | 
| 36 | 
            +
                "import math\n",
         | 
| 37 | 
            +
                "from datetime import datetime\n",
         | 
| 38 | 
            +
                "\n",
         | 
| 39 | 
            +
                "import evaluate\n",
         | 
| 40 | 
            +
                "import torch\n",
         | 
| 41 | 
            +
                "import numpy as np\n",
         | 
| 42 | 
            +
                "from datasets import load_dataset\n",
         | 
| 43 | 
            +
                "from transformers import LlamaTokenizer\n",
         | 
| 44 | 
            +
                "from tqdm import tqdm\n",
         | 
| 45 | 
            +
                "\n",
         | 
| 46 | 
            +
                "from eval import *\n",
         | 
| 47 | 
            +
                "from superposed.llama.metrics import *\n",
         | 
| 48 | 
            +
                "from superposed.llama.generation import Llama\n",
         | 
| 49 | 
            +
                "from superposed.llama.superposed_generation import SuperposedLlama\n",
         | 
| 50 | 
            +
                "from superposed.llama.tokenizer import Tokenizer\n",
         | 
| 51 | 
            +
                "from superposed.ngrams.ngram_models import make_models"
         | 
| 52 | 
            +
               ]
         | 
| 53 | 
            +
              },
         | 
| 54 | 
            +
              {
         | 
| 55 | 
            +
               "cell_type": "markdown",
         | 
| 56 | 
            +
               "metadata": {},
         | 
| 57 | 
            +
               "source": [
         | 
| 58 | 
            +
                "# Setup"
         | 
| 59 | 
            +
               ]
         | 
| 60 | 
            +
              },
         | 
| 61 | 
            +
              {
         | 
| 62 | 
            +
               "cell_type": "code",
         | 
| 63 | 
            +
               "execution_count": 3,
         | 
| 64 | 
            +
               "metadata": {},
         | 
| 65 | 
            +
               "outputs": [
         | 
| 66 | 
            +
                {
         | 
| 67 | 
            +
                 "name": "stdout",
         | 
| 68 | 
            +
                 "output_type": "stream",
         | 
| 69 | 
            +
                 "text": [
         | 
| 70 | 
            +
                  "Parameters: {'alpha': 0.54, 'temp': 0.06, 'n_drafts': 3, 'prompt_len': 15, 'n_token_sample': 9, 'n_token_consider': 32000, 'mixing_method': 'sample_new_weights_with_score', 'smoothing': 'geom', 'sample_tokens': 0, 'sample_beams': 0, 'i_weights': [0.01, 0.04, 0.15, 0.18, 0.12], 'i_length': [1, 2, 3, 4, 5]}\n"
         | 
| 71 | 
            +
                 ]
         | 
| 72 | 
            +
                }
         | 
| 73 | 
            +
               ],
         | 
| 74 | 
            +
               "source": [
         | 
| 75 | 
            +
                "# Params\n",
         | 
| 76 | 
            +
                "param_file = \"../../params/p15_d3_mixed.json\"\n",
         | 
| 77 | 
            +
                "with open(param_file, \"r\") as f:\n",
         | 
| 78 | 
            +
                "    params = json.load(f)\n",
         | 
| 79 | 
            +
                "    print(f\"Parameters: {params}\")\n",
         | 
| 80 | 
            +
                "alpha = params[\"alpha\"]\n",
         | 
| 81 | 
            +
                "temp = params[\"temp\"]\n",
         | 
| 82 | 
            +
                "n_drafts = params[\"n_drafts\"]\n",
         | 
| 83 | 
            +
                "prompt_len = params[\"prompt_len\"]\n",
         | 
| 84 | 
            +
                "n_token_sample = params[\"n_token_sample\"]\n",
         | 
| 85 | 
            +
                "i_weights = params[\"i_weights\"]\n",
         | 
| 86 | 
            +
                "i_length = params[\"i_length\"]"
         | 
| 87 | 
            +
               ]
         | 
| 88 | 
            +
              },
         | 
| 89 | 
            +
              {
         | 
| 90 | 
            +
               "cell_type": "code",
         | 
| 91 | 
            +
               "execution_count": 5,
         | 
| 92 | 
            +
               "metadata": {
         | 
| 93 | 
            +
                "scrolled": true
         | 
| 94 | 
            +
               },
         | 
| 95 | 
            +
               "outputs": [
         | 
| 96 | 
            +
                {
         | 
| 97 | 
            +
                 "name": "stdout",
         | 
| 98 | 
            +
                 "output_type": "stream",
         | 
| 99 | 
            +
                 "text": [
         | 
| 100 | 
            +
                  "Making bigram...\n",
         | 
| 101 | 
            +
                  "1310800\n",
         | 
| 102 | 
            +
                  "Making trigram...\n",
         | 
| 103 | 
            +
                  "671088728\n",
         | 
| 104 | 
            +
                  "Making fourgram...\n",
         | 
| 105 | 
            +
                  "2684354648\n",
         | 
| 106 | 
            +
                  "Making fivegram...\n",
         | 
| 107 | 
            +
                  "5368709200\n",
         | 
| 108 | 
            +
                  "Making sixgram...\n",
         | 
| 109 | 
            +
                  "5368709200\n"
         | 
| 110 | 
            +
                 ]
         | 
| 111 | 
            +
                }
         | 
| 112 | 
            +
               ],
         | 
| 113 | 
            +
               "source": [
         | 
| 114 | 
            +
                "ngrams = make_models(\"../../ckpts-200k\", bigram=True, trigram=True, fourgram=True, fivegram=True, sixgram=True, sevengram=False)"
         | 
| 115 | 
            +
               ]
         | 
| 116 | 
            +
              },
         | 
| 117 | 
            +
              {
         | 
| 118 | 
            +
               "cell_type": "code",
         | 
| 119 | 
            +
               "execution_count": 10,
         | 
| 120 | 
            +
               "metadata": {},
         | 
| 121 | 
            +
               "outputs": [],
         | 
| 122 | 
            +
               "source": [
         | 
| 123 | 
            +
                "sup_device = torch.device(\"cuda:0\")\n",
         | 
| 124 | 
            +
                "reg_device = torch.device(\"cuda:1\")"
         | 
| 125 | 
            +
               ]
         | 
| 126 | 
            +
              },
         | 
| 127 | 
            +
              {
         | 
| 128 | 
            +
               "cell_type": "code",
         | 
| 129 | 
            +
               "execution_count": 11,
         | 
| 130 | 
            +
               "metadata": {},
         | 
| 131 | 
            +
               "outputs": [
         | 
| 132 | 
            +
                {
         | 
| 133 | 
            +
                 "name": "stdout",
         | 
| 134 | 
            +
                 "output_type": "stream",
         | 
| 135 | 
            +
                 "text": [
         | 
| 136 | 
            +
                  "> initializing model parallel with size 1\n",
         | 
| 137 | 
            +
                  "> initializing ddp with size 1\n",
         | 
| 138 | 
            +
                  "> initializing pipeline with size 1\n"
         | 
| 139 | 
            +
                 ]
         | 
| 140 | 
            +
                },
         | 
| 141 | 
            +
                {
         | 
| 142 | 
            +
                 "name": "stderr",
         | 
| 143 | 
            +
                 "output_type": "stream",
         | 
| 144 | 
            +
                 "text": [
         | 
| 145 | 
            +
                  "/gscratch/raivn/ethans/miniconda3/envs/llms_12.1/lib/python3.11/site-packages/torch/__init__.py:614: UserWarning: torch.set_default_tensor_type() is deprecated as of PyTorch 2.1, please use torch.set_default_dtype() and torch.set_default_device() as alternatives. (Triggered internally at ../torch/csrc/tensor/python_tensor.cpp:451.)\n",
         | 
| 146 | 
            +
                  "  _C._set_default_tensor_type(t)\n"
         | 
| 147 | 
            +
                 ]
         | 
| 148 | 
            +
                },
         | 
| 149 | 
            +
                {
         | 
| 150 | 
            +
                 "name": "stdout",
         | 
| 151 | 
            +
                 "output_type": "stream",
         | 
| 152 | 
            +
                 "text": [
         | 
| 153 | 
            +
                  "Loaded in 22.07 seconds\n",
         | 
| 154 | 
            +
                  "cuda:0\n"
         | 
| 155 | 
            +
                 ]
         | 
| 156 | 
            +
                }
         | 
| 157 | 
            +
               ],
         | 
| 158 | 
            +
               "source": [
         | 
| 159 | 
            +
                "weight_path = \"../../7B/\"\n",
         | 
| 160 | 
            +
                "sup_model = SuperposedLlama.build(ckpt_dir=weight_path, \n",
         | 
| 161 | 
            +
                "                                 tokenizer_path=f'{weight_path}/tokenizer.model', \n",
         | 
| 162 | 
            +
                "                                 max_seq_len=1000, \n",
         | 
| 163 | 
            +
                "                                 max_batch_size=16,\n",
         | 
| 164 | 
            +
                "                                 device=sup_device,\n",
         | 
| 165 | 
            +
                "                                 model_parallel_size=1)"
         | 
| 166 | 
            +
               ]
         | 
| 167 | 
            +
              },
         | 
| 168 | 
            +
              {
         | 
| 169 | 
            +
               "cell_type": "code",
         | 
| 170 | 
            +
               "execution_count": 12,
         | 
| 171 | 
            +
               "metadata": {},
         | 
| 172 | 
            +
               "outputs": [
         | 
| 173 | 
            +
                {
         | 
| 174 | 
            +
                 "name": "stdout",
         | 
| 175 | 
            +
                 "output_type": "stream",
         | 
| 176 | 
            +
                 "text": [
         | 
| 177 | 
            +
                  "0\n",
         | 
| 178 | 
            +
                  "Loaded in 22.76 seconds\n"
         | 
| 179 | 
            +
                 ]
         | 
| 180 | 
            +
                }
         | 
| 181 | 
            +
               ],
         | 
| 182 | 
            +
               "source": [
         | 
| 183 | 
            +
                "reg_model = Llama.build(ckpt_dir=weight_path, \n",
         | 
| 184 | 
            +
                "                    tokenizer_path=f'{weight_path}/tokenizer.model', \n",
         | 
| 185 | 
            +
                "                    max_seq_len=1000, \n",
         | 
| 186 | 
            +
                "                    max_batch_size=16,\n",
         | 
| 187 | 
            +
                "                    device=reg_device,\n",
         | 
| 188 | 
            +
                "                    model_parallel_size=1)"
         | 
| 189 | 
            +
               ]
         | 
| 190 | 
            +
              },
         | 
| 191 | 
            +
              {
         | 
| 192 | 
            +
               "cell_type": "code",
         | 
| 193 | 
            +
               "execution_count": 18,
         | 
| 194 | 
            +
               "metadata": {},
         | 
| 195 | 
            +
               "outputs": [],
         | 
| 196 | 
            +
               "source": [
         | 
| 197 | 
            +
                "tokenizer = Tokenizer(f\"{weight_path}/tokenizer.model\")"
         | 
| 198 | 
            +
               ]
         | 
| 199 | 
            +
              },
         | 
| 200 | 
            +
              {
         | 
| 201 | 
            +
               "cell_type": "markdown",
         | 
| 202 | 
            +
               "metadata": {},
         | 
| 203 | 
            +
               "source": [
         | 
| 204 | 
            +
                "# Evaluation"
         | 
| 205 | 
            +
               ]
         | 
| 206 | 
            +
              },
         | 
| 207 | 
            +
              {
         | 
| 208 | 
            +
               "cell_type": "code",
         | 
| 209 | 
            +
               "execution_count": 13,
         | 
| 210 | 
            +
               "metadata": {},
         | 
| 211 | 
            +
               "outputs": [
         | 
| 212 | 
            +
                {
         | 
| 213 | 
            +
                 "name": "stdout",
         | 
| 214 | 
            +
                 "output_type": "stream",
         | 
| 215 | 
            +
                 "text": [
         | 
| 216 | 
            +
                  "Length: 7993\n"
         | 
| 217 | 
            +
                 ]
         | 
| 218 | 
            +
                }
         | 
| 219 | 
            +
               ],
         | 
| 220 | 
            +
               "source": [
         | 
| 221 | 
            +
                "trivia_path = \"../../../datasets/qa/wikipedia-dev.json\"\n",
         | 
| 222 | 
            +
                "with open(trivia_path, \"r\") as f:\n",
         | 
| 223 | 
            +
                "    triviaqa = json.load(f)[\"Data\"]\n",
         | 
| 224 | 
            +
                "print(f\"Length: {len(triviaqa)}\")"
         | 
| 225 | 
            +
               ]
         | 
| 226 | 
            +
              },
         | 
| 227 | 
            +
              {
         | 
| 228 | 
            +
               "cell_type": "code",
         | 
| 229 | 
            +
               "execution_count": 14,
         | 
| 230 | 
            +
               "metadata": {},
         | 
| 231 | 
            +
               "outputs": [],
         | 
| 232 | 
            +
               "source": [
         | 
| 233 | 
            +
                "torch.set_default_dtype(torch.float32)"
         | 
| 234 | 
            +
               ]
         | 
| 235 | 
            +
              },
         | 
| 236 | 
            +
              {
         | 
| 237 | 
            +
               "cell_type": "code",
         | 
| 238 | 
            +
               "execution_count": 15,
         | 
| 239 | 
            +
               "metadata": {},
         | 
| 240 | 
            +
               "outputs": [],
         | 
| 241 | 
            +
               "source": [
         | 
| 242 | 
            +
                "model_types = [\"superposed\", \"regular\"]\n",
         | 
| 243 | 
            +
                "model_type = model_types[0]"
         | 
| 244 | 
            +
               ]
         | 
| 245 | 
            +
              },
         | 
| 246 | 
            +
              {
         | 
| 247 | 
            +
               "cell_type": "code",
         | 
| 248 | 
            +
               "execution_count": 16,
         | 
| 249 | 
            +
               "metadata": {},
         | 
| 250 | 
            +
               "outputs": [],
         | 
| 251 | 
            +
               "source": [
         | 
| 252 | 
            +
                "# https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/triviaqa/default.yaml\n",
         | 
| 253 | 
            +
                "def evaluate_trivia(model_type, question, max_gen_len):\n",
         | 
| 254 | 
            +
                "    question = \"Question: \" + question + \"\\nAnswer:\"\n",
         | 
| 255 | 
            +
                "    text_len = len(question) # for truncating\n",
         | 
| 256 | 
            +
                "    prompt_len = len(tokenizer.encode([question], True, False)[0]) # for model\n",
         | 
| 257 | 
            +
                "    if model_type == \"regular\":\n",
         | 
| 258 | 
            +
                "        input = [question for _ in range(n_drafts)]\n",
         | 
| 259 | 
            +
                "        sequences, _ = evaluate_nucleus_losses(data=input,\n",
         | 
| 260 | 
            +
                "                                               model=reg_model,\n",
         | 
| 261 | 
            +
                "                                               tokenizer=tokenizer,\n",
         | 
| 262 | 
            +
                "                                               prompt_len=prompt_len,\n",
         | 
| 263 | 
            +
                "                                               max_gen_len=max_gen_len,\n",
         | 
| 264 | 
            +
                "                                               temp=0.6, # Set to 0 for greedy\n",
         | 
| 265 | 
            +
                "                                               bsz=8,\n",
         | 
| 266 | 
            +
                "                                               marker=False)\n",
         | 
| 267 | 
            +
                "        n_pd, seq_len = sequences.shape\n",
         | 
| 268 | 
            +
                "    elif model_type == \"superposed\":\n",
         | 
| 269 | 
            +
                "        sequences, _ = evaluate_mixed_losses(data=[question],\n",
         | 
| 270 | 
            +
                "                                                   model=sup_model,\n",
         | 
| 271 | 
            +
                "                                                   tokenizer=tokenizer,\n",
         | 
| 272 | 
            +
                "                                                   prompt_len=prompt_len,\n",
         | 
| 273 | 
            +
                "                                                   max_gen_len=max_gen_len,\n",
         | 
| 274 | 
            +
                "                                                   alpha=alpha,\n",
         | 
| 275 | 
            +
                "                                                   temp=temp,\n",
         | 
| 276 | 
            +
                "                                                   n_drafts=n_drafts,\n",
         | 
| 277 | 
            +
                "                                                   n_token_sample=n_token_sample,\n",
         | 
| 278 | 
            +
                "                                                   smoothing=None, # greedy\n",
         | 
| 279 | 
            +
                "                                                   bsz=8,\n",
         | 
| 280 | 
            +
                "                                                   i_weights=i_weights,\n",
         | 
| 281 | 
            +
                "                                                   i_length=i_length,\n",
         | 
| 282 | 
            +
                "                                                   ngrams=ngrams,\n",
         | 
| 283 | 
            +
                "                                                   marker=False)\n",
         | 
| 284 | 
            +
                "        n_p, n_d, seq_len = sequences.shape\n",
         | 
| 285 | 
            +
                "    # Process results\n",
         | 
| 286 | 
            +
                "    sequences = sequences.reshape(-1, seq_len).tolist()\n",
         | 
| 287 | 
            +
                "    for d_idx in range(len(sequences)):\n",
         | 
| 288 | 
            +
                "        draft = sequences[d_idx]\n",
         | 
| 289 | 
            +
                "        if -1 in draft:\n",
         | 
| 290 | 
            +
                "            draft = draft[:draft.index(-1)]\n",
         | 
| 291 | 
            +
                "        sequences[d_idx] = draft\n",
         | 
| 292 | 
            +
                "    decoded_seq = tokenizer.decode(sequences)\n",
         | 
| 293 | 
            +
                "    answers = []\n",
         | 
| 294 | 
            +
                "    for s in decoded_seq:\n",
         | 
| 295 | 
            +
                "        # print(s)\n",
         | 
| 296 | 
            +
                "        answers.append(re.split(\"[,.\\n]\", s[text_len:].strip())[0])\n",
         | 
| 297 | 
            +
                "    return answers\n",
         | 
| 298 | 
            +
                "            "
         | 
| 299 | 
            +
               ]
         | 
| 300 | 
            +
              },
         | 
| 301 | 
            +
              {
         | 
| 302 | 
            +
               "cell_type": "code",
         | 
| 303 | 
            +
               "execution_count": null,
         | 
| 304 | 
            +
               "metadata": {},
         | 
| 305 | 
            +
               "outputs": [],
         | 
| 306 | 
            +
               "source": [
         | 
| 307 | 
            +
                "questions = {}\n",
         | 
| 308 | 
            +
                "predictions = {}\n",
         | 
| 309 | 
            +
                "print(f\"Precision from 1 to {n_drafts}\")\n",
         | 
| 310 | 
            +
                "for sample in tqdm(triviaqa):\n",
         | 
| 311 | 
            +
                "    # Adaptively select generation length\n",
         | 
| 312 | 
            +
                "    longest = 0\n",
         | 
| 313 | 
            +
                "    shortest = 1000\n",
         | 
| 314 | 
            +
                "    total = 0\n",
         | 
| 315 | 
            +
                "    for answer in sample[\"Answer\"][\"Aliases\"]:\n",
         | 
| 316 | 
            +
                "        tmp = tokenizer.encode([answer], False, False)[0]\n",
         | 
| 317 | 
            +
                "        if len(tmp) > longest:\n",
         | 
| 318 | 
            +
                "            longest = len(tmp)\n",
         | 
| 319 | 
            +
                "        if len(tmp) < shortest:\n",
         | 
| 320 | 
            +
                "            shortest = len(tmp)\n",
         | 
| 321 | 
            +
                "        total += len(tmp)\n",
         | 
| 322 | 
            +
                "    # Evaluation code\n",
         | 
| 323 | 
            +
                "    id = sample[\"QuestionId\"]\n",
         | 
| 324 | 
            +
                "    question = sample[\"Question\"]\n",
         | 
| 325 | 
            +
                "    answer = evaluate_trivia(model_type, question, max_gen_len=longest + 3)\n",
         | 
| 326 | 
            +
                "    predictions[id] = answer\n",
         | 
| 327 | 
            +
                "    questions[id] = question"
         | 
| 328 | 
            +
               ]
         | 
| 329 | 
            +
              },
         | 
| 330 | 
            +
              {
         | 
| 331 | 
            +
               "cell_type": "code",
         | 
| 332 | 
            +
               "execution_count": null,
         | 
| 333 | 
            +
               "metadata": {},
         | 
| 334 | 
            +
               "outputs": [],
         | 
| 335 | 
            +
               "source": [
         | 
| 336 | 
            +
                "# Save precisions\n",
         | 
| 337 | 
            +
                "precisions = {}\n",
         | 
| 338 | 
            +
                "for i in range(1, n_drafts+1):\n",
         | 
| 339 | 
            +
                "    prec = str(i)\n",
         | 
| 340 | 
            +
                "    responses = {k: v[:i] for k, v in predictions.items()}\n",
         | 
| 341 | 
            +
                "    precisions[prec] = responses"
         | 
| 342 | 
            +
               ]
         | 
| 343 | 
            +
              },
         | 
| 344 | 
            +
              {
         | 
| 345 | 
            +
               "cell_type": "code",
         | 
| 346 | 
            +
               "execution_count": null,
         | 
| 347 | 
            +
               "metadata": {},
         | 
| 348 | 
            +
               "outputs": [],
         | 
| 349 | 
            +
               "source": [
         | 
| 350 | 
            +
                "# Print some results\n",
         | 
| 351 | 
            +
                "counter = 0\n",
         | 
| 352 | 
            +
                "for k in predictions:\n",
         | 
| 353 | 
            +
                "    if counter >= 10:\n",
         | 
| 354 | 
            +
                "        break\n",
         | 
| 355 | 
            +
                "    print(questions[k])\n",
         | 
| 356 | 
            +
                "    print(predictions[k])\n",
         | 
| 357 | 
            +
                "    counter += 1\n",
         | 
| 358 | 
            +
                "    print(\"================\")"
         | 
| 359 | 
            +
               ]
         | 
| 360 | 
            +
              },
         | 
| 361 | 
            +
              {
         | 
| 362 | 
            +
               "cell_type": "code",
         | 
| 363 | 
            +
               "execution_count": null,
         | 
| 364 | 
            +
               "metadata": {},
         | 
| 365 | 
            +
               "outputs": [],
         | 
| 366 | 
            +
               "source": [
         | 
| 367 | 
            +
                "# Save results\n",
         | 
| 368 | 
            +
                "os.makedirs(\"../../trivia/\", exist_ok=True)\n",
         | 
| 369 | 
            +
                "for prec in range(1, n_drafts+1):\n",
         | 
| 370 | 
            +
                "    out_path = f\"../nucleus_extra/trivia_extra/ngram_4trivia_{model_type}_{prec}_4.json\"\n",
         | 
| 371 | 
            +
                "    with open(out_path, \"w\") as f:\n",
         | 
| 372 | 
            +
                "        json.dump(precisions[str(prec)], f, indent=4)"
         | 
| 373 | 
            +
               ]
         | 
| 374 | 
            +
              },
         | 
| 375 | 
            +
              {
         | 
| 376 | 
            +
               "cell_type": "code",
         | 
| 377 | 
            +
               "execution_count": null,
         | 
| 378 | 
            +
               "metadata": {},
         | 
| 379 | 
            +
               "outputs": [],
         | 
| 380 | 
            +
               "source": []
         | 
| 381 | 
            +
              }
         | 
| 382 | 
            +
             ],
         | 
| 383 | 
            +
             "metadata": {
         | 
| 384 | 
            +
              "kernelspec": {
         | 
| 385 | 
            +
               "display_name": "Python 3 (ipykernel)",
         | 
| 386 | 
            +
               "language": "python",
         | 
| 387 | 
            +
               "name": "python3"
         | 
| 388 | 
            +
              },
         | 
| 389 | 
            +
              "language_info": {
         | 
| 390 | 
            +
               "codemirror_mode": {
         | 
| 391 | 
            +
                "name": "ipython",
         | 
| 392 | 
            +
                "version": 3
         | 
| 393 | 
            +
               },
         | 
| 394 | 
            +
               "file_extension": ".py",
         | 
| 395 | 
            +
               "mimetype": "text/x-python",
         | 
| 396 | 
            +
               "name": "python",
         | 
| 397 | 
            +
               "nbconvert_exporter": "python",
         | 
| 398 | 
            +
               "pygments_lexer": "ipython3",
         | 
| 399 | 
            +
               "version": "3.11.5"
         | 
| 400 | 
            +
              }
         | 
| 401 | 
            +
             },
         | 
| 402 | 
            +
             "nbformat": 4,
         | 
| 403 | 
            +
             "nbformat_minor": 4
         | 
| 404 | 
            +
            }
         |