Batch provider - for parallel batch data processing¶
-
dlutils.
batch_provider
(data, batch_size, processor=None, worker_count=1, queue_size=16, report_progress=True)[source]¶ Return an object that produces a sequence of batches from input data
Input data is split into batches of size
batch_size
which are processed with functionprocessor
Data is split and processed by separate threads and dumped into a queue allowing continuous provision of data. The main purpose of this primitive is to provide easy to use tool for parallel batch processing/generation in background while main thread runs the main algorithm. Batches are processed in parallel, allowing better utilization of CPU cores and disk that may improve GPU utilization for DL tasks with Storage/IO bottleneck.This primitive can be used in various ways. For small datasets, the input
data
list may contain actual dataset, whileprocessor
function does from small to no data processing. For larger datasets,data
list may contain just filenames or keys whileprocessor
function reads data from disk or db.There are many purposes that function
processor
can be used for, depending on your use case.Reading data from disk or db
Data decoding, e.g. from JPEG.
Augmenting data, flipping, rotating adding nose, etc.
Concatenation of data, stacking to single ndarray, conversion to a tensor, uploading to GPU.
Data generation.
Note
Sequential order of batches is guaranteed only if number of workers is 1 (Default), otherwise batches might be supplied out of order.
- Parameters
data (list) – Input data, each entry in the list should be a separate data point.
batch_size (int) – Size of a batch. If size of data is not divisible by
batch_size
, then the last batch will have smaller size.processor (Callable[[list], Any], optional) – Function for processing batches. Receives slice of the
data
list as input. Can return object of any type. Defaults to None.worker_count (int, optional) – Number of workers, should be greater or equal to one. To process data in parallel and fully load CPU
worker_count
should be close to the number of CPU cores. Defaults to one.queue_size (int, optional) – Maximum size of the queue, which is number of batches to buffer. Should be larger than
worker_count
. Typically, one would want this to be as large as possible to amortize all disk IO and computational costs. Downside of large value is increased RAM consumption. Defaults to 16.report_progress (bool, optional) –
Print a progress bar similar to tqdm. You still may use tqdm if you set
report_progress
to False. To use tqdm just dofor x in tqdm(batch_provider(...)): ...
Defaults to True.
- Returns
An object that produces a sequence of batches.
next()
method of the iterator will return object that was produced byprocessor
function- Return type
Iterator
- Raises
StopIteration – When all data was iterated through. Stops the for loop.
Example
def process(batch): images = [misc.imread(x[0]) for x in batch] images = np.asarray(images, dtype=np.float32) images = images.transpose((0, 3, 1, 2)) labeles = [x[1] for x in batch] labeles = np.asarray(labeles, np.int) return torch.from_numpy(images) / 255.0, torch.from_numpy(labeles) data = [('some_list.jpg', 1), ('of_filenames.jpg', 2), ('etc.jpg', 4), ...] # filenames and labels batches = dlutils.batch_provider(data, 32, process) for images, labeles in batches: result = model(images) loss = F.nll_loss(result, labeles) loss.backward() optimizer.step()