Processing a large tensor by chunking¶
-
dlutils.
block_process_2d
(data, func, block_size=512, overlap=32, intermediate_as_double=False)[source]¶ Applies function to the given 4-dimensional tensor by chucking it in 2 dimensions and processing each chunk separately.
Input data is split into chuck across last two dimension. If input data is represented as NCHW, chucking will occur in HW space. Size of chunks and overlap can be modified. Overlap tells how much chunks should overlap.
Overlapping regions will be interpolated linearly/bilinearly between chunks.
Note
Type of data in tensor data is expected to be float-point one (either float or double)
- Parameters
- Returns
Result of the function func applied to input data
- Return type
torch.Tensor or list[torch.Tensor]
Example
def f(x): assert x.shape[2] <= 64 assert x.shape[3] <= 64 return x * x + x * x x = torch.ones(3, 3, 512, 512, dtype=torch.float32) r = dlutils.block_process_2d(x, f, block_size=32, overlap=8)