reinforcement learning (2/4): value function approximation

The methods we discussed in part 1 are limited when state spaces are large and/or continuous. Value function approximation addresses this by using functions to approximate the relationship between states and their value. But how can we find the parameters $\mathbf{w}$ of our value function $\hat{v}(s, \mathbf{w})$? Gradient descent works nicely here, which gives us tons of flexibility in how we model value functions.

stochastic gradient descent

Our goal is to learn the true value function $v_\pi(s)$ that relates states to values under our policy $\pi$. If our state space is large and/or continuous (as it is for most interesting problems), the best we can do is accurately approximate the true value function. If we parameterize our function approximator $\hat{v}(s, \mathbf{w})$ with weights $\mathbf{w}$, then the goal becomes finding $\mathbf{w}$ such that the approximation is as accurate as possible.

Gradient descent to the rescue! First we’ll define a loss function that penalizes deviations from the true value function. We’ll use the mean squared error between the estimated and true value function, averaged across states. But we care more about states that are visited frequently, so we will weight our average according to $u(s)$, which is the stationary distribution of states under the current policy:

\[\overline{VE}(\mathbf{w}) = \sum_{s \in \mathcal{S}} \underbrace{u(s)}_\text{state probability} \left[ \underbrace{v_\pi(s)}_\text{true value} - \underbrace{\hat{v}_\pi(s, \mathbf{w})}_\text{estimated value} \right]^2\]

We can make our value estimate more accurate by nudging $\mathbf{w}$ in the (opposite) direction of the gradient. Because the states should be distributed according to $u(s)$, these updates will be correct on average1. Recall that the gradient (of a scalar valued function) is a vector describing how the function’s output is affected by each input:

\[\nabla_\mathbf{w} f(\mathbf{w}) = \left( \frac{\partial f(\mathbf{w})}{\partial w_1}, \frac{\partial f(\mathbf{w})}{\partial w_2}, \dots, \frac{\partial f(\mathbf{w})}{\partial w_d} \right)^\top\]

We can now start building an algorithm. To update $\mathbf{w}$ we will move it in the opposite direction of the gradient by a factor $\alpha >0$ (ommitting the subscript in $\nabla_\mathbf{w}$ for clarity):

\[\mathbf{w_{t+1}} = \mathbf{w_{t}} - \alpha \frac{1}{2} \nabla \left[v_\pi(s) - \hat{v}_\pi(S_t, \mathbf{w})\right]^2\]

Then we take the derivative using the chain rule:

\[\mathbf{w_{t+1}} = \mathbf{w_{t}} + \alpha \left[{v_\pi(S_{t})} - \hat{v}_\pi(S_t, \mathbf{w})\right] \nabla \hat{v}_\pi(S_t, \mathbf{w})\]

So far so good, but if we knew $v_\pi(s)$ - the true value function - we wouldn’t need to be studying any of these methods! Instead we use a target to approximate $v_\pi(s)$. For example, we don’t know the true value of state $S_t$, so in our update rule we could replace the true value with a sample return $G_t$. Using $G_t$ as the target (in blue below) gives us the Monte Carlo update rule:

\[\mathbf{w_{t+1}} = \mathbf{w_{t}} + \alpha \left[\color{blue}{G_t} - \hat{v}_\pi(S_t, \mathbf{w})\right] \nabla \hat{v}_\pi(S_t, \mathbf{w})\]

These updates will be quite noisy, but actually have zero bias. Remember that the value is the expectation over returns. Therefore, if we always move where the returns lead us, we will be moving in the correct direction on average.

Using the return as a target can be annoying because (1) there is high variance2 and (2) we have to collect entire returns, which are themselves functions of entire trajectories of experience. It would be nice to take just a single step in the environment and then update our value function online.

We can accomplish this with a nice trick. Recall from part 1 that we can recursively define returns:

\[\begin{align} G_t & = R_{t+1} + \gamma R_{t+2} + \gamma^2R_{t+3} + \gamma^3R_{t+4} + \dots\\ & = R_{t+1} + \gamma(R_{t+2} + \gamma R_{t+3} + \gamma^2R_{t+4} +\dots)\\ & = R_{t+1} + \gamma G_{t+1} \end{align}\]

