Eladlev commited on
Commit
4311882
·
verified ·
1 Parent(s): 6babbf8

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +151 -249
app.py CHANGED
@@ -1,262 +1,164 @@
1
- """
2
- Entrypoint for Gradio, see https://gradio.app/
3
- """
4
-
5
- import platform
6
- import asyncio
7
- import base64
8
- import os
9
- from datetime import datetime
10
- from enum import StrEnum
11
- from functools import partial
12
- from pathlib import Path
13
- from typing import cast, Dict
14
-
15
  import gradio as gr
16
- from anthropic import APIResponse
 
 
 
17
  from anthropic.types import TextBlock
18
  from anthropic.types.beta import BetaMessage, BetaTextBlock, BetaToolUseBlock
19
- from anthropic.types.tool_use_block import ToolUseBlock
20
-
21
- from computer_use_demo.loop import (
22
- PROVIDER_TO_DEFAULT_MODEL_NAME,
23
- APIProvider,
24
- sampling_loop,
25
- sampling_loop_sync,
26
- )
27
-
28
- from computer_use_demo.tools import ToolResult
29
-
30
-
31
- CONFIG_DIR = Path("~/.anthropic").expanduser()
32
- API_KEY_FILE = CONFIG_DIR / "api_key"
33
-
34
- WARNING_TEXT = "⚠️ Security Alert: Never provide access to sensitive accounts or data, as malicious web content can hijack Claude's behavior"
35
-
36
-
37
- class Sender(StrEnum):
38
- USER = "user"
39
- BOT = "assistant"
40
- TOOL = "tool"
41
-
42
-
43
- def setup_state(state):
44
- if "messages" not in state:
45
- state["messages"] = []
46
- if "api_key" not in state:
47
- # Try to load API key from file first, then environment
48
- state["api_key"] = load_from_storage("api_key") or os.getenv("ANTHROPIC_API_KEY", "")
49
- if not state["api_key"]:
50
- print("API key not found. Please set it in the environment or storage.")
51
- if "provider" not in state:
52
- state["provider"] = os.getenv("API_PROVIDER", "anthropic") or APIProvider.ANTHROPIC
53
- if "provider_radio" not in state:
54
- state["provider_radio"] = state["provider"]
55
- if "model" not in state:
56
- _reset_model(state)
57
- if "auth_validated" not in state:
58
- state["auth_validated"] = False
59
- if "responses" not in state:
60
- state["responses"] = {}
61
- if "tools" not in state:
62
- state["tools"] = {}
63
- if "only_n_most_recent_images" not in state:
64
- state["only_n_most_recent_images"] = 3 # 10
65
- if "custom_system_prompt" not in state:
66
- state["custom_system_prompt"] = load_from_storage("system_prompt") or ""
67
- # remove if want to use default system prompt
68
- device_os_name = "Windows" if platform.platform == "Windows" else "Mac" if platform.platform == "Darwin" else "Linux"
69
- state["custom_system_prompt"] += f"\n\nNOTE: you are operating a {device_os_name} machine"
70
- if "hide_images" not in state:
71
- state["hide_images"] = False
72
-
73
-
74
- def _reset_model(state):
75
- state["model"] = PROVIDER_TO_DEFAULT_MODEL_NAME[cast(APIProvider, state["provider"])]
76
-
77
-
78
- async def main(state):
79
- """Render loop for Gradio"""
80
- setup_state(state)
81
- return "Setup completed"
82
-
83
-
84
- def validate_auth(provider: APIProvider, api_key: str | None):
85
- if provider == APIProvider.ANTHROPIC:
86
- if not api_key:
87
- return "Enter your Anthropic API key to continue."
88
- if provider == APIProvider.BEDROCK:
89
- import boto3
90
-
91
- if not boto3.Session().get_credentials():
92
- return "You must have AWS credentials set up to use the Bedrock API."
93
- if provider == APIProvider.VERTEX:
94
- import google.auth
95
- from google.auth.exceptions import DefaultCredentialsError
96
-
97
- if not os.environ.get("CLOUD_ML_REGION"):
98
- return "Set the CLOUD_ML_REGION environment variable to use the Vertex API."
99
- try:
100
- google.auth.default(scopes=["https://www.googleapis.com/auth/cloud-platform"])
101
- except DefaultCredentialsError:
102
- return "Your google cloud credentials are not set up correctly."
103
-
104
-
105
- def load_from_storage(filename: str) -> str | None:
106
- """Load data from a file in the storage directory."""
107
- try:
108
- file_path = CONFIG_DIR / filename
109
- if file_path.exists():
110
- data = file_path.read_text().strip()
111
- if data:
112
- return data
113
- except Exception as e:
114
- print(f"Debug: Error loading {filename}: {e}")
115
- return None
116
-
117
-
118
- def save_to_storage(filename: str, data: str) -> None:
119
- """Save data to a file in the storage directory."""
120
- try:
121
- CONFIG_DIR.mkdir(parents=True, exist_ok=True)
122
- file_path = CONFIG_DIR / filename
123
- file_path.write_text(data)
124
- # Ensure only user can read/write the file
125
- file_path.chmod(0o600)
126
- except Exception as e:
127
- print(f"Debug: Error saving {filename}: {e}")
128
-
129
-
130
- def _api_response_callback(response: APIResponse[BetaMessage], response_state: dict):
131
- response_id = datetime.now().isoformat()
132
- response_state[response_id] = response
133
-
134
-
135
- def _tool_output_callback(tool_output: ToolResult, tool_id: str, tool_state: dict):
136
- tool_state[tool_id] = tool_output
137
-
138
-
139
- def _render_message(sender: Sender, message: str | BetaTextBlock | BetaToolUseBlock | ToolResult, state):
140
- is_tool_result = not isinstance(message, str) and (
141
- isinstance(message, ToolResult)
142
- or message.__class__.__name__ == "ToolResult"
143
- or message.__class__.__name__ == "CLIResult"
144
- )
145
- if not message or (
146
- is_tool_result
147
- and state["hide_images"]
148
- and not hasattr(message, "error")
149
- and not hasattr(message, "output")
150
- ):
151
- return
152
- if is_tool_result:
153
- message = cast(ToolResult, message)
154
- if message.output:
155
- return message.output
156
- if message.error:
157
- return f"Error: {message.error}"
158
- if message.base64_image and not state["hide_images"]:
159
- return base64.b64decode(message.base64_image)
160
- elif isinstance(message, BetaTextBlock) or isinstance(message, TextBlock):
161
- return message.text
162
- elif isinstance(message, BetaToolUseBlock) or isinstance(message, ToolUseBlock):
163
- return f"Tool Use: {message.name}\nInput: {message.input}"
164
- else:
165
- return message
166
- # open new tab, open google sheets inside, then create a new blank spreadsheet
167
-
168
- def process_input(user_input, state):
169
- # Ensure the state is properly initialized
170
- setup_state(state)
171
-
172
- # Append the user input to the messages in the state
173
- state["messages"].append(
174
- {
175
- "role": Sender.USER,
176
- "content": [TextBlock(type="text", text=user_input)],
177
- }
178
  )
 
 
 
 
 
 
