Batching Iterables in Python
…when one by one is just too slow.
Python’s iterables are designed so that it is easy to take items from them, one at a time. Getting “batches” of items from them, in sub-iterables if you will, is slightly more complicated. Here are some of the ways I’ve found of doing this, in as type-friendly a manner.
Before we start, let’s have a look at what we’re aiming for:
>>> list(range(17))
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]
>>> [list(batch) for batch in batches(range(20), 3)]
[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11], [12, 13, 14], [15, 16]]
Note that the iterable we get back doesn’t have to be a list, nor do the batches themselves. They just have to be iterables. Also, if the length of the iterable is not divisible by the batch size, then the last batch is going to have to be shorter.
Now that we’ve got our requirements, let’s look at some code.
Slicing: The Naïve Way.
The slice operation of lists is an immediate and obvious choice for this operation. Batching a sequence can be done easily this way, for example:
def batches_with_slicing(iterable, batch_size) :
return (iterable[start::batch_size]
for start in range(0, len(iterable), batch_size)
This method, however, poses some serious problems:
- It needs to be able to calculate the length of the iterable. Not possible for infinite sequences, as well as iterators or generators, as they will be consumed in this operation.
- The iterable must support the slice operation, which is not done by iterators or generators.
In fact, this method only really works properly for lists, strings and tuples, despite the fact that batching an iterable shouldn’t really depend on any of those capabilities.
We have to generalize our approach to other sequence types, and beware of infinite sequences. Thankfully, we don’t have to go too far to get that kind of functionality.
Want to slice an iterator? Use islice.
The itertools module, contains the utility islice, which allows you to slice an iterable. It returns an iterator for the elements you sliced. Note that using islice will consume the source iterator, so tee it off if you want a copy.
The following code was posted on ActiveState as Cookbook Recipe #303279 (not by me).
from itertools import islice, chain
def batches(iterable, batch_size) :
iterable = iter(iterable)
while True :
batch = islice(iterable, size)
yield chain([next(batch)], batch)
There are a couple of points of interest about that method:
- iterable must be converted into an iterator to give it a __next__ method, so we can call it with the global next(iterable).
- The chaining is done in order to not yield an infinite number of empty iterators after the source iterable is exhausted. This is because, islice is perfectly content to keep serving empty iterators itself.
This method, despite being straightforward, is still incomplete, which leads us to our next approach…
The mystery that is groupby.
itertools.groupby is another esoteric tool of the itertools module. It’s function is to take each element of an iterable and compute a key based on the value, with a user defined key-function. It continues to do so until the keys for two successive elements in the iterable have different values, at which time it yields the tuple (key, sub_iterable) where sub_iterable is the iterator of all the elements which had the same key.
Even if two elements have the same key, they need not be in the same sub_iterator, since they need not be adjacent to each other in the iterable, which groupby requires.
The following code was posted as a comment to the above recipe:
from itertools import groupby
def batches(iterable, size) :
def ticker(x, s = size, a = [-1]):
r = a[0] = a[0] + 1
return r // s
for k, g in groupby(iterable, ticker):
yield g
Although this works perfectly, it takes some intuition to figure exactly why.
The key is, of course, the key function: ticker. What this function does, is generate the sequence number of the batch that the current element should be in, and returns it. That way, adjacent elements in the same batch will have the same batch number, and groupby will group them together.
However, the precise working of ticker is a bit arcane. I’ve posted the code verbatim from the recipe, which is about 4 years old, so it may be simply that it hasn’t been updated since then. For example, size need not be included as a default parameter in the definition of ticker, since it has access to it as-is.
The main point is that it is not possible to figure out how ticker works, without knowing that mutable default parameters of functions persist over function calls, so even though a is given the default value of [-1], and is never called with an overriding parameter, since it is changed within the function, those changes persist across the calls to that function.
Now in my opinion, this is a remarkably tedious way of doing something very simple.
Removing the arcane.
All I did was simplify the operation of the ticker function, to something that is easier to understand. Even though the previous approach works perfectly fine, I find this one easier to understand, and use.
import itertools
def batches(iterable, batch_size) :
'''Returns the given iterable split into batches, of size batch_size.'''
iterable = iter(iterable)
counter = itertools.count()
def ticker(key) :
return next(counter) // batch_size
for key, group in itertools.groupby(iterable, ticker) :
yield group
A very small change makes that much more readable. ticker gives a sequence of batch_size 0s, batch_size 1s, and so on. It automatically turns itself off after the end of the sequence, and things are still nice and lazy.
Do you have any observations? Let me know.
blog comments powered by Disqus