Files
2025-04-02 21:44:17 -07:00

31 lines
861 B
Python

import numpy as np
def prepare_shape_strides_dtype(shape, strides, dtype, order):
dtype = np.dtype(dtype)
if isinstance(shape, int):
shape = (shape,)
if isinstance(strides, int):
strides = (strides,)
else:
strides = strides or _fill_stride_by_order(shape, dtype, order)
return shape, strides, dtype
def _fill_stride_by_order(shape, dtype, order):
nd = len(shape)
if nd == 0:
return ()
strides = [0] * nd
if order == 'C':
strides[-1] = dtype.itemsize
for d in reversed(range(nd - 1)):
strides[d] = strides[d + 1] * shape[d + 1]
elif order == 'F':
strides[0] = dtype.itemsize
for d in range(1, nd):
strides[d] = strides[d - 1] * shape[d - 1]
else:
raise ValueError('must be either C/F order')
return tuple(strides)