179
 
180
- # Run the sampling loop synchronously and yield messages
181
- for message in sampling_loop(state):
182
- yield message
183
 
 
 
 
184
 
185
- def accumulate_messages(*args, **kwargs):
186
- """
187
- Wrapper function to accumulate messages from sampling_loop_sync.
188
- """
189
- accumulated_messages = []
190
-
191
- for message in sampling_loop_sync(*args, **kwargs):
192
- # Check if the message is already in the accumulated messages
193
- if message not in accumulated_messages:
194
- accumulated_messages.append(message)
195
- # Yield the accumulated messages as a list
196
- yield accumulated_messages
197
-
198
-
199
- def sampling_loop(state):
200
- # Ensure the API key is present
201
- if not state.get("api_key"):
202
- raise ValueError("API key is missing. Please set it in the environment or storage.")
203
-
204
- # Call the sampling loop and yield messages
205
- for message in accumulate_messages(
206
- system_prompt_suffix=state["custom_system_prompt"],
207
- model=state["model"],
208
- provider=state["provider"],
209
- messages=state["messages"],
210
- output_callback=partial(_render_message, Sender.BOT, state=state),
211
- tool_output_callback=partial(_tool_output_callback, tool_state=state["tools"]),
212
- api_response_callback=partial(_api_response_callback, response_state=state["responses"]),
213
- api_key=state["api_key"],
214
- only_n_most_recent_images=state["only_n_most_recent_images"],
215
- ):
216
- yield message
217
 
 
218
 
