File size: 6,651 Bytes
75e8fca a954356 cdc20d7 75e8fca 453a88d 75e8fca 6210f88 4074ba1 75e8fca 453a88d 6210f88 1f8cdd1 a2dad16 453a88d 9ad661a 4710436 453a88d 0856426 453a88d d20c7c3 4710436 453a88d 75e8fca 6210f88 453a88d 75e8fca 453a88d fb3bd09 ebc2265 75e8fca ebc2265 453a88d 4710436 453a88d 9ad661a 4710436 453a88d 54ea669 75e8fca 453a88d 8614d5a 6210f88 453a88d 815a27e 6210f88 36d0f40 fa98bc1 453a88d fa98bc1 453a88d fa98bc1 0856426 453a88d fb3bd09 453a88d 1f8cdd1 453a88d cdc20d7 453a88d fa98bc1 6210f88 453a88d bf9dd88 a2dad16 453a88d 815a27e 8614d5a a2dad16 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
from gradio_client import Client
import os
import asyncio
import json
from concurrent.futures import wait
HF_TOKEN = os.getenv("HF_TOKEN")
falcon_userid_threadid_dictionary = {}
threadid_conversation = {}
# Instructions are for Falcon-chat and can be found here: https://huggingface.co/spaces/HuggingFaceH4/falcon-chat
instructions = "The following is a conversation between a highly knowledgeable and intelligent AI assistant, called Falcon, and a human user, called User. In the following interactions, User and Falcon will converse in natural language, and Falcon will answer User's questions. Falcon was built to be respectful, polite and inclusive. Falcon was built by the Technology Innovation Institute in Abu Dhabi. Falcon will never decline to answer a question, and always attempts to give an answer that User would be satisfied with. It knows a lot, and always tells the truth. The conversation begins."
falcon_client = Client("HuggingFaceH4/falcon-chat", HF_TOKEN)
BOT_USER_ID = (
1086256910572986469 if os.getenv("TEST_ENV", False) else 1102236653545861151
)
FALCON_CHANNEL_ID = (
1079459939405279232 if os.getenv("TEST_ENV", False) else 1119313248056004729
)
async def wait(job):
while not job.done():
await asyncio.sleep(0.2)
def falcon_initial_generation(prompt, instructions, thread):
"""Solves two problems at once; 1) The Slash command + job.submit interaction, and 2) the need for job.submit in order to locate the full generated text"""
global threadid_conversation
chathistory = falcon_client.predict(fn_index=5)
temperature = 0.8
p_nucleus_sampling = 0.9
job = falcon_client.submit(
prompt, chathistory, instructions, temperature, p_nucleus_sampling, fn_index=1
)
while job.done() is False:
pass
else:
if os.environ.get("TEST_ENV") == "True":
print("falcon text gen job done")
file_paths = job.outputs()
print(file_paths)
full_generation = file_paths[-1]
print(full_generation)
with open(full_generation, "r") as file:
data = json.load(file)
print(data)
output_text = data[-1][-1]
threadid_conversation[thread.id] = full_generation
if len(output_text) > 1300:
output_text = (
output_text[:1300]
+ "...\nTruncating response to 2000 characters due to discord api limits."
)
if os.environ.get("TEST_ENV") == "True":
print(output_text)
return output_text
async def try_falcon(interaction, prompt):
"""Generates text based on a given prompt"""
try:
global falcon_userid_threadid_dictionary # tracks userid-thread existence
global threadid_conversation
if interaction.user.id != BOT_USER_ID:
if interaction.channel.id == FALCON_CHANNEL_ID:
if os.environ.get("TEST_ENV") == "True":
print("Safetychecks passed for try_falcon")
await interaction.response.send_message("Working on it!")
channel = interaction.channel
message = await channel.send("Creating thread...")
thread = await message.create_thread(
name=prompt, auto_archive_duration=60
) # interaction.user
await thread.send(
"[DISCLAIMER: HuggingBot is a **highly experimental** beta feature; The Falcon model and system prompt can be found here: https://huggingface.co/spaces/HuggingFaceH4/falcon-chat]"
)
if os.environ.get("TEST_ENV") == "True":
print("Running falcon_initial_generation...")
loop = asyncio.get_running_loop()
output_text = await loop.run_in_executor(
None, falcon_initial_generation, prompt, instructions, thread
)
falcon_userid_threadid_dictionary[thread.id] = interaction.user.id
await thread.send(output_text)
except Exception as e:
print(f"try_falcon Error: {e}")
async def continue_falcon(message):
"""Continues a given conversation based on chathistory"""
try:
if not message.author.bot:
global falcon_userid_threadid_dictionary # tracks userid-thread existence
if (
message.channel.id in falcon_userid_threadid_dictionary
): # is this a valid thread?
if (
falcon_userid_threadid_dictionary[message.channel.id]
== message.author.id
): # more than that - is this specifically the right user for this thread?
if os.environ.get("TEST_ENV") == "True":
print("Safetychecks passed for continue_falcon")
global instructions
global threadid_conversation
await message.add_reaction("🔁")
prompt = message.content
chathistory = threadid_conversation[message.channel.id]
temperature = 0.8
p_nucleus_sampling = 0.9
if os.environ.get("TEST_ENV") == "True":
print("Running falcon_client.submit")
job = falcon_client.submit(
prompt,
chathistory,
instructions,
temperature,
p_nucleus_sampling,
fn_index=1,
)
await wait(job)
if os.environ.get("TEST_ENV") == "True":
print("Continue_falcon job done")
file_paths = job.outputs()
full_generation = file_paths[-1]
with open(full_generation, "r") as file:
data = json.load(file)
output_text = data[-1][-1]
threadid_conversation[
message.channel.id
] = full_generation # overwrite the old file
if len(output_text) > 1300:
output_text = (
output_text[:1300]
+ "...\nTruncating response to 2000 characters due to discord api limits."
)
await message.reply(output_text)
except Exception as e:
print(f"continue_falcon Error: {e}")
await message.reply(f"Error: {e} <@811235357663297546> (continue_falcon error)") |