Using the fact that $\hat{v}(s, \mathbf{w})$ is an estimate of the average return, we will replace $G_t$ with $\hat{v}(s, \mathbf{w})$ and get:

\[G_t \approx R_{t+1} + \gamma \hat{v}(S_{t+1}, \mathbf{w})\]

Think of the latter term as a bootstrapped estimate of the tail end of the real return. In other words, we take one step in the world and see what really happens (i.e. we observe $R_{t+1}$), the we estimate what would have happened next using our value function estimate. This estimate will likely be more stable than the return itself, giving us the added benefit of reduced variance updates3. Plugging this target (the blue stuff below) into the update rule gives us4:

\[\mathbf{w_{t+1}} = \mathbf{w_{t}} + \alpha \left[\color{blue}{R_t + \gamma \hat{v}_\pi(S_{t+1}, \mathbf{w})} - \hat{v}_\pi(S_t, \mathbf{w})\right] \nabla \hat{v}_\pi(S_t, \mathbf{w})\]

function approximators

Now we have a couple of methods for updating our weights. Notice, however, that we still have the $\nabla \hat{v}_\pi(S_t, \mathbf{w})$ term on the right of our update equations. This means the function we select needs to be differentiable with respect to $\mathbf{w}$. This is not hugely limiting; simple linear functions and complex (but still differentiable) neural networks are all valid options.

There are also non-parametric approaches to function approximation. For example, a nearest neighbors approach could take the mean target for the $k$ visited states that are closest to the current state. Alternatively, we could take a mean across all visited states weighted by their proximity to the current state. Such kernel methods use a function $k(s,s’)$ that assigns a weight to each previously visited state $s’$ when averaging their associated targets $g(s’)$:

\[\hat{v}(s,\mathbf{w}) = \sum_{s' \in \mathcal{D}} k(s,s') g(s')\]

control with function approximation

Estimating value functions is nice, but the goal is still to find nice policies! To do this we can learn action-value functions rather than state-value functions. Recall that action-value functions encode the expected return given a state-action pair:

\[q_\pi(s,a) = \mathbb{E}[G_t | S_t=s, A_t=a]\]

If we can accurately estimate the action-value function, our policy can simply select actions in each state that maximize the expected return, e.g. $A_t = \max_{a} q_\pi(S_t,a)$. By analogy to the weight updates above, we can update the weights of the action-value function (for example by following a SARSA target) like this:

\[\mathbf{w_{t+1}} = \mathbf{w_{t}} + \alpha \left[\color{blue}{R_t + \gamma \hat{q}_\pi(S_{t+1}, A_{t+1}, \mathbf{w})} - \hat{q}_\pi(S_t, A_t, \mathbf{w})\right] \nabla \hat{q}_\pi(S_{t}, A_t, \mathbf{w})\]

up next

Part 3 will build upon our previous discussion of dynamic programming and Monte Carlo reinforcement learning algorithms. At first glance, dynamic programming and Monte Carlo appear to be qualitatively different approaches to reinforcement learning. Whereas dynamic programming is model-based and relies on bootstrapping, Monte Carlo is model-free and relies on sampling interactions with the environment. We will introduce temporal difference learning, a model-free approach that uses both bootstrapping and sampling to learn online.

  1. Our gradient updates won’t explicitly consider the stationary distribution of states. But because we will observe states according to this distribution, they will be correct on average. This is the magic behind stochastic gradient descent methods; individual updates may be a little off (high variance), the updates will on average be correct (zero bias). 

  2. Targets based on returns can be high variance because individual returns, even when starting from the same state, can vary due to stochasticity in both the policy and the environment. 

  3. This is actually an example of so-called temporal difference learning, a hugely influence algorithm that will be covered in part 3

  4. There’s a subtlety here: when we took the gradient above, the original target $v_\pi(s)$ did not depend on $\mathbf{w}$. But our esimated target, $\hat{v}(s, \mathbf{w})$, depends on $\mathbf{w}$. Using a bootstrapped target containing our estimated value function invalidates the math. This approach is therefore called semi-gradient. It still works in practice, but at the expense of some theoretical guarantees.