PyTorch Internals, cuRAND, and numerical instability

Random sampling

I've been working lately to implement random samplers from a number of distributions in PyTorch, both on CPU and CUDA. This is a topic near and dear to my heart, since it has caused me a lot of trouble multiple times. Once this PR is merged, I'll post an explanation/notebook of why this is important.

Here's a brief summary of the motivation:

  1. We want to sample from distributions like \(\operatorname{Beta}(a, b)\). However, it's tricky, because up until recently PyTorch could only sample from a few basic distributions (Uniform, Normal, Exponential, etc.). This is a problem because most fast sampling algorithms for more complex distributions work via rejection sampling (or variants, like ARS), or via inverse transform sampling. The first is tricky because if you want to do it in parallel in pure PyTorch, you need to implement a tricky masking method, and the second is tricky because the inverse CDF is often hard to compute.

  2. Failing that, we can fork out to Numpy. After all, PyTorch seamlessly integrates with Numpy, which has long had excellent support for distributions (more on this later). However, sampling in Numpy involves an expensive CPU-GPU copy, which was actually significant in our models. In our work, the baseline used a Beta distribution, so it would be unfair to compare with this large performance hit.

  3. Finally, failing that, we can write C/CUDA code to sample, and link against PyTorch. That's exactly what we did. The downside of this is that CUDA random number generation is a little tricky, and NVIDIA's cuRAND library only implements a few random number generators. Also, since I am only a makefile novice, it took me forever to get it to compile on Odyssey, and promptly didn't work when I tried to use it on a different environment.

So, my goal lately is to port some of the knowledge gained to PyTorch proper. That way, other researchers can get random \(\operatorname{Beta}(a, b)\) samples, fast, without having to jump through all the hoops.

PyTorch internals

PyTorch as a project is pretty complex, but can be surprisingly easy to contribute to if you know where to look. Unfortunately the documentation on internals is sparse 1, and there's two things that make it difficult: there's a mixture of C/C++/CUDA/Python code throughout, and it's glued together with a lot of codegen.

Why is this necessary? PyTorch is a Python library that communicates with C/C++ code (for fast CPU operations), and CUDA (for fast GPU operations). Since there are many data types supported, a lot of the code would be tedious: all of

THFloatTensor * add(THFloatTensor *a, THFloatTensor *b);
THDoubleTensor * add(THDoubleTensor *a, THDoubleTensor *b);
THCFloatTensor * add(THCFloatTensor *a, THCFloatTensor *b);

probably have the same implementation. Imagine repeating that 15 times! So not only are the FFI interfaces generated, but the function signatures and implementations too.

Very recently, ATen has made the story somewhat simpler by leveraging C++11 and namespacing to eliminate macros 2.

