Update README.md
Browse files
README.md
CHANGED
@@ -33,7 +33,6 @@ import numpy as np
|
|
33 |
import pickle
|
34 |
|
35 |
def greedy_policy(Qtable, state):
|
36 |
-
# Exploitation: take the action with the highest state, action value
|
37 |
action = np.argmax(Qtable[state, :])
|
38 |
return action
|
39 |
|
@@ -50,23 +49,24 @@ def evaluate_agent(env: gym.Env, max_steps: int, n_eval_episodes: int, Q: np.nda
|
|
50 |
for episode in tqdm(range(n_eval_episodes)):
|
51 |
if seed:
|
52 |
state, info = env.reset(seed=seed[episode])
|
53 |
-
|
54 |
-
|
55 |
-
step = 0
|
56 |
-
truncated = False
|
57 |
-
terminated = False
|
58 |
-
total_rewards_ep = 0
|
59 |
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
|
66 |
-
if terminated or truncated:
|
67 |
-
break
|
68 |
-
state = new_state
|
69 |
-
episode_rewards.append(total_rewards_ep)
|
70 |
mean_reward = np.mean(episode_rewards)
|
71 |
std_reward = np.std(episode_rewards)
|
72 |
|
@@ -74,7 +74,7 @@ def evaluate_agent(env: gym.Env, max_steps: int, n_eval_episodes: int, Q: np.nda
|
|
74 |
|
75 |
if __name__ == "__main__":
|
76 |
file_path = load_from_hub(repo_id="BobChuang/q-Taxi-v1-5x5", filename="q-learning.pkl")
|
77 |
-
with open(file_path,
|
78 |
model = pickle.load(f)
|
79 |
|
80 |
env = gym.make(model["env_id"], render_mode="rgb_array")
|
|
|
33 |
import pickle
|
34 |
|
35 |
def greedy_policy(Qtable, state):
|
|
|
36 |
action = np.argmax(Qtable[state, :])
|
37 |
return action
|
38 |
|
|
|
49 |
for episode in tqdm(range(n_eval_episodes)):
|
50 |
if seed:
|
51 |
state, info = env.reset(seed=seed[episode])
|
52 |
+
else:
|
53 |
+
state, info = env.reset()
|
|
|
|
|
|
|
|
|
54 |
|
55 |
+
truncated = False
|
56 |
+
terminated = False
|
57 |
+
total_rewards_ep = 0
|
58 |
+
|
59 |
+
for step in range(max_steps):
|
60 |
+
action = greedy_policy(Q, state)
|
61 |
+
new_state, reward, terminated, truncated, info = env.step(action)
|
62 |
+
total_rewards_ep += reward
|
63 |
+
|
64 |
+
if terminated or truncated:
|
65 |
+
break
|
66 |
+
state = new_state
|
67 |
+
|
68 |
+
episode_rewards.append(total_rewards_ep)
|
69 |
|
|
|
|
|
|
|
|
|
70 |
mean_reward = np.mean(episode_rewards)
|
71 |
std_reward = np.std(episode_rewards)
|
72 |
|
|
|
74 |
|
75 |
if __name__ == "__main__":
|
76 |
file_path = load_from_hub(repo_id="BobChuang/q-Taxi-v1-5x5", filename="q-learning.pkl")
|
77 |
+
with open(file_path, "rb") as f:
|
78 |
model = pickle.load(f)
|
79 |
|
80 |
env = gym.make(model["env_id"], render_mode="rgb_array")
|