yuntian-deng commited on
Commit
89b9813
·
1 Parent(s): 8169788

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +44 -6
main.py CHANGED
@@ -1,8 +1,11 @@
1
  from fastapi import FastAPI, WebSocket
2
  from fastapi.responses import HTMLResponse
3
  from fastapi.staticfiles import StaticFiles
4
- from typing import List
5
  import numpy as np
 
 
 
6
 
7
  app = FastAPI()
8
 
@@ -14,9 +17,38 @@ app.mount("/static", StaticFiles(directory="static"), name="static")
14
  async def get():
15
  return HTMLResponse(open("static/index.html").read())
16
 
17
- # Simulate your diffusion model (placeholder)
18
- def predict_next_frame(previous_frames: List[np.ndarray], previous_actions: List[str]) -> np.ndarray:
19
- return np.zeros((800, 600, 3), dtype=np.uint8)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
  # WebSocket endpoint for continuous user interaction
22
  @app.websocket("/ws")
@@ -39,8 +71,14 @@ async def websocket_endpoint(websocket: WebSocket):
39
  next_frame = predict_next_frame(previous_frames, previous_actions)
40
  previous_frames.append(next_frame)
41
 
42
- # Send the generated frame back to the client (encoded as base64 or similar)
43
- await websocket.send_text("Next frame generated") # Replace with real image sending logic
 
 
 
 
 
 
44
 
45
  except Exception as e:
46
  print(f"Error: {e}")
 
1
  from fastapi import FastAPI, WebSocket
2
  from fastapi.responses import HTMLResponse
3
  from fastapi.staticfiles import StaticFiles
4
+ from typing import List, Tuple
5
  import numpy as np
6
+ from PIL import Image, ImageDraw
7
+ import base64
8
+ import io
9
 
10
  app = FastAPI()
11
 
 
17
  async def get():
18
  return HTMLResponse(open("static/index.html").read())
19
 
20
+ def generate_random_image(width: int, height: int) -> np.ndarray:
21
+ return np.random.randint(0, 256, (height, width, 3), dtype=np.uint8)
22
+
23
+ def draw_trace(image: np.ndarray, previous_actions: List[Tuple[str, List[int]]]) -> np.ndarray:
24
+ pil_image = Image.fromarray(image)
25
+ draw = ImageDraw.Draw(pil_image)
26
+
27
+ for i, (action_type, position) in enumerate(previous_actions):
28
+ color = (255, 0, 0) if action_type == "move" else (0, 255, 0)
29
+ x, y = position
30
+ draw.ellipse([x-2, y-2, x+2, y+2], fill=color)
31
+
32
+ if i > 0:
33
+ prev_x, prev_y = previous_actions[i-1][1]
34
+ draw.line([prev_x, prev_y, x, y], fill=color, width=1)
35
+
36
+ return np.array(pil_image)
37
+
38
+ def predict_next_frame(previous_frames: List[np.ndarray], previous_actions: List[Tuple[str, List[int]]]) -> np.ndarray:
39
+ width, height = 800, 600
40
+
41
+ if not previous_frames or previous_actions[-1][0] == "move":
42
+ # Generate a new random image when there's no previous frame or the mouse moves
43
+ new_frame = generate_random_image(width, height)
44
+ else:
45
+ # Use the last frame if it exists and the action is not a mouse move
46
+ new_frame = previous_frames[-1].copy()
47
+
48
+ # Draw the trace of previous actions
49
+ new_frame_with_trace = draw_trace(new_frame, previous_actions)
50
+
51
+ return new_frame_with_trace
52
 
53
  # WebSocket endpoint for continuous user interaction
54
  @app.websocket("/ws")
 
71
  next_frame = predict_next_frame(previous_frames, previous_actions)
72
  previous_frames.append(next_frame)
73
 
74
+ # Convert the numpy array to a base64 encoded image
75
+ img = Image.fromarray(next_frame)
76
+ buffered = io.BytesIO()
77
+ img.save(buffered, format="PNG")
78
+ img_str = base64.b64encode(buffered.getvalue()).decode()
79
+
80
+ # Send the generated frame back to the client
81
+ await websocket.send_json({"image": img_str})
82
 
83
  except Exception as e:
84
  print(f"Error: {e}")