Here's a few notes I found useful while trying to understand how the build works:

  1. There are 2 different codegen systems: cwrap for generating Python interfaces for some underlying code, and .yaml for an interface from Variable to ATen. So, the torch/csrc/generic/**/*.cwrap files generate Python interfaces and versions of the THTensor_(...) methods for each type, which are dispatched based on the type used. You can jump into that via here.

    For the .yaml files, ATen builds its own interface via this file and outputs Declarations.yaml. Then, reads Declarations.yaml and writes the corresponding Python interface, using gen_variable_type and the derivatives.yaml file. The latter also has information about what the gradient of an operation is.

  2. While building, all the information in is very helpful in keeping iteration time down. Also helpful: rewrite build_deps inside to just build your component (e.g. ATen). Sometimes it gets screwed up and running python clean is the remedy.

  3. The ATen codegen (starting with, but mostly in generates the glue that dispatches the correct function based on types. After building, you can find these files in torch/lib/build/aten/src/ATen/ATen/. If you want to mess with the generation, you can modify just find the spot where the corresponding code is generated, and modify options to do what you need. Note that to change just one code path, you'll need to modify many of the codegen points, so look for all of them (Functions.h, CPU[Type]Type.h, etc.).

Mostly I figured this out by running the build, using ag -G [something] [term], and find . -name "[regexp]". If you're poking around, they will likely be useful as well. NOTE: by default, ag or rg will ignore the files in your .gitignore. This includes generated build files!

A story about RNG

Recently I was implementing a Poisson sampler using essentially rejection sampling, and found that it didn't work. Here's the code:

__device__ int64_t sample_poisson(double lambda, curandStateMtgp32 *states) {
  if (lambda < 10) {
    double enlam = std::exp(-lambda);
    int64_t X = 0;
    double prod = 1.0;
    double U = 0;
    while (1) {
      U = curand_uniform_double(&states[blockIdx.x]);
      prod *= U;
      if (prod > enlam) {
        X += 1;
      else {
        return X;
  ... // more special case code for values of lambda

In particular, if a thread didn't exit in the first or second samples, it would never exit the while loop. I spent a while debugging, and realized that even though calls to curand_uniform_double were uniformly distributed in isolation, adding rejection sampling would cause it to repeat values. The calls are curand_uniform_double(state) for some RNG state state, but state was fine since it generated uniform doubles in isolation. PyTorch uses a MTGP32-based sampler, so I eventually looked in the docs and found this line:

"At a given point in the code, all threads in the block, or none of them, must call this function."

So, what was happening is that threads that returned early didn't call the function, so it was undefined behavior. This means rejection sampling is hard! However, there's a solution. There's an alternative call, curand_mtgp32_single_specific, which takes a generator state, an index, and a count of the total number of threads that call it. As long as each index is unique and adds up the thread count, this will give uniformly distributed floats as expected. However, we do need to be a bit careful about how to synchronize because of warp divergence.

__device__ int64_t sample_poisson(double lambda, curandStateMtgp32 *states, int num_threads) {
  __shared__ int thread_count;
  if (threadIdx.x == 0) thread_count = num_threads;
  int64_t X = 0;
  int idx = threadIdx.x;
  float U = 0;
  float enlam = std::exp(-lambda);
  float prod = 1.0;

  while (thread_count != 0) {
    U = curand_mtgp32_single_specific(&states[blockIdx.x], idx, thread_count);
    prod *= U;
    if (prod > enlam) {
      X += 1;
    if (idx == 0) {
      thread_count = 0;
    if (prod > enlam) {
      idx = atomicAdd(&thread_count, 1); // counts 'living' threads

While it's neat, for a few reasons unfortunately it's not quite appropriate for PyTorch, so we'll look into other solutions. For the Poisson, at least, there's a curand_poisson which implements it natively for us.

Some thoughts

One problem that bothered me for more than a week on the IBP project was that our implementation of Beta BBVI went haywire when I used my CUDA sampler. So, following Finale's advice, I made some qq-plots, but couldn't see any real issues. The reason: was sampling using the identity

\[z \sim \operatorname{Beta}(a, b) \implies z \sim \frac{\operatorname{Gamma}(a)}{\operatorname{Gamma}(a) + \operatorname{Gamma}(b)}\]

since you know, that's what I learned in Stat 210. But! This is numerically unstable when both \(a, b \leq 1\). The solution was found while digging through Numpy's code here, which taught me to respect my elders, or at least to respect Numpy.

I wonder whether there's any work still going on for fast random number sampling. It's not something I'm directly interested in, but something I'm curious about.

Another fun story: when later trying to calculate log of the Beta function, I was on my guard and checked out the Cephes implementation, which is roughly 30 years old. At the top it says:

"Direct inquiries to 30 Frost Street, Cambridge, MA 02140"

which is about 2 blocks from where I live.

  1. There's some other blog posts by the PyTorch folks here, definitely also worth checking out. [return]
  2. Which are the devil. My operating systems course, as excellent as it was, was entirely in C and implemented arrays via macros. [return]