import torch import torch.nn as nn from torch.distributions.categorical import Categorical from torch.optim import Adam import numpy as np import gym from gym.spaces import Discrete, Box
defmlp(sizes, activation=nn.Tanh, output_activation=nn.Identity): # Build a feedforward neural network. layers = [] for j inrange(len(sizes)-1): act = activation if j < len(sizes)-2else output_activation layers += [nn.Linear(sizes[j], sizes[j+1]), act()] return nn.Sequential(*layers)
# make environment, check spaces, get obs / act dims env = gym.make(env_name) assertisinstance(env.observation_space, Box), \ "This example only works for envs with continuous state spaces." assertisinstance(env.action_space, Discrete), \ "This example only works for envs with discrete action spaces."
# make core of policy network logits_net = mlp(sizes=[obs_dim]+hidden_sizes+[n_acts])
# make function to compute action distribution defget_policy(obs): logits = logits_net(obs) return Categorical(logits=logits)
# make action selection function (outputs int actions, sampled from policy) defget_action(obs): return get_policy(obs).sample().item()
# make loss function whose gradient, for the right data, is policy gradient defcompute_loss(obs, act, weights): logp = get_policy(obs).log_prob(act) return -(logp * weights).mean()
# make optimizer optimizer = Adam(logits_net.parameters(), lr=lr)
# for training policy deftrain_one_epoch(): # make some empty lists for logging. batch_obs = [] # for observations batch_acts = [] # for actions batch_weights = [] # for R(tau) weighting in policy gradient batch_rets = [] # for measuring episode returns batch_lens = [] # for measuring episode lengths
# reset episode-specific variables obs = env.reset() # first obs comes from starting distribution done = False# signal from environment that episode is over ep_rews = [] # list for rewards accrued throughout ep
# render first episode of each epoch finished_rendering_this_epoch = False
# collect experience by acting in the environment with current policy whileTrue:
# rendering if (not finished_rendering_this_epoch) and render: env.render()
# save obs batch_obs.append(obs.copy())
# act in the environment act = get_action(torch.as_tensor(obs, dtype=torch.float32)) obs, rew, done, _ = env.step(act)
# save action, reward batch_acts.append(act) ep_rews.append(rew)
if done: # if episode is over, record info about episode ep_ret, ep_len = sum(ep_rews), len(ep_rews) batch_rets.append(ep_ret) batch_lens.append(ep_len)
# the weight for each logprob(a|s) is R(tau) batch_weights += [ep_ret] * ep_len
if done: # if episode is over, record info about episode ep_ret, ep_len = sum(ep_rews), len(ep_rews) batch_rets.append(ep_ret) batch_lens.append(ep_len)
# the weight for each logprob(a|s) is R(tau) batch_weights += [ep_ret] * ep_len