import random import numpy as np from scipy.misc import imresize class CatchEnv3: def __init__(self): self.size = 21 self.image = np.zeros((self.size, self.size)) self.state = [] self.fps = 4 self.output_shape = (84, 84) def reset_random(self): self.image.fill(0) self.pos = np.random.randint(2, self.size-2) self.vx = np.random.randint(5) - 2 self.vy = 1 self.ballx, self.bally = np.random.randint(self.size), 4 self.image[self.bally, self.ballx] = 1 self.image[-5, self.pos - 2:self.pos + 3] = np.ones(5) return self.step(2)[0] def step(self, action): def left(): if self.pos > 3: self.pos -= 2 def right(): if self.pos < 17: self.pos += 2 def noop(): pass {0: left, 1: right, 2: noop}[action]() self.image[self.bally, self.ballx] = 0 self.ballx += self.vx self.bally += self.vy if self.ballx > self.size - 1: self.ballx -= 2 * (self.ballx - (self.size-1)) self.vx *= -1 elif self.ballx < 0: self.ballx += 2 * (0 - self.ballx) self.vx *= -1 self.image[self.bally, self.ballx] = 1 self.image[-5].fill(0) self.image[-5, self.pos-2:self.pos+3] = np.ones(5) terminal = self.bally == self.size - 1 - 4 reward = int(self.pos - 2 <= self.ballx <= self.pos + 2) if terminal else 0 [self.state.append(imresize(self.image, (84, 84))) for _ in range(self.fps - len(self.state) + 1)] self.state = self.state[-self.fps:] self.state[0] = self.state[0][::-1,:] self.state[1] = self.state[1][::-1,:] self.state[2] = self.state[2][::-1,:] self.state[3] = self.state[3][::-1,:] return np.transpose(self.state, [1, 2, 0]), reward, terminal def get_num_actions(self): return 3 def reset(self): return self.reset_random() def state_shape(self): return (self.fps,) + self.output_shape def test(): env = CatchEnv2() i = 0 for ep in range(1): env.reset() state, reward, terminal = env.step(1) while not terminal: env.show_state(i) state, reward, terminal = env.step(1) state = np.squeeze(state) plt.imsave('image_'+str(i)+'.jpg', state) i += 1 if __name__ == "main": test()