Finding the top-p elements as used in Nucleus Sampling29 Aug 2023
Here’s a quick report on some experiments on algorithms to figure out the top-p elements of a distribution.
Why Nucleus Sampling?
The final stage of the current LLMs need to figure out how to generate the next token given a distribution over all tokens. Using the maximum element directly (greedy search) is the simplest solution but does not lead to interesting results. On the other end of the spectrum, sampling from the full distribution is risky because a significant share of the probability mass may consist of a large number of low probability tokens and accidentally selecting an unlikely token may lead the further text generation astray.1
Many potential improvements have been proposed. A simple, and widely-supported one is top-k sampling, where the k most probable tokens are selected and the probability mass is redistributed over those k tokens. It is easy to see when this can go wrong, e.g. in a common scenario where there’s a single token with a very high probability and all other tokens are similarly unlikely. Or, in the opposite case, of a somewhat flat distribution over many tokens where that flat part is cut short seemingly arbitrarily.
A more sophisticated approach is top-p sampling, also known as Nucleus Sampling. Here, the top tokens are selected until the cumulative probability mass exceeds a threshold p.
While working on llama2.scala, the practical problem of how to figure out the top-p elements out of the 32000 potential tokens (of the commonly used vocabulary), can be selected in a performant way. This is only relevant for small models where sampling time might become significant.
I implemented a few algorithms and compared them.
The naive approach is to sort the tokens by their probability and then select the top-p elements keeping track of the cumulative probability. This is the baseline to beat. This is simple to implement using a sorting implementations from your runtime of choice.
Since we need to keep track of the original indices of the elements (i.e. the identity of the token in the vocabulary), sorting can be done on the list of indices and then the original elements can be looked up. Alternatively, for better locality, the probabilities and indices could be kept together (leading to better locality but a more expensive swap operation).
Drawbacks of this approach:
- In the common case only few elements will be selected, so sorting the whole array is wasteful.
- The sorting algorithm does not take some properties of the distribution into account:
- The distribution is likely to be skewed (expecting an inverse power law distribution)
- The distribution sums up to 1
Filter out some elements and then sort
The first improvement is to filter out some elements that are impossible to be selected and then sort the remaining ones going on with the naive approach.
To figure out a conservative bound, consider this: In the edge case, only one element is selected that has exactly
p. In that case, all the remaining
n-1 elements will have to share the remaining probability of
Thus, all elements below
(1-p) / (n-1) can be filtered out immediately because they will never be selected. Note,
how the denominator will only get smaller when more elements are selected. Also, when the single element has a higher
p, the numerator will be bigger. Thus, any other scenario will lead to an even higher value, so
the given value is a conservative bound.
Filter out some elements and then iteratively select the next element
Instead of using a full-blown sort, the idea is here that after filtering we will be left with a small enough number of elements that using something like a selection sort becomes feasible (especially in the common case). In this algorithm, we keep track of the cumulative probability and select the next biggest element as long as the cumulative probability is still below p.
The runtime depends on the numbers of elements that will ultimately be selected. In the worst case, all elements
will be selected and the list is traversed
k times. One way to speed up the iterations, is to take into account,
that we know the probability mass that is still available in each iteration. If we find a new maximum element that is
bigger than half of the remaining probability mass, we can stop this iteration early and go on with selecting the
Build a histogram of elements
Observing that selecting the top-p elements is equivalent to finding the probability of the last element that is still included in the selection, one idea is to build a histogram of the elements to filter in and out whole buckets after a single iteration over the values.
After the first step, we know a more precise bound on the probability of the last element. We can then only operate on the bucket that contains the last element and find the exact result.
A difficult question in this algorithm is how to set up the buckets. Even with top-p, the last element might already be part of a long, flat tail, so that regardless of exact choice of buckets, the last element might be part of a big bucket. Another consideration has to be how to practically calculate the bucket without spending to much computation time on it. A seemingly obvious choice might be to use a logarithmic scale, but calculating logarithm is expensive. Another possibility would be to extract the exponent of the floating point number and use that as the bucket.
Use a quick select algorithm
The quick select algorithm is a well-known algorithm to find the k-th smallest element in an array. It is a variant of the quick sort algorithm that only recurses into the part of the array that contains the k-th smallest element (and avoids sorting as much as possible).
I found it hard to implement, it’s easy to get the indices slightly wrong and choosing the right pivot is difficult. Also, care needs to be taken that the algorithm actually terminates. Depending on the choice of pivot it can happen that the algorithm does not make progress (WHY? Is that even true or a bug in the implementation?) It seems the way this is dealt with is by choosing the pivot randomly which means that eventually the right pivot will be chosen to make progress (there’s always one that works).
bench/jmh:run -f1 -wi 2 TopP* at 7f0774b3 on my machine gives the following results:
sorting thrpt 5 122.276 ± 1.498 ops/s filterAndSort thrpt 5 2417.850 ± 135.264 ops/s filterAndFastSort thrpt 5 17899.529 ± 600.464 ops/s filterAndScan thrpt 5 23479.921 ± 895.475 ops/s quick-find-max-top-p thrpt 5 5294.213 ± 302.202 ops/s histogram thrpt 5 19456.664 ± 210.921 ops/s quick select -- has still bugs
sorting= Naive idiomatic Scala sorting
filterAndSort= Filter out some elements and then sort, idiomatic Scala
filterAndFastSort= Filter out some elements and then sort, avoid expensive Scala collections
filterAndScan= Filter out some elements and then iteratively select the next element
quick-find-max-top-p= Tries to use filtering and selecting in each step, iterating over all elements all the time
histogram= Build a histogram of elements, then use previous algorithm on the last bucket
The results are an average over 100 different distributions (generated by running the
- Just doing the filtering does most of the work. Due to its simplicity this is what I contributed to llama2.c.
- Doing filtering and sorting in Scala requires a bit of work to avoiding the overhead of Scala collections.
- Avoiding a full-blown sort and replacing it with a selection sort gives a little bit of extra performance.
- Building a histogram before doing the selection sort steps, does not seem to be quite worth it. In theory, it should further reduce the number of elements to consider for the selection sort, but the additional complexity is not well amortized.
- In general, a first filtering step is worth it, maybe just because it reduces the amount of data by such a degree that further iterations over all elements are much cheaper than before.
- The results are an average over 100 distributions. In many cases, only a handful of elements is selected (for top-p = 0.9) A heuristic could be developed to detect cases with very few elements (e.g. while filtering count elements above a certain threshold), and use a fast-path for those cases.
- More work is needed to get the quick select algorithm right.
You can find the implementations and benchmarks in the
This was a fun experiment doing some algorithm engineering and benchmarking. I hope you enjoyed it as well!
This is a consequence of autoregressive LLMs “just predicting the next word/token”. This is often stated as an seemingly intuitive drawback of how LLMs are trained and how they generate text. For me, it is not all that clear, why predicting just the next token is obviously weak. That said, of course, this main architectural property has consequences. As I see it, a (2023 generation) LLM is a very complex state machine that can solve a (maybe surprisingly) wide range of tasks given enough depths/width/parameters. The state is built up by calculating key/value pairs for each token for each layer (and multiple heads) of transformers. The whole machine is instructed and clocked token-by-token of the sequence. In generation mode, the machine generates its own next instruction auto-regressively. So, the choice of the next token is very important as the machine cannot take back its choice. The only thing it might do, is try to weasel itself out of a bad choice by generating more text. Note, how this kind of backtracking by adding apologies, is only applicable to some language generation situations. For example, in a chat situation, the machine can apologize and use more user input as ways to take the conversation further getting meaningful back on track. However, e.g. when generating code or generating under a stricter regime, this might not be feasible and will have to lead to incoherence, hallucinations, or outright garbage. ↩