Update README.md
Browse files
README.md
CHANGED
@@ -1,35 +1,89 @@
|
|
1 |
-
---
|
2 |
-
tags:
|
3 |
-
- Taxi-v3
|
4 |
-
- q-learning
|
5 |
-
- reinforcement-learning
|
6 |
-
- custom-implementation
|
7 |
-
model-index:
|
8 |
-
- name: q-Taxi-v1-5x5
|
9 |
-
results:
|
10 |
-
- task:
|
11 |
-
type: reinforcement-learning
|
12 |
-
name: reinforcement-learning
|
13 |
-
dataset:
|
14 |
-
name: Taxi-v3
|
15 |
-
type: Taxi-v3
|
16 |
-
metrics:
|
17 |
-
- type: mean_reward
|
18 |
-
value: 7.36 +/- 2.47
|
19 |
-
name: mean_reward
|
20 |
-
verified: false
|
21 |
-
---
|
22 |
-
|
23 |
-
# **Q-Learning** Agent playing1 **Taxi-v3**
|
24 |
-
This is a trained model of a **Q-Learning** agent playing **Taxi-v3** .
|
25 |
-
|
26 |
-
## Usage
|
27 |
-
|
28 |
-
```python
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
|
|
|
1 |
+
---
|
2 |
+
tags:
|
3 |
+
- Taxi-v3
|
4 |
+
- q-learning
|
5 |
+
- reinforcement-learning
|
6 |
+
- custom-implementation
|
7 |
+
model-index:
|
8 |
+
- name: q-Taxi-v1-5x5
|
9 |
+
results:
|
10 |
+
- task:
|
11 |
+
type: reinforcement-learning
|
12 |
+
name: reinforcement-learning
|
13 |
+
dataset:
|
14 |
+
name: Taxi-v3
|
15 |
+
type: Taxi-v3
|
16 |
+
metrics:
|
17 |
+
- type: mean_reward
|
18 |
+
value: 7.36 +/- 2.47
|
19 |
+
name: mean_reward
|
20 |
+
verified: false
|
21 |
+
---
|
22 |
+
|
23 |
+
# **Q-Learning** Agent playing1 **Taxi-v3**
|
24 |
+
This is a trained model of a **Q-Learning** agent playing **Taxi-v3** .
|
25 |
+
|
26 |
+
## Usage
|
27 |
+
|
28 |
+
```python
|
29 |
+
from huggingface_sb3 import load_from_hub
|
30 |
+
import gymnasium as gym
|
31 |
+
from tqdm import tqdm
|
32 |
+
import numpy as np
|
33 |
+
import pickle
|
34 |
+
|
35 |
+
def greedy_policy(Qtable, state):
|
36 |
+
# Exploitation: take the action with the highest state, action value
|
37 |
+
action = np.argmax(Qtable[state, :])
|
38 |
+
return action
|
39 |
+
|
40 |
+
def evaluate_agent(env: gym.Env, max_steps: int, n_eval_episodes: int, Q: np.ndarray, seed: list[int]):
|
41 |
+
"""
|
42 |
+
Evaluate the agent for ``n_eval_episodes`` episodes and returns average reward and std of reward.
|
43 |
+
:param env: The evaluation environment
|
44 |
+
:param max_steps: Maximum number of steps per episode
|
45 |
+
:param n_eval_episodes: Number of episode to evaluate the agent
|
46 |
+
:param Q: The Q-table
|
47 |
+
:param seed: The evaluation seed array (for taxi-v3)
|
48 |
+
"""
|
49 |
+
episode_rewards = []
|
50 |
+
for episode in tqdm(range(n_eval_episodes)):
|
51 |
+
if seed:
|
52 |
+
state, info = env.reset(seed=seed[episode])
|
53 |
+
else:
|
54 |
+
state, info = env.reset()
|
55 |
+
step = 0
|
56 |
+
truncated = False
|
57 |
+
terminated = False
|
58 |
+
total_rewards_ep = 0
|
59 |
+
|
60 |
+
for step in range(max_steps):
|
61 |
+
# Take the action (index) that have the maximum expected future reward given that state
|
62 |
+
action = greedy_policy(Q, state)
|
63 |
+
new_state, reward, terminated, truncated, info = env.step(action)
|
64 |
+
total_rewards_ep += reward
|
65 |
+
|
66 |
+
if terminated or truncated:
|
67 |
+
break
|
68 |
+
state = new_state
|
69 |
+
episode_rewards.append(total_rewards_ep)
|
70 |
+
mean_reward = np.mean(episode_rewards)
|
71 |
+
std_reward = np.std(episode_rewards)
|
72 |
+
|
73 |
+
return float(mean_reward), float(std_reward)
|
74 |
+
|
75 |
+
if __name__ == "__main__":
|
76 |
+
file_path = load_from_hub(repo_id="BobChuang/q-Taxi-v1-5x5", filename="q-learning.pkl")
|
77 |
+
with open(file_path, 'rb') as f:
|
78 |
+
model = pickle.load(f)
|
79 |
+
|
80 |
+
env = gym.make(model["env_id"], render_mode="rgb_array")
|
81 |
+
max_steps = model["max_steps"]
|
82 |
+
n_eval_episodes = model["n_eval_episodes"]
|
83 |
+
qtable = model["qtable"]
|
84 |
+
eval_seed = model["eval_seed"]
|
85 |
+
|
86 |
+
mean_reward, std_reward = evaluate_agent(env, max_steps, n_eval_episodes, qtable, eval_seed)
|
87 |
+
print(f"\n{ mean_reward = }, { std_reward = }")
|
88 |
+
```
|
89 |
|