Processing a large tensor by chunking

dlutils.block_process_2d(data, func, block_size=512, overlap=32, intermediate_as_double=False)[source]

Applies function to the given 4-dimensional tensor by chucking it in 2 dimensions and processing each chunk separately.

Input data is split into chuck across last two dimension. If input data is represented as NCHW, chucking will occur in HW space. Size of chunks and overlap can be modified. Overlap tells how much chunks should overlap.

Overlapping regions will be interpolated linearly/bilinearly between chunks.

Note

Type of data in tensor data is expected to be float-point one (either float or double)

Parameters
  • data (torch.Tensor) – Input data to be processed by func.

  • func (Callable) – Function to be used on data

  • block_size (int) – Size of chunks

  • overlap (int) – Tells how much chunks should overlap

  • intermediate_as_double (bool) – Use double for intermidiate representation, improves accuracy

Returns

Result of the function func applied to input data

Return type

torch.Tensor or list[torch.Tensor]

Example

def f(x):
    assert x.shape[2] <= 64
    assert x.shape[3] <= 64
    return x * x + x * x

x = torch.ones(3, 3, 512, 512, dtype=torch.float32)
r = dlutils.block_process_2d(x, f, block_size=32, overlap=8)