File size: 6,651 Bytes
75e8fca
 
 
a954356
cdc20d7
75e8fca
453a88d
75e8fca
 
6210f88
 
4074ba1
75e8fca
453a88d
 
 
 
 
 
6210f88
1f8cdd1
 
 
a2dad16
453a88d
9ad661a
4710436
453a88d
 
0856426
 
453a88d
 
 
 
d20c7c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4710436
453a88d
75e8fca
6210f88
 
453a88d
75e8fca
 
 
453a88d
 
fb3bd09
ebc2265
75e8fca
ebc2265
453a88d
 
 
 
 
 
4710436
453a88d
9ad661a
4710436
453a88d
 
 
54ea669
75e8fca
453a88d
8614d5a
 
6210f88
453a88d
815a27e
6210f88
36d0f40
fa98bc1
453a88d
 
 
 
 
 
 
 
 
 
fa98bc1
 
453a88d
 
fa98bc1
0856426
 
453a88d
 
 
fb3bd09
453a88d
 
 
 
 
 
 
 
1f8cdd1
453a88d
cdc20d7
 
 
453a88d
fa98bc1
6210f88
453a88d
 
 
bf9dd88
a2dad16
 
 
 
453a88d
815a27e
8614d5a
a2dad16
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
from gradio_client import Client
import os
import asyncio
import json
from concurrent.futures import wait

HF_TOKEN = os.getenv("HF_TOKEN")
falcon_userid_threadid_dictionary = {}
threadid_conversation = {}
# Instructions are for Falcon-chat and can be found here:  https://huggingface.co/spaces/HuggingFaceH4/falcon-chat
instructions = "The following is a conversation between a highly knowledgeable and intelligent AI assistant, called Falcon, and a human user, called User. In the following interactions, User and Falcon will converse in natural language, and Falcon will answer User's questions. Falcon was built to be respectful, polite and inclusive. Falcon was built by the Technology Innovation Institute in Abu Dhabi. Falcon will never decline to answer a question, and always attempts to give an answer that User would be satisfied with. It knows a lot, and always tells the truth. The conversation begins."
falcon_client = Client("HuggingFaceH4/falcon-chat", HF_TOKEN)

BOT_USER_ID = (
    1086256910572986469 if os.getenv("TEST_ENV", False) else 1102236653545861151
)
FALCON_CHANNEL_ID = (
    1079459939405279232 if os.getenv("TEST_ENV", False) else 1119313248056004729
)

async def wait(job):
    while not job.done():
        await asyncio.sleep(0.2)

def falcon_initial_generation(prompt, instructions, thread):
    """Solves two problems at once; 1) The Slash command + job.submit interaction, and 2) the need for job.submit in order to locate the full generated text"""
    global threadid_conversation

    chathistory = falcon_client.predict(fn_index=5)
    temperature = 0.8
    p_nucleus_sampling = 0.9

    job = falcon_client.submit(
        prompt, chathistory, instructions, temperature, p_nucleus_sampling, fn_index=1
    )
    while job.done() is False:
        pass
    else:
        if os.environ.get("TEST_ENV") == "True":
            print("falcon text gen job done")
        file_paths = job.outputs()
        print(file_paths)
        full_generation = file_paths[-1]
        print(full_generation)
        with open(full_generation, "r") as file:
            data = json.load(file)
            print(data)
        output_text = data[-1][-1]
        threadid_conversation[thread.id] = full_generation
        if len(output_text) > 1300:
            output_text = (
                output_text[:1300]
                + "...\nTruncating response to 2000 characters due to discord api limits."
            )
        if os.environ.get("TEST_ENV") == "True":
            print(output_text)
        return output_text


async def try_falcon(interaction, prompt):
    """Generates text based on a given prompt"""
    try:
        global falcon_userid_threadid_dictionary  # tracks userid-thread existence
        global threadid_conversation

        if interaction.user.id != BOT_USER_ID:
            if interaction.channel.id == FALCON_CHANNEL_ID:
                if os.environ.get("TEST_ENV") == "True":
                    print("Safetychecks passed for try_falcon")
                await interaction.response.send_message("Working on it!")
                channel = interaction.channel
                message = await channel.send("Creating thread...")
                thread = await message.create_thread(
                    name=prompt, auto_archive_duration=60
                )  # interaction.user
                await thread.send(
                    "[DISCLAIMER: HuggingBot is a **highly experimental** beta feature; The Falcon model and system prompt can be found here: https://huggingface.co/spaces/HuggingFaceH4/falcon-chat]"
                )

                if os.environ.get("TEST_ENV") == "True":
                    print("Running falcon_initial_generation...")
                loop = asyncio.get_running_loop()
                output_text = await loop.run_in_executor(
                    None, falcon_initial_generation, prompt, instructions, thread
                )
                falcon_userid_threadid_dictionary[thread.id] = interaction.user.id

                await thread.send(output_text)
    except Exception as e:
        print(f"try_falcon Error: {e}")


async def continue_falcon(message):
    """Continues a given conversation based on chathistory"""
    try:
        if not message.author.bot:
            global falcon_userid_threadid_dictionary  # tracks userid-thread existence
            if (
                message.channel.id in falcon_userid_threadid_dictionary
            ):  # is this a valid thread?
                if (
                    falcon_userid_threadid_dictionary[message.channel.id]
                    == message.author.id
                ):  # more than that - is this specifically the right user for this thread?
                    if os.environ.get("TEST_ENV") == "True":
                        print("Safetychecks passed for continue_falcon")
                    global instructions
                    global threadid_conversation
                    await message.add_reaction("🔁")

                    prompt = message.content
                    chathistory = threadid_conversation[message.channel.id]
                    temperature = 0.8
                    p_nucleus_sampling = 0.9

                    if os.environ.get("TEST_ENV") == "True":
                        print("Running falcon_client.submit")
                    job = falcon_client.submit(
                        prompt,
                        chathistory,
                        instructions,
                        temperature,
                        p_nucleus_sampling,
                        fn_index=1,
                    )
                    await wait(job)
                    if os.environ.get("TEST_ENV") == "True":
                        print("Continue_falcon job done")
                    file_paths = job.outputs()
                    full_generation = file_paths[-1]
                    with open(full_generation, "r") as file:
                        data = json.load(file)
                        output_text = data[-1][-1]
                    threadid_conversation[
                        message.channel.id
                    ] = full_generation  # overwrite the old file
                    if len(output_text) > 1300:
                        output_text = (
                            output_text[:1300]
                            + "...\nTruncating response to 2000 characters due to discord api limits."
                        )
                    await message.reply(output_text)
    except Exception as e:
        print(f"continue_falcon Error: {e}")
        await message.reply(f"Error: {e} <@811235357663297546> (continue_falcon error)")