Many years ago I've re-posted a Stack Overflow answer with Python code for a terse prime sieve function that generates a potentially infinite sequence of prime numbers ("potentially" because it will run out of memory eventually). Since then, I've used this code many times - mostly because it's short and clear. In this post I will explain how this code works, where it comes from (I didn't come up with it), and some potential optimizations. If you want a teaser, here it is:
def gen_primes():
"""Generate an infinite sequence of prime numbers."""
D = {}
q = 2
while True:
if q not in D:
D[q * q] = [q]
yield q
else:
for p in D[q]:
D.setdefault(p + q, []).append(p)
del D[q]
q += 1
The sieve of Eratosthenes
To understand what this code does, we should first start with the basic Sieve of Eratosthenes; if you're familiar with it, feel free to skip this section.
The Sieve of Eratosthenes is a well-known algorithm from ancient Greek times for finding all the primes below a certain number reasonably efficiently using a tabular representation. This animation from Wikipedia explains it pretty well:
Starting with the first prime (2) it marks all its multiples until the requested limit. It then takes the next unmarked number, assumes it's a prime (because it is not a multiple of a smaller prime), and marks its multiples, and so on until all the multiples below the limit are marked. The remaining unmarked numbers are primes.
Here's a well-commented, basic Python implementation:
def gen_primes_upto(n):
"""Generates a sequence of primes < n.
Uses the full sieve of Eratosthenes with O(n) memory.
"""
if n == 2:
return
# Initialize table; True means "prime", initially assuming all numbers
# are prime.
table = [True] * n
sqrtn = int(math.ceil(math.sqrt(n)))
# Starting with 2, for each True (prime) number I in the table, mark all
# its multiples as composite (starting with I*I, since earlier multiples
# should have already been marked as multiples of smaller primes).
# At the end of this process, the remaining True items in the table are
# primes, and the False items are composites.
for i in range(2, sqrtn):
if table[i]:
for j in range(i * i, n, i):
table[j] = False
# Yield all the primes in the table.
yield 2
for i in range(3, n, 2):
if table[i]:
yield i
When we want a list of all the primes below some known limit, gen_primes_upto is great, and performs fairly well. There are two issues with it, though:
- We have to know what the limit is ahead of time; this isn't always possible or convenient.
- Its memory usage is high - O(n); this can be significantly optimized, however; see the bonus section at the end of the post for details.
The infinite prime generator
Back to the infinite prime generator that's in the focus of this post. Here is its code again, now with some comments:
def gen_primes():
"""Generate an infinite sequence of prime numbers."""
# Maps composites to primes witnessing their compositeness.
D = {}
# The running integer that's checked for primeness
q = 2
while True:
if q not in D:
# q is a new prime.
# Yield it and mark its first multiple that isn't
# already marked in previous iterations
D[q * q] = [q]
yield q
else:
# q is composite. D[q] holds some of the primes that
# divide it. Since we've reached q, we no longer
# need it in the map, but we'll mark the next
# multiples of its witnesses to prepare for larger
# numbers
for p in D[q]:
D.setdefault(p + q, []).append(p)
del D[q]
q += 1
The key to the algorithm is the map D. It holds all the primes encountered so far, but not as keys! Rather, they are stored as values, with the keys being the next composite number they divide. This lets the program avoid having to divide each number it encounters by all the primes known so far - it can simply look in the map. A number that's not in the map is a new prime, and the way the map updates is not unlike the sieve of Eratosthenes - when a composite is removed, we add the next composite multiple of the same prime(s). This is guaranteed to cover all the composite numbers, while prime numbers should never be keys in D.
I highly recommend instrumenting this function with some printouts and running through a sample invocation - it makes it easy to understand how the algorithm makes progress.
Compared to the full sieve gen_primes_upto, this function doesn't require us to know the limit ahead of time - it will keep producing prime numbers ad infinitum (but will run out of memory eventually). As for memory usage, the D map has all the primes in it somewhere, but each one appears only once. So its size is , where is the Prime-counting function, the number of primes smaller or equal to n. This can be approximated by [1].
I don't remember where I first saw this approach mentioned, but all the breadcrumbs lead to this ActiveState Recipe by David Eppstein from way back in 2002.
Optimizing the generator
I really like gen_primes; it's short, easy to understand and gives me as many primes as I need without forcing me to know what limit to use, and its memory usage is much more reasonable than the full-blown sieve of Eratosthenes. It is, however, also quite slow, over 5x slower than gen_primes_upto.
The aforementioned ActiveState Recipe thread has several optimization ideas; here's a version that incorporates ideas from Alex Martelli, Tim Hochberg and Wolfgang Beneicke:
def gen_primes_opt():
yield 2
D = {}
for q in itertools.count(3, step=2):
p = D.pop(q, None)
if not p:
D[q * q] = q
yield q
else:
x = q + p + p # get odd multiples
while x in D:
x += p + p
D[x] = p
The optimizations are:
- Instead of holding a list as the value of D, just have a single number. In cases where we need more than one witness to a composite, find the next multiple of the witness and assign that instead (this is the while x in D inner loop in the else clause). This is a bit like using linear probing in a hash table instead of having a list per bucket.
- Skip even numbers by starting with 2 and then proceeding from 3 in steps of 2.
- The loop assigning the next multiple of witnesses may land on even numbers (when p and q are both odd). So instead jump to q + p + p directly, which is guaranteed to be odd.
With these in place, the function is more than 3x faster than before, and is now only within 40% or so of gen_primes_upto, while remaining short and reasonably clear.
There are even fancier algorithms that use interesting mathematical tricks to do less work. Here's an approach by Will Ness and Tim Peters (yes, that Tim Peters) that's reportedly faster. It uses the wheels idea from this paper by Sorenson. Some additional details on this approach are available here. This algorithm is both faster and consumes less memory; on the other hand, it's no longer short and simple.
To be honest, it always feels a bit odd to me to painfully optimize Python code, when switching languages provides vastly bigger benefits. For example, I threw together the same algorithms using Go and its experimental iterator support; it's 3x faster than the Python version, with very little effort (even though the new Go iterators and yield functions are still in the proposal stage and aren't optimized). I can't try to rewrite it in C++ or Rust for now, due to the lack of generator support; the yield statement is what makes this code so nice and elegant, and alternative idioms are much less convenient.
Bonus: segmented sieve of Eratosthenes
The Wikipedia article on the sieve of Eratosthenes mentions a segmented approach, which is also described in the Sorenson paper in section 5.
The main insight is that we only need the primes up to to be able to sieve a table all the way to N. This results in a sieve that uses only memory. Here's a commented Python implementation:
def gen_primes_upto_segmented(n):
"""Generates a sequence of primes < n.
Uses the segmented sieve or Eratosthenes algorithm with O(√n) memory.
"""
# Simplify boundary cases by hard-coding some small primes.
if n < 11:
for p in [2, 3, 5, 7]:
if p < n:
yield p
return
# We break the range [0..n) into segments of size √n
segsize = int(math.ceil(math.sqrt(n)))
# Find the primes in the first segment by calling the basic sieve on that
# segment (its memory usage will be O(√n)). We'll use these primes to
# sieve all subsequent segments.
baseprimes = list(gen_primes_upto(segsize))
for bp in baseprimes:
yield bp
for segstart in range(segsize, n, segsize):
# Create a new table of size √n for each segment; the old table
# is thrown away, so the total memory use here is √n
# seg[i] represents the number segstart+i
seg = [True] * segsize
for bp in baseprimes:
# The first multiple of bp in this segment can be calculated using
# modulo.
first_multiple = (
segstart if segstart % bp == 0 else segstart + bp - segstart % bp
)
# Mark all multiples of bp in the segment as composite.
for q in range(first_multiple, segstart + segsize, bp):
seg[q % len(seg)] = False
# Sieving is done; yield all composites in the segment (iterating only
# over the odd ones).
start = 1 if segstart % 2 == 0 else 0
for i in range(start, len(seg), 2):
if seg[i]:
if segstart + i >= n:
break
yield segstart + i
Code
The full code for this post - along with tests and benchmarks - is available on GitHub.
[1] | While this is a strong improvement over O(n) (e.g. for a billion primes, memory usage here is only 5% of the full sieve version), it still depends on the size of the input. In the unlikely event that you need to generate truly gigantic primes starting from 2, even the square-root-space solutions become infeasible. In this case, the whole approach should be changed; instead, one would just generate random huge numbers and use probabilistic primality testing to check for their primeness. This is what real libraries like Go's crypto/rand.Prime do. |