arkadyark's picture
Upload . with huggingface_hub
9f4571e
diff --git a/unit-4/main.py b/unit-4/main.py
index 347c250..834b615 100644
--- a/unit-4/main.py
+++ b/unit-4/main.py
@@ -69,7 +69,7 @@ class CartpolePolicy(nn.Module):
class PixelcopterPolicy(nn.Module):
def __init__(self, s_size, a_size, h_size, device):
- super(Policy, self).__init__()
+ super(PixelcopterPolicy, self).__init__()
self.fc1 = nn.Linear(s_size, h_size)
self.fc2 = nn.Linear(h_size, h_size * 2)
self.fc3 = nn.Linear(h_size * 2, a_size)
@@ -170,8 +170,29 @@ def reinforce(policy, env, optimizer, n_training_episodes, max_t, gamma, print_e
return scores
-
-def push_to_hub(repo_id, model, hyperparameters, eval_env, video_fps=30):
+def record_video(env, policy, out_directory, fps=30):
+ """
+ Generate a replay video of the agent
+ :param env
+ :param Qtable: Qtable of our agent
+ :param out_directory
+ :param fps: how many frame per seconds (with taxi-v3 and frozenlake-v1 we use 1)
+ """
+ images = []
+ done = False
+ state = env.reset()
+ img = env.render(mode="rgb_array")
+ images.append(img)
+ while not done:
+ # Take the action (index) that have the maximum expected future reward given that state
+ action, _ = policy.act(state)
+ state, reward, done, info = env.step(action) # We directly put next_state = state for recording logic
+ img = env.render(mode="rgb_array")
+ images.append(img)
+ imageio.mimsave(out_directory, [np.array(img) for i, img in enumerate(images)], fps=fps)
+
+
+def push_to_hub(repo_id, model, hparams, eval_env, video_fps=30):
"""
Evaluate, Generate a video and Upload a model to Hugging Face Hub.
This method does the complete pipeline:
@@ -182,7 +203,7 @@ def push_to_hub(repo_id, model, hyperparameters, eval_env, video_fps=30):
:param repo_id: repo_id: id of the model repository from the Hugging Face Hub
:param model: the pytorch model we want to save
- :param hyperparameters: training hyperparameters
+ :param hparams: training hparams
:param eval_env: evaluation environment
:param video_fps: how many frame per seconds to record our video replay
"""
@@ -202,15 +223,15 @@ def push_to_hub(repo_id, model, hyperparameters, eval_env, video_fps=30):
# Step 2: Save the model
torch.save(model, local_directory / "model.pt")
- # Step 3: Save the hyperparameters to JSON
- with open(local_directory / "hyperparameters.json", "w") as outfile:
- json.dump(hyperparameters, outfile)
+ # Step 3: Save the hparams to JSON
+ with open(local_directory / "hparams.json", "w") as outfile:
+ json.dump(hparams, outfile)
# Step 4: Evaluate the model and build JSON
mean_reward, std_reward = evaluate_agent(
eval_env,
- hyperparameters["max_t"],
- hyperparameters["n_evaluation_episodes"],
+ hparams["max_t"],
+ hparams["n_evaluation_episodes"],
model,
)
# Get datetime
@@ -218,9 +239,9 @@ def push_to_hub(repo_id, model, hyperparameters, eval_env, video_fps=30):
eval_form_datetime = eval_datetime.isoformat()
evaluate_data = {
- "env_id": hyperparameters["env_id"],
+ "env_id": hparams["env_id"],
"mean_reward": mean_reward,
- "n_evaluation_episodes": hyperparameters["n_evaluation_episodes"],
+ "n_evaluation_episodes": hparams["n_evaluation_episodes"],
"eval_datetime": eval_form_datetime,
}
@@ -229,7 +250,7 @@ def push_to_hub(repo_id, model, hyperparameters, eval_env, video_fps=30):
json.dump(evaluate_data, outfile)
# Step 5: Create the model card
- env_name = hyperparameters["env_id"]
+ env_name = hparams["env_id"]
metadata = {}
metadata["tags"] = [
@@ -256,8 +277,8 @@ def push_to_hub(repo_id, model, hyperparameters, eval_env, video_fps=30):
metadata = {**metadata, **eval}
model_card = f"""
- # **Reinforce** Agent playing **{env_id}**
- This is a trained model of a **Reinforce** agent playing **{env_id}** .
+ # **Reinforce** Agent playing **{env_name}**
+ This is a trained model of a **Reinforce** agent playing **{env_name}** .
To learn to use this model and train yours check Unit 4 of the Deep Reinforcement Learning Course: https://huggingface.co/deep-rl-course/unit4/introduction
"""
@@ -277,7 +298,7 @@ def push_to_hub(repo_id, model, hyperparameters, eval_env, video_fps=30):
# Step 6: Record a video
video_path = local_directory / "replay.mp4"
- record_video(env, model, video_path, video_fps)
+ record_video(eval_env, model, video_path, video_fps)
# Step 7. Push everything to the Hub
api.upload_folder(