# Adopted from https://github.com/haotian-liu/LLaVA. Below is the original copyright: # Copyright 2023 Haotian Liu # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import dataclasses from enum import auto, Enum from typing import List, Any, Union, Tuple import base64 from io import BytesIO from PIL import Image class SeparatorStyle(Enum): """Different separator style.""" TWO = auto() PLAIN = auto() LLAMA_2 = auto() LLAMA_3 = auto() QWEN = auto() @dataclasses.dataclass class Conversation: """A class that keeps all conversation history.""" system: str roles: List[str] messages: List[List[str]] offset: int sep_style: SeparatorStyle = SeparatorStyle.PLAIN sep: str = "###" sep2: str = None version: str = "Unknown" tokenizer_id: str = "" tokenizer: Any = None # Stop criteria (the default one is EOS token) stop_str: Union[str, List[str]] = None # Stops generation if meeting any token in this list stop_token_ids: List[int] = None skip_next: bool = False def get_prompt(self, tokenizer = None, system_prompt = None): self.tokenizer = tokenizer messages = self.messages if self.sep_style == SeparatorStyle.TWO: seps = [self.sep, self.sep2] ret = self.system + seps[0] for i, (role, message) in enumerate(messages): if message: if type(message) is tuple: message = message[0] ret += role + ": " + message + seps[i % 2] else: ret += role + ":" elif self.sep_style == SeparatorStyle.LLAMA_3: wrap_sys = lambda msg: f"<|start_header_id|>system<|end_header_id|>\n\n{msg}<|eot_id|>" if len(msg) > 0 else msg if system_prompt is not None: self.system = system_prompt chat_template_messages = [{"role": "system", "content": self.system}] for role, message in messages: if message: if type(message) is tuple: message = message[0] chat_template_messages.append({"role": role, "content": message}) # print(chat_template_messages) return self.tokenizer.apply_chat_template(chat_template_messages, tokenize=False, add_generation_prompt=True) # ret = "<|begin_of_text|>" + wrap_sys(self.system) # for i, (role, message) in enumerate(messages): # if message: # if type(message) is tuple: # message = message[0] # ret += f"<|start_header_id|>{role}<|end_header_id|>\n\n" # ret += message.strip() + self.sep2 # else: # ret += f"<|start_header_id|>{role}<|end_header_id|>\n\n" # return ret elif self.sep_style == SeparatorStyle.QWEN: # Define helper function for wrapping the system message wrap_sys = lambda msg: ( f"<|im_start|>system\n{msg}<|im_end|>\n" if len(msg) > 0 else "<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n" ) if system_prompt is not None: self.system = system_prompt chat_template_messages = [{"role": "system", "content": self.system}] for role, message in messages: if message: if type(message) is tuple: message = message[0] chat_template_messages.append({"role": role, "content": message}) # print(chat_template_messages) return self.tokenizer.apply_chat_template(chat_template_messages, tokenize=False, add_generation_prompt=True) # ret = wrap_sys(self.system) # for i, (role, message) in enumerate(messages): # if message: # if type(message) is tuple: # message = message[0] # ret += f"<|im_start|>{role}\n" # ret += message.strip() + "<|im_end|>\n" # else: # ret += f"<|im_start|>{role}\n" # return ret elif self.sep_style == SeparatorStyle.LLAMA_2: wrap_sys = lambda msg: f"<>\n{msg}\n<>\n\n" if len(msg) > 0 else msg wrap_inst = lambda msg: f"[INST] {msg} [/INST]" ret = "" for i, (role, message) in enumerate(messages): if i == 0: assert message, "first message should not be none" assert role == self.roles[0], "first message should come from user" if message: if type(message) is tuple: message, _, _ = message if i == 0: message = wrap_sys(self.system) + message if i % 2 == 0: message = wrap_inst(message) ret += self.sep + message else: ret += " " + message + " " + self.sep2 else: ret += "" ret = ret.lstrip(self.sep) elif self.sep_style == SeparatorStyle.PLAIN: seps = [self.sep, self.sep2] ret = self.system for i, (role, message) in enumerate(messages): if message: if type(message) is tuple: message, _, _ = message ret += message + seps[i % 2] else: ret += "" else: raise ValueError(f"Invalid style: {self.sep_style}") return ret def append_message(self, role, message): self.messages.append([role, message]) def to_gradio_chatbot(self): ret = [] for i, (role, msg) in enumerate(self.messages[self.offset:]): if i % 2 == 0: if type(msg) is tuple: msg, speech = msg ret.append([msg, None]) else: ret.append([msg, None]) else: ret[-1][-1] = msg return ret def copy(self): return Conversation( system=self.system, roles=self.roles, messages=[[x, y] for x, y in self.messages], offset=self.offset, sep_style=self.sep_style, sep=self.sep, sep2=self.sep2, version=self.version) def dict(self): if len(self.get_images()) > 0: return { "system": self.system, "roles": self.roles, "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages], "offset": self.offset, "sep": self.sep, "sep2": self.sep2, } return { "system": self.system, "roles": self.roles, "messages": self.messages, "offset": self.offset, "sep": self.sep, "sep2": self.sep2, } conv_vicuna_v1 = Conversation( system="A chat between a curious user and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the user's questions.", roles=("USER", "ASSISTANT"), version="v1", messages=[], offset=0, sep_style=SeparatorStyle.TWO, sep=" ", sep2="", ) conv_llama_2 = Conversation( system="You are a helpful language and speech assistant. " "You are able to understand the speech content that the user provides, " "and assist the user with a variety of tasks using natural language.", roles=("USER", "ASSISTANT"), version="llama_v2", messages=[], offset=0, sep_style=SeparatorStyle.LLAMA_2, sep="", sep2="", ) conv_llama_3 = Conversation( system="You are a helpful language and speech assistant. " "You are able to understand the speech content that the user provides, " "and assist the user with a variety of tasks using natural language.", roles=("user", "assistant"), version="llama_v3", messages=[], offset=0, sep_style=SeparatorStyle.LLAMA_3, sep="", sep2="<|eot_id|>" ) conv_qwen = Conversation( system="""You are a helpful language and speech assistant. You are able to understand the speech content that the user provides, and assist the user with a variety of tasks using natural language.""", roles=("user", "assistant"), version="qwen", messages=[], offset=0, sep_style=SeparatorStyle.QWEN, sep="<|im_end|>", ) conv_plain = Conversation( system="", roles=("", ""), messages=( ), offset=0, sep_style=SeparatorStyle.PLAIN, sep="", ) default_conversation = conv_llama_3 conv_templates = { "v1": conv_vicuna_v1, "plain": conv_plain, "llama_2": conv_llama_2, "llama_3": conv_llama_3, "qwen": conv_qwen, } if __name__ == "__main__": print(default_conversation.get_prompt())