Natural Gradient - UCL Computer Science
Transcription
Natural Gradient - UCL Computer Science
1 Natural Gradient Daniel Worrall I. P ROBLEM SETTING A. A generalisation Say I have an objective function to minimise a scalar-valued function f . I can parametrise this function with respect to some coordinate system Θ as f (θ) and then find a solution to θ∗ = arg min [f (θ)] . (1) θ Alternatively I could find another parameterisation of f with respect to a different basis Φ as f (φ) and perform φ∗ = arg min [f (φ)] . (2) φ Let’s define a new basis N . Suppose we estimate ν ∗ with successive samples of an iterative scheme where ν k+1 ← ν k + g(ν0k , f ) (4) where ν k is our estimate of ν ∗ at iteration k, ν0k = {ν 0 , ν 1 , ..., ν k } is the current partial trajectory of samples and g is some function, we are yet to define. If we want a reasonable solution then we wish for ν k to be bounded so k−1 X (5) kν k k = ν 0 + g(ν0i , f ) < ∞. i=1 For arguments sake, I’m going to restrict my analysis to the scenario where we can define a continuous, invertible function between parameter spaces T : Θ → Φ, so in topology speak we say that Θ and Φ are topologically equivalent. I’m also going to focus on smooth, well-behaved f only. Starting from a bounded point and running the algorithm to infinity, this implies ∞ ∞ X X i g(ν0 , f ) < ∞. (6) g(ν0i , f ) ≤ Now, say we cannot find an exact solution to our optimisation so perform an iterative technique to find the minimiser. For either basis, if we initialise at the same point φ0 = T (θ0 ) and then run the algorithm we will return a trajectory of points {θ0∞ } = 0 1 ∞ {θ0 , θ1 , ..., θ∞ } and {φ∞ 0 } = {φ , φ , ..., φ }. There are many different trajectories satisfying this constraint, so we also note that eventually we wish for the algorithm to stop, so limk→∞ kg(ν0k , f )k = 0. Note, however, that this condition alone isn’t enough. We have already established that our start points are the same φ0 = T (θ0 ), we might also wish that the end points of the algorithm in either basis are the same φ∞ = T (θ∞ ). This is necessarily true if we are performing true global optimisation, but generally we focus on local problems. Now there is a school of thought called, the principle of covariance1 which says that ‘a consistent algorithm should give the same results independent of the units in which quantities are measured’, i.e. every point of either trajectory should be the same B. Local minimisation {φ} = T (θ}). (3) So we are aiming to develop a basis-independent algorithm to perform our optimisation. This is an attractive idea, because it offers a level of robustness with respect to how we decide to represent our data and as we shall see it is in a sense optimal, in that we are working with the natural parametrisation of the problem. 1 I have taken the naming convention from Mackay, who took the original idea from Knuth. I believe the name comes from the fact that we are dealing with covariant gradients i=1 i=1 Let’s consider the simpler problem of descending on a unique local minimum. We aim to satisfy the sufficient conditions ∇ν f (ν) = 0 ν > ∇2ν f (ν)ν > 0. (7) (8) 1) Steepest descent: The simplest descent scheme is steepest descent (SD), which seeks to satisfy (7) only, so really, there is no guarantee of even descending upon a minimum, just stationary points; nonetheless, it is widely used due to ease of implementation. The updates take the form of ν k+1 ← ν k + α∇ν f (ν k ). (9) where α is small and negative. We this see that this is similar to limk→∞ kg(ν0k , f )k = 0, in that if we are slowly converging on our minimum (by correct choice of α) then we expect limk→∞ k∇f (ν k )k = 0, so in this case g(ν0k , f ) = g(ν k , f ) = α∇ν f (ν k ). (10) 2 SD is first-order Markov2 , in that the next move depends on the current state only. The problem with SD is that it is basis-dependent. To see this explicitly we use Θ and Φ again as our basis pair and define the function f expressed in different bases as fθ (θ) = fθ T −1 (φ) = fφ (φ) = fφ (T (θ)) . (11) If we consider differentials then this becomes n X |dw|2 = (dwi )2 . If the coordinate system is non-orthonormal, however, then the squared length is given locally by |dw|2 = The steepest descent update in Φ is φk+1 = φk + α∇φ fφ (φk ) (12) (13) Given T (θk ) = φk in order for {φ} = T ({θ}), we require T (θk+1 ) = φk+1 after the update. Is this the case? Immediately we see that the algorithm will only be globally invariant under linear transformations because what we are really asking is to evaluate whether T (∇θ fθ (θk )) = ∇φ fφ (φk ). n X gij (w)dwi dwj . (18) i,j and the same update in Θ is θk+1 = θk + α∇θ fθ (θk ). (17) i=1 (14) gij evaluated at a particular point in space with respect to a given basis returns a matrix G called the Riemannian metric tensor. When the parameter space is a curved manifold we have to resort to using this kind of approximation at each point in space. This is the Riemannian space. Now in the normal Euclidean setting gij = δij , the Kronecker delta, so the squared length reduces to the usual dot product form. This is the source of much of our confusion and why transformations preserve volume and angle leave naive SD invariant. So what is the steepest descent direction, factoring in the Riemannian metric? This can be found with a simple optimisation. arg min f (ν + dν) − f (ν) subject to |a|2 = a> Ga = 1. (19) In reality though, we are only ever concerning with small volumes of parameter space, which we can approximate as linear. To make reading easier, we adopt the notation g = ∇θ fθ (θk ) and g 0 = ∇φ fφ (φk ), so This can be solved simply using Lagrange multipliers to yield the natural gradient T (g) = [∇θ T (θk )]g 0 ˜ (ν) = G−1 ∇f (ν). ∇f (15) which wasn’t what we were hoping for T (g) = g 0 . The only transformations, which leave the SD algorithm invariant are (locally) volume preserving rotations3 i.e. orthonormal ones. This is a fairly poor set of transformations. Surely we can do better than this! The problem is that we have not actually chosen the direction of greatest reduction in f . Note this is different to the direction of steepest descent! C. Another point of view The partial gradient direction ∇f paradoxically does not change f the most. How can this be? The problem arises from the fact that we are assuming that we are computing everything in Euclidean space. In fact we need to consider the more general Riemannian space. In Euclidean space the distance between two points x and y is computed as the root of the square of their difference i.e. q (16) d(x, y) = (x − y)> (x − y). 2 My naming convention can also add reflections, but we want to impose the positive definiteness conditions, which aren’t strictly part of SD 3 We a:dν=a (20) The natural gradient changes the nature of the partial gradient ∇f (ν) such that it transforms in a different way when we change basis. Anyone with a background in differential geometry will recognise immediately that we are simply converting a covariant gradient into a contravariant gradient by index raising using the metric tensor. D. How the natural gradient links with covariant optimisation The natural gradient steepest descent method is a covariant optimisation on a local level, because transformations of the parameter space leave the algorithm invariant. Globally, we cannot define a mapping such that the partial gradient is invariant completely, but do we really need to do this? In reality, we are only going to concern ourselves with small patches of the parameter space because we want to keep our step sizes α small. These patches we approximate as these Riemannian manifolds where we can approximate a distance metric locally. So whilst we view transformations locally as these linear mappings, we also need to remember to reevaluate what we mean by distance within each locality to maintain a coherent view of the world. The natural gradient is the gradient of the function with respect to the parameters, with this redefinition of distance in mind. 3 Going back to the transformation example, if we were to use natural gradient descent we would get θk+1 = θk + αG−1 g (21) which under transformation becomes, using A> = [∇θ T (θk )] to clean up notation T (θk+1 ) = T (θk ) + αT (G−1 g) k > k −1 −1 = T (θ ) + α(A GA) = T (θ ) + αA = φk + αG−1 g 0 −1 G g (22) > A g (23) (24) (25)