lunarflu HF Staff commited on
Commit
453a88d
·
1 Parent(s): 5b302b6

[falcon.py] black + ruff

Browse files
Files changed (1) hide show
  1. falcon.py +69 -47
falcon.py CHANGED
@@ -1,108 +1,130 @@
1
- import discord
2
- from discord import app_commands
3
- import gradio as gr
4
  from gradio_client import Client
5
  import os
6
  import asyncio
7
  import json
8
  from concurrent.futures import wait
9
 
10
- HF_TOKEN = os.getenv('HF_TOKEN')
11
  falcon_userid_threadid_dictionary = {}
12
  threadid_conversation = {}
13
  # Instructions are for Falcon-chat and can be found here: https://huggingface.co/spaces/HuggingFaceH4/falcon-chat
14
  instructions = "The following is a conversation between a highly knowledgeable and intelligent AI assistant, called Falcon, and a human user, called User. In the following interactions, User and Falcon will converse in natural language, and Falcon will answer User's questions. Falcon was built to be respectful, polite and inclusive. Falcon was built by the Technology Innovation Institute in Abu Dhabi. Falcon will never decline to answer a question, and always attempts to give an answer that User would be satisfied with. It knows a lot, and always tells the truth. The conversation begins."
15
  falcon_client = Client("HuggingFaceH4/falcon-chat", HF_TOKEN)
16
 
17
- BOT_USER_ID = 1086256910572986469 if os.getenv("TEST_ENV", False) else 1102236653545861151
18
- FALCON_CHANNEL_ID = 1079459939405279232 if os.getenv("TEST_ENV", False) else 1119313248056004729
 
 
 
 
19
 
20
- def falcon_initial_generation(prompt, instructions, thread):
 
21
  """Solves two problems at once; 1) The Slash command + job.submit interaction, and 2) the need for job.submit in order to locate the full generated text"""
22
  global threadid_conversation
23
-
24
- chathistory = falcon_client.predict(fn_index=5)
25
  temperature = 0.8
26
  p_nucleus_sampling = 0.9
27
-
28
- job = falcon_client.submit(prompt, chathistory, instructions, temperature, p_nucleus_sampling, fn_index=1)
 
 
29
  wait([job])
30
- if os.environ.get('TEST_ENV') == 'True':
31
  print("falcon text gen job done")
32
  file_paths = job.outputs()
33
- print(file_paths)
34
  full_generation = file_paths[-1]
35
- print(full_generation)
36
- with open(full_generation, 'r') as file:
37
  data = json.load(file)
38
- print(data)
39
- output_text = data[-1][-1]
40
  threadid_conversation[thread.id] = full_generation
41
- if os.environ.get('TEST_ENV') == 'True':
42
- print(output_text)
43
-
44
  return output_text
45
 
 
46
  async def try_falcon(interaction, prompt):
47
  """Generates text based on a given prompt"""
48
  try:
49
- global falcon_userid_threadid_dictionary # tracks userid-thread existence
50
  global threadid_conversation
51
 
52
  if interaction.user.id != BOT_USER_ID:
53
- if interaction.channel.id == FALCON_CHANNEL_ID:
54
- if os.environ.get('TEST_ENV') == 'True':
55
  print("Safetychecks passed for try_falcon")
56
  await interaction.response.send_message("Working on it!")
57
  channel = interaction.channel
58
  message = await channel.send("Creating thread...")
59
- thread = await message.create_thread(name=prompt, auto_archive_duration=60) # interaction.user
60
- await thread.send("[DISCLAIMER: HuggingBot is a **highly experimental** beta feature; The Falcon model and system prompt can be found here: https://huggingface.co/spaces/HuggingFaceH4/falcon-chat]")
 
 
 
 
61
 
62
- if os.environ.get('TEST_ENV') == 'True':
63
  print("Running falcon_initial_generation...")
64
  loop = asyncio.get_running_loop()
65
- output_text = await loop.run_in_executor(None, falcon_initial_generation, prompt, instructions, thread)
 
 
66
  falcon_userid_threadid_dictionary[thread.id] = interaction.user.id
