This article is saying that it can be numerically unstable in certain situations, not that it's theoretically incorrect.
goosedragons · 38m ago
It can be both. A mistake in AD primitives can lead to theoretically incorrect derivatives. With the system I use I have run into a few scenarios where edge cases aren't totally covered leading to the wrong result.
I have also run into numerical instability too.
froobius · 36m ago
> A mistake in AD primitives can lead to theoretically incorrect derivatives
Ok but that's true of any program. A mistake in the implementation of the program can lead to mistakes in the result of the program...
goosedragons · 30m ago
That's true! But it's also true that any program dealing with floats can run into numerical instability if care isn't taken to avoid it, no?
It's also not necessarily immediately obvious that the derivatives ARE wrong if the implementation is wrong.
omnicognate · 2h ago
Yeah, perhaps the actual title would be better: "The Numerical Analysis of Differentiable Simulation". (Rather than the subtitle, which is itself a poor rewording of the actual subtitle in the video.)
deckar01 · 1h ago
I have been using sympy while learning electron physics to automatically integrate linear charge densities. It works great symbolically, but often fails silently when the symbols are substituted with floats before integration.
With floats getting smaller and smaller in ML, It's hard to imagine anyone failing to learn this as one of their early experiences in the field.
The focus should not be on the possibility of error, but managing the error to be within acceptable limits. There's a hour long video there, and it's 3am, so I'm not sure how much of this covers management. Anyone familiar with it care to say?
Legend2440 · 1h ago
Floats are getting smaller and smaller only in neural networks - which works ok because NNs are designed for nice easy backprop. Normalization and skip connections really help avoid numerical instability.
The author is talking about ODEs and PDEs, which come from physical systems and don't have such nice properties.
NooneAtAll3 · 2h ago
[video]?
ChrisRackauckas · 1h ago
Interesting to see this here! This example is one of the ones mentioned in the appendix of https://arxiv.org/abs/2406.09699, specifically "4.1.2.4 When AD is algorithmically correct but numerically wrong".
If people want a tl;dr, the main idea is you can construct an ODE where the forward pass is trivial, i.e. the ODE solution going forwards is exact, but its derivative is "hard". An easy way to do this is to make it so you have for example `x' = x - y, y' = y - x`, with initial conditions x(0)=y(0). If you start with things being the same value, then the solution to the ODE is constant since `x' = y' = 0`. But the derivative of the ODE solution with respect to its initial condition is very non-zero: a small change away from equality and boom the solution explodes. You write out the expression for dy(t)/dy(0) and what you get is a non-trivial ODE that has to be solved.
What happens in this case though is that automatic differentiation "runs the same code as the primal case". I.e., automatic differentiation has the property that it walks through your code and differentiates the steps of your code, and so the control flow always matches the control flow of the non-differentiated code, slapping the derivative parts on each step. But here the primal case is trivial, so the ODE solver goes "this is easy, let's make dt as big as possible and step through this easily". But this means that `dt` is not error controlled in the derivative (the derivative, being a non-trivial ODE, needs to have a smaller dt and take small steps in order to get an accurate answer), so the derivative then gets error due to this large dt (or small number of steps). Via this construction you can make automatic differentiation give as much error as you want just by tweaking the parameters around.
Thus by this construction, automatic differentiation has no bound to the error it can give, and no this isn't floating point errors this is strictly building a function that does not converge to the correct derivative. It just goes to show that automatic differentiation of a function that computes a nice error controlled answer does not necessarily give a derivative that also has any sense of error control, that is a property that has to be proved and is not true in general. This of course is a bit of a contrived example to show the point that the error is unbounded, but then it points to real issues that can show up in user code (in fact, this example was found because a user opened an issue with a related model).
Then one thing that's noted in here too is that the Julia differential equation solvers hook into the AD system to explicitly "not do forward-mode AD correctly", incorporating the derivative terms into the time stepping adaptivity calculation, so that it is error controlled. The property that you get is that for these solvers you get more steps to the ODE when running it in AD mode than outside of AD mode, and that is a requirement if you want to ensure that the user's tolerance is respected. But that's explicitly "not correct" in terms of what forward-mode AD is, so calling forward-mode AD on the solver doesn't quite do forward mode AD of the solver's code "correctly" in order to give a more precise solution. That of course is a choice, you could instead choose to follow standard AD rules in such a code. The trade-off is between accuracy and complexity.
dleary · 20m ago
Thank you for this good description.
ismailmaj · 1h ago
The way I explained it to myself in the past why so much of the CUDA algorithms don't care much about numerical stability is that the error is a form of regularization (i.e. less overfitting over the data) in deep learning.
Nevermark · 17m ago
I am not quite sure what that means! :)
But reasons why deep learning training is very robust to moderately inaccuracy in gradients:
1. Locally, sigmoid and similar functions are the simplest smoothest possible non-linearity to propagate gradients through.
2. Globally, outside of deep recurrent networks, there is no recursion which makes the total function smooth and well behaved.
3. While the perfect gradient indicates the ideal direction to adjust parameters, for fastest improvement, all that is really needed to reduce error is to move parameters in the direction of the gradient signs, with a small enough step. That is a very low bar.
It's like telling an archer they just need to shoot an arrow in the direction of a target, but there is no need to hit it.
4. Finally, the perfect first order gradient is only meaningful at one point of the optimization surface. Moving away from that point, i.e. updating the parameters at all, and the gradient changes quickly.
So we are in gradient heuristic land even with "perfect" first order gradients. The most perfectly calculated gradient isn't actually "accurate" in the way we might assume.
To actually get an accurate gradient over a parameter step, would take fitting the local gradient with a second or third order polynomial. I.e. not just first, but second and third order derivatives. At vastly greater computational cost.
I have also run into numerical instability too.
Ok but that's true of any program. A mistake in the implementation of the program can lead to mistakes in the result of the program...
It's also not necessarily immediately obvious that the derivatives ARE wrong if the implementation is wrong.
https://github.com/sympy/sympy/issues/27675
The focus should not be on the possibility of error, but managing the error to be within acceptable limits. There's a hour long video there, and it's 3am, so I'm not sure how much of this covers management. Anyone familiar with it care to say?
The author is talking about ODEs and PDEs, which come from physical systems and don't have such nice properties.
If people want a tl;dr, the main idea is you can construct an ODE where the forward pass is trivial, i.e. the ODE solution going forwards is exact, but its derivative is "hard". An easy way to do this is to make it so you have for example `x' = x - y, y' = y - x`, with initial conditions x(0)=y(0). If you start with things being the same value, then the solution to the ODE is constant since `x' = y' = 0`. But the derivative of the ODE solution with respect to its initial condition is very non-zero: a small change away from equality and boom the solution explodes. You write out the expression for dy(t)/dy(0) and what you get is a non-trivial ODE that has to be solved.
What happens in this case though is that automatic differentiation "runs the same code as the primal case". I.e., automatic differentiation has the property that it walks through your code and differentiates the steps of your code, and so the control flow always matches the control flow of the non-differentiated code, slapping the derivative parts on each step. But here the primal case is trivial, so the ODE solver goes "this is easy, let's make dt as big as possible and step through this easily". But this means that `dt` is not error controlled in the derivative (the derivative, being a non-trivial ODE, needs to have a smaller dt and take small steps in order to get an accurate answer), so the derivative then gets error due to this large dt (or small number of steps). Via this construction you can make automatic differentiation give as much error as you want just by tweaking the parameters around.
Thus by this construction, automatic differentiation has no bound to the error it can give, and no this isn't floating point errors this is strictly building a function that does not converge to the correct derivative. It just goes to show that automatic differentiation of a function that computes a nice error controlled answer does not necessarily give a derivative that also has any sense of error control, that is a property that has to be proved and is not true in general. This of course is a bit of a contrived example to show the point that the error is unbounded, but then it points to real issues that can show up in user code (in fact, this example was found because a user opened an issue with a related model).
Then one thing that's noted in here too is that the Julia differential equation solvers hook into the AD system to explicitly "not do forward-mode AD correctly", incorporating the derivative terms into the time stepping adaptivity calculation, so that it is error controlled. The property that you get is that for these solvers you get more steps to the ODE when running it in AD mode than outside of AD mode, and that is a requirement if you want to ensure that the user's tolerance is respected. But that's explicitly "not correct" in terms of what forward-mode AD is, so calling forward-mode AD on the solver doesn't quite do forward mode AD of the solver's code "correctly" in order to give a more precise solution. That of course is a choice, you could instead choose to follow standard AD rules in such a code. The trade-off is between accuracy and complexity.
But reasons why deep learning training is very robust to moderately inaccuracy in gradients:
1. Locally, sigmoid and similar functions are the simplest smoothest possible non-linearity to propagate gradients through.
2. Globally, outside of deep recurrent networks, there is no recursion which makes the total function smooth and well behaved.
3. While the perfect gradient indicates the ideal direction to adjust parameters, for fastest improvement, all that is really needed to reduce error is to move parameters in the direction of the gradient signs, with a small enough step. That is a very low bar.
It's like telling an archer they just need to shoot an arrow in the direction of a target, but there is no need to hit it.
4. Finally, the perfect first order gradient is only meaningful at one point of the optimization surface. Moving away from that point, i.e. updating the parameters at all, and the gradient changes quickly.
So we are in gradient heuristic land even with "perfect" first order gradients. The most perfectly calculated gradient isn't actually "accurate" in the way we might assume.
To actually get an accurate gradient over a parameter step, would take fitting the local gradient with a second or third order polynomial. I.e. not just first, but second and third order derivatives. At vastly greater computational cost.