This repository has been archived by the owner on Aug 10, 2021. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
48 lines (43 loc) · 1.59 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import numpy as np
from game import Game
import pygame
from deep_q_network import Agent
if __name__ == "__main__":
game = Game()
observation_space = 13
action_space = 2
episodes = 300
render = False
total_rewards = []
agent = Agent(observation_space, action_space)
for episode in range(episodes):
readings, rewards, done = game.init_elements()
done = False
total_reward = 0
quit = False
save_file_name = "dqn_with_er"
while not (done):
readings = np.reshape(readings,(1,observation_space))
total_reward += rewards
action = agent.get_action(readings)
next_readings, rewards, done = game.frame_step(action)
next_readings = np.reshape(next_readings, (1,observation_space))
agent.save_to_memory(readings, action, rewards, next_readings, done, episode)
readings = next_readings
agent.experience_replay()
if done:
print("Episode : ",episode, " Total reward : ",total_reward, " with exploration rate : ",agent.exploration_rate)
total_rewards.append(total_reward)
break
for event in pygame.event.get():
if event.type == pygame.QUIT:
print("QUITING THE GAME")
quit = True
done = True
if (episode % 50 == 0):
agent.save(save_file_name)
if (quit):
break
agent.save_collected_data(episodes, total_rewards)
agent.save(save_file_name)
pygame.quit()