BobChuang commited on
Commit
6f053bf
·
verified ·
1 Parent(s): acb7b01

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +17 -17
README.md CHANGED
@@ -33,7 +33,6 @@ import numpy as np
33
  import pickle
34
 
35
  def greedy_policy(Qtable, state):
36
- # Exploitation: take the action with the highest state, action value
37
  action = np.argmax(Qtable[state, :])
38
  return action
39
 
@@ -50,23 +49,24 @@ def evaluate_agent(env: gym.Env, max_steps: int, n_eval_episodes: int, Q: np.nda
50
  for episode in tqdm(range(n_eval_episodes)):
51
  if seed:
52
  state, info = env.reset(seed=seed[episode])
53
- else:
54
- state, info = env.reset()
55
- step = 0
56
- truncated = False
57
- terminated = False
58
- total_rewards_ep = 0
59
 
60
- for step in range(max_steps):
61
- # Take the action (index) that have the maximum expected future reward given that state
62
- action = greedy_policy(Q, state)
63
- new_state, reward, terminated, truncated, info = env.step(action)
64
- total_rewards_ep += reward
 
 
 
 
 
 
 
 
 
65
 
66
- if terminated or truncated:
67
- break
68
- state = new_state
69
- episode_rewards.append(total_rewards_ep)
70
  mean_reward = np.mean(episode_rewards)
71
  std_reward = np.std(episode_rewards)
72
 
@@ -74,7 +74,7 @@ def evaluate_agent(env: gym.Env, max_steps: int, n_eval_episodes: int, Q: np.nda
74
 
75
  if __name__ == "__main__":
76
  file_path = load_from_hub(repo_id="BobChuang/q-Taxi-v1-5x5", filename="q-learning.pkl")
77
- with open(file_path, 'rb') as f:
78
  model = pickle.load(f)
79
 
80
  env = gym.make(model["env_id"], render_mode="rgb_array")
 
33
  import pickle
34
 
35
  def greedy_policy(Qtable, state):
 
36
  action = np.argmax(Qtable[state, :])
37
  return action
38
 
 
49
  for episode in tqdm(range(n_eval_episodes)):
50
  if seed:
51
  state, info = env.reset(seed=seed[episode])
52
+ else:
53
+ state, info = env.reset()
 
 
 
 
54
 
55
+ truncated = False
56
+ terminated = False
57
+ total_rewards_ep = 0
58
+
59
+ for step in range(max_steps):
60
+ action = greedy_policy(Q, state)
61
+ new_state, reward, terminated, truncated, info = env.step(action)
62
+ total_rewards_ep += reward
63
+
64
+ if terminated or truncated:
65
+ break
66
+ state = new_state
67
+
68
+ episode_rewards.append(total_rewards_ep)
69
 
 
 
 
 
70
  mean_reward = np.mean(episode_rewards)
71
  std_reward = np.std(episode_rewards)
72
 
 
74
 
75
  if __name__ == "__main__":
76
  file_path = load_from_hub(repo_id="BobChuang/q-Taxi-v1-5x5", filename="q-learning.pkl")
77
+ with open(file_path, "rb") as f:
78
  model = pickle.load(f)
79
 
80
  env = gym.make(model["env_id"], render_mode="rgb_array")