67
-
68
- await thread.send(output_text)
69
 
 
70
  except Exception as e:
71
  print(f"try_falcon Error: {e}")
72
 
 
73
  async def continue_falcon(message):
74
  """Continues a given conversation based on chathistory"""
75
  try:
76
  if not message.author.bot:
77
- global falcon_userid_threadid_dictionary # tracks userid-thread existence
78
- if message.channel.id in falcon_userid_threadid_dictionary: # is this a valid thread?
79
- if falcon_userid_threadid_dictionary[message.channel.id] == message.author.id: # more than that - is this specifically the right user for this thread?
80
- if os.environ.get('TEST_ENV') == 'True':
81
- print("Safetychecks passed for continue_falcon")
 
 
 
 
 
82
  global instructions
83
  global threadid_conversation
84
- await message.add_reaction('🔁')
85
-
86
  prompt = message.content
87
  chathistory = threadid_conversation[message.channel.id]
88
  temperature = 0.8
89
- p_nucleus_sampling = 0.9
90
-
91
- if os.environ.get('TEST_ENV') == 'True':
92
  print("Running falcon_client.submit")
93
- job = falcon_client.submit(prompt, chathistory, instructions, temperature, p_nucleus_sampling, fn_index=1)
 
 
 
 
 
 
 
94
  wait([job])
95
- if os.environ.get('TEST_ENV') == 'True':
96
  print("Continue_falcon job done")
97
  file_paths = job.outputs()
98
  full_generation = file_paths[-1]
99
- with open(full_generation, 'r') as file:
100
  data = json.load(file)
101
  output_text = data[-1][-1]
102
-
103
- threadid_conversation[message.channel.id] = full_generation # overwrite the old file
104
- await message.reply(output_text)
105
-
106
  except Exception as e:
107
  print(f"continue_falcon Error: {e}")
108
- await message.reply(f"Error: {e} <@811235357663297546> (continue_falcon error)")
 
 
 
 
1
  from gradio_client import Client
2
  import os
3
  import asyncio
4
  import json
5
  from concurrent.futures import wait
6
 
7
+ HF_TOKEN = os.getenv("HF_TOKEN")
8
  falcon_userid_threadid_dictionary = {}
9
  threadid_conversation = {}
10
  # Instructions are for Falcon-chat and can be found here: https://huggingface.co/spaces/HuggingFaceH4/falcon-chat
11
  instructions = "The following is a conversation between a highly knowledgeable and intelligent AI assistant, called Falcon, and a human user, called User. In the following interactions, User and Falcon will converse in natural language, and Falcon will answer User's questions. Falcon was built to be respectful, polite and inclusive. Falcon was built by the Technology Innovation Institute in Abu Dhabi. Falcon will never decline to answer a question, and always attempts to give an answer that User would be satisfied with. It knows a lot, and always tells the truth. The conversation begins."
12
  falcon_client = Client("HuggingFaceH4/falcon-chat", HF_TOKEN)
13
 
14
+ BOT_USER_ID = (
15
+ 1086256910572986469 if os.getenv("TEST_ENV", False) else 1102236653545861151
16
+ )
17
+ FALCON_CHANNEL_ID = (
18
+ 1079459939405279232 if os.getenv("TEST_ENV", False) else 1119313248056004729
19
+ )
20
 
21
+
22
+ def falcon_initial_generation(prompt, instructions, thread):
23
  """Solves two problems at once; 1) The Slash command + job.submit interaction, and 2) the need for job.submit in order to locate the full generated text"""
24
  global threadid_conversation
25
+
26
+ chathistory = falcon_client.predict(fn_index=5)
27
  temperature = 0.8
28
  p_nucleus_sampling = 0.9
29
+
30
+ job = falcon_client.submit(
31
+ prompt, chathistory, instructions, temperature, p_nucleus_sampling, fn_index=1
32
+ )
33
  wait([job])