219
- with gr.Blocks() as demo:
220
- state = gr.State({}) # Use Gradio's state management
221
 
222
- gr.Markdown("# Claude Computer Use Demo")
223
 
224
- if not os.getenv("HIDE_WARNING", False):
225
- gr.Markdown(WARNING_TEXT)
 
 
 
 
 
 
 
 
 
 
 
 
226
 
227
  with gr.Row():
228
- provider = gr.Dropdown(
229
- label="API Provider",
230
- choices=[option.value for option in APIProvider],
231
- value="anthropic",
232
- interactive=True,
233
- )
234
- model = gr.Textbox(label="Model", value="claude-3-5-sonnet-20241022")
235
- api_key = gr.Textbox(
236
- label="Anthropic API Key",
237
- type="password",
238
- value="",
239
- interactive=True,
240
- )
241
- only_n_images = gr.Slider(
242
- label="Only send N most recent images",
243
- minimum=0,
244
- value=3, # 10
245
- interactive=True,
246
- )
247
- custom_prompt = gr.Textbox(
248
- label="Custom System Prompt Suffix",
249
- value="",
250
- interactive=True,
251
- )
252
- hide_images = gr.Checkbox(label="Hide screenshots", value=False)
253
-
254
- api_key.change(fn=lambda key: save_to_storage(API_KEY_FILE, key), inputs=api_key)
255
- chat_input = gr.Textbox(label="Type a message to send to Claude...")
256
- # chat_output = gr.Textbox(label="Chat Output", interactive=False)
257
- chatbot = gr.Chatbot(label="Chatbot History", autoscroll=True)
258
-
259
- # Pass state as an input to the function
260
- chat_input.submit(process_input, [chat_input, state], chatbot)
261
-
262
- demo.launch(share=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import io
3
+ import os
4
+ from PIL import Image, ImageDraw
5
+ from anthropic import Anthropic
6
  from anthropic.types import TextBlock
7
  from anthropic.types.beta import BetaMessage, BetaTextBlock, BetaToolUseBlock
8
+ max_tokens = 4096
9
+ import base64
10
+ model = 'claude-3-5-sonnet-20241022'
11
+ system = """<SYSTEM_CAPABILITY>
12
+ * You are utilizing a Windows system with internet access.
13
+ * The current date is Monday, November 18, 2024.
14
+ </SYSTEM_CAPABILITY>"""
15
+
16
+ def save_image_or_get_url(image, filename="processed_image.png"):
17
+ filepath = os.path.join("static", filename)
18
+ image.save(filepath)
19
+ return filepath
20
+
21
+ def draw_circle_on_image(image, center, radius=30):
22
+ """
23
+ Draws a circle on the given image using a center point and radius.
24
+
25
+ Parameters:
26
+ image (PIL.Image): The image to draw on.
27
+ center (tuple): A tuple (x, y) representing the center of the circle.
28
+ radius (int): The radius of the circle.
29
+
30
+ Returns:
31
+ PIL.Image: The image with the circle drawn.
32
+ """
33
+ if not isinstance(center, tuple) or len(center) != 2:
34
+ raise ValueError("Center must be a tuple of two values (x, y).")
35
+ if not isinstance(radius, (int, float)) or radius <= 0:
36
+ raise ValueError("Radius must be a positive number.")
37
+
38
+ # Calculate the bounding box for the circle
39
+ bbox = [
40
+ center[0] - radius, center[1] - radius, # Top-left corner
41
+ center[0] + radius, center[1] + radius # Bottom-right corner
42
+ ]
43
+
44
+ # Create a drawing context
45
+ draw = ImageDraw.Draw(image)
46
+
47
+ # Draw the circle
48
+ draw.ellipse(bbox, outline="red", width=15) # Change outline color and width as needed
49
+
50
+ return image
51
+
52
+
53
+ def pil_image_to_base64(pil_image):
54
+ # Save the PIL image to an in-memory buffer as a file-like object
55
+ buffered = io.BytesIO()
56
+ pil_image.save(buffered, format="PNG") # Specify format (e.g., PNG, JPEG)
57
+ buffered.seek(0) # Rewind the buffer to the beginning
58
+
59
+ # Encode the bytes from the buffer to Base64
60
+ image_data = base64.b64encode(buffered.getvalue()).decode("utf-8")
61
+ return image_data
62
+
63
+
64
+
65
+
66
+
67
+
68
+
69
+ # Function to simulate chatbot responses
70
+ def chatbot_response(input_text, image, key, chat_history):
71
+
72
+ if not key:
73
+ return chat_history + [[input_text, "Please enter a valid key."]]
74
+ if image is None:
75
+ return chat_history + [[input_text, "Please upload an image."]]
76
+ api_key =key
77
+ client = Anthropic(api_key=api_key)
78
+
79
+
80
+
81
+ messages = [{'role': 'user', 'content': [TextBlock(text=f'Look at my screenshot, {input_text}', type='text')]},
82
+ {'role': 'assistant', 'content': [BetaTextBlock(
83
+ text="I'll help you check your screen, but first I need to take a screenshot to see what you're looking at.",
84
+ type='text'), BetaToolUseBlock(id='toolu_01PSTVtavFgmx6ctaiSvacCB',
85
+ input={'action': 'screenshot'}, name='computer',
86
+ type='tool_use')]}]
87
+ image_data = pil_image_to_base64(image)
88
+
89
+ tool_res = {'role': 'user', 'content': [{'type': 'tool_result', 'tool_use_id': 'toolu_01PSTVtavFgmx6ctaiSvacCB',
90
+ 'is_error': False,
91
+ 'content': [{'type': 'image',
92
+ 'source': {'type': 'base64', 'media_type': 'image/png',
93
+ 'data': image_data}}]}]}
94
+ messages.append(tool_res)
95
+ params = [{'name': 'computer', 'type': 'computer_20241022', 'display_width_px': 1512, 'display_height_px': 982,
96
+ 'display_number': None}, {'type': 'bash_20241022', 'name': 'bash'},
97
+ {'name': 'str_replace_editor', 'type': 'text_editor_20241022'}]
98
+ raw_response = client.beta.messages.with_raw_response.create(
99
+ max_tokens=max_tokens,
100
+ messages=messages,
101
+ model=model,
102
+ system=system,
103
+ tools=params,
104
+ betas=["computer-use-2024-10-22"],
105
+ temperature=0.0,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  )
107
+ response = raw_response.parse()
108
+ scale_x = image.width // 1512
109
+ scale_y = image.height // 982
110
+ for r in response.content:
111
+ if hasattr(r, 'text'):
112
+ chat_history = chat_history + [[input_text, r.text]]
113
 
114
+ if hasattr(r, 'input') and 'coordinate' in r.input:
115
+ coordinate = r.input['coordinate']
116
+ new_image = draw_circle_on_image(image, (coordinate[0] * scale_x, coordinate[1] * scale_y))
117
 
118
+ # Save the image or encode it as a base64 string if needed
119
+ image_url = save_image_or_get_url(
120
+ new_image) # Define this function to save or generate the URL for the image
121
 
122
+ # Include the image as part of the chat history
123
+ image_html = f'<img src="{image_url}" alt="Processed Image" style="max-width: 100%; max-height: 200px;">'
124
+ chat_history = chat_history + [[None, (image_url,)]]
125
+ return chat_history
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
 
127
+ # Read the image and encode it in base64
128
 
 
 
129
 
 
130
 
131
+
132
+
133
+ # Simulated response
134
+ response = f"Received input: {input_text}\nKey: {key}\nImage uploaded successfully!"
135
+ return chat_history + [[input_text, response]]
136
+
137
+
138
+ # Create the Gradio interface
139
+ with gr.Blocks() as demo:
140
+ with gr.Row():
141
+ with gr.Column():
142
+ image_input = gr.Image(label="Upload Image", type="pil", interactive=True)
143
+ with gr.Column():
144
+ chatbot = gr.Chatbot(label="Chatbot Interaction", height=400)
145
 
146
  with gr.Row():
147
+ user_input = gr.Textbox(label="Type your message here", placeholder="Enter your message...")
148
+ key_input = gr.Textbox(label="API Key", placeholder="Enter your key...", type="password")
149
+
150
+ # Button to submit
151
+ submit_button = gr.Button("Submit")
152
+
153
+ # Initialize chat history
154
+ chat_history = gr.State(value=[])
155
+
156
+ # Set interactions
157
+ submit_button.click(
158
+ fn=chatbot_response,
159
+ inputs=[user_input, image_input, key_input, chat_history],
160
+ outputs=[chatbot],
161
+ )
162
+
163
+ # Launch the app
164
+ demo.launch()