34
+ if os.environ.get("TEST_ENV") == "True":
35
  print("falcon text gen job done")
36
  file_paths = job.outputs()
37
+ print(file_paths)
38
  full_generation = file_paths[-1]
39
+ print(full_generation)
40
+ with open(full_generation, "r") as file:
41
  data = json.load(file)
42
+ print(data)
43
+ output_text = data[-1][-1]
44
  threadid_conversation[thread.id] = full_generation
45
+ if os.environ.get("TEST_ENV") == "True":
46
+ print(output_text)
 
47
  return output_text
48
 
49
+
50
  async def try_falcon(interaction, prompt):
51
  """Generates text based on a given prompt"""
52
  try:
53
+ global falcon_userid_threadid_dictionary # tracks userid-thread existence
54
  global threadid_conversation
55
 
56
  if interaction.user.id != BOT_USER_ID:
57
+ if interaction.channel.id == FALCON_CHANNEL_ID:
58
+ if os.environ.get("TEST_ENV") == "True":
59
  print("Safetychecks passed for try_falcon")
60
  await interaction.response.send_message("Working on it!")
61
  channel = interaction.channel
62
  message = await channel.send("Creating thread...")
63
+ thread = await message.create_thread(
64
+ name=prompt, auto_archive_duration=60
65
+ ) # interaction.user
66
+ await thread.send(
67
+ "[DISCLAIMER: HuggingBot is a **highly experimental** beta feature; The Falcon model and system prompt can be found here: https://huggingface.co/spaces/HuggingFaceH4/falcon-chat]"
68
+ )
69
 
70
+ if os.environ.get("TEST_ENV") == "True":
71
  print("Running falcon_initial_generation...")
72
  loop = asyncio.get_running_loop()
73
+ output_text = await loop.run_in_executor(
74
+ None, falcon_initial_generation, prompt, instructions, thread
75
+ )
76
  falcon_userid_threadid_dictionary[thread.id] = interaction.user.id
 
 
77
 
78
+ await thread.send(output_text)
79
  except Exception as e:
80
  print(f"try_falcon Error: {e}")
81
 
82
+
83
  async def continue_falcon(message):
84
  """Continues a given conversation based on chathistory"""
85
  try:
86
  if not message.author.bot:
87
+ global falcon_userid_threadid_dictionary # tracks userid-thread existence
88
+ if (
89
+ message.channel.id in falcon_userid_threadid_dictionary
90
+ ): # is this a valid thread?
91
+ if (
92
+ falcon_userid_threadid_dictionary[message.channel.id]
93
+ == message.author.id
94
+ ): # more than that - is this specifically the right user for this thread?
95
+ if os.environ.get("TEST_ENV") == "True":
96
+ print("Safetychecks passed for continue_falcon")
97
  global instructions
98
  global threadid_conversation
99
+ await message.add_reaction("🔁")
100
+
101
  prompt = message.content
102
  chathistory = threadid_conversation[message.channel.id]
103
  temperature = 0.8
104
+ p_nucleus_sampling = 0.9
105
+
106
+ if os.environ.get("TEST_ENV") == "True":
107
  print("Running falcon_client.submit")
108
+ job = falcon_client.submit(
109
+ prompt,
110
+ chathistory,
111
+ instructions,
112
+ temperature,
113
+ p_nucleus_sampling,
114
+ fn_index=1,
115
+ )
116
  wait([job])
117
+ if os.environ.get("TEST_ENV") == "True":
118
  print("Continue_falcon job done")
119
  file_paths = job.outputs()
120
  full_generation = file_paths[-1]
121
+ with open(full_generation, "r") as file:
122
  data = json.load(file)
123
  output_text = data[-1][-1]
124
+ threadid_conversation[
125
+ message.channel.id
126
+ ] = full_generation # overwrite the old file
127
+ await message.reply(output_text)
128
  except Exception as e:
129
  print(f"continue_falcon Error: {e}")
130
+ await message.reply(f"Error: {e} <@811235357663297546> (continue_falcon error)")