import unittest from mpi4py import MPI datatypes = (MPI.CHAR, MPI.SHORT, MPI.INT, MPI.LONG, MPI.FLOAT, MPI.DOUBLE) class TestDatatype(unittest.TestCase): def testGetExtent(self): for dtype in datatypes: lb, ext = dtype.Get_extent() def testGetSize(self): for dtype in datatypes: size = dtype.Get_size() def testGetTrueExtent(self): for dtype in datatypes: try: lb, ext = dtype.Get_true_extent() except NotImplementedError: return def _create(self, factory, *args): try: newtype = factory(*args) except NotImplementedError: return newtype.Commit() newtype.Free() def testDup(self): nulldtype = MPI.DATATYPE_NULL self.assertRaises(MPI.Exception, nulldtype.Dup) for dtype in datatypes: factory = dtype.Dup self._create(factory) def testCreateContiguous(self): for dtype in datatypes: for count in range(5): factory = dtype.Create_contiguous args = (count, ) self._create(factory, *args) def testCreateVector(self): for dtype in datatypes: for count in range(5): for blocklength in range(5): for stride in range(5): factory = dtype.Create_vector args = (count, blocklength, stride) self._create(factory, *args) def testCreateHvector(self): for dtype in datatypes: for count in range(5): for blocklength in range(5): for stride in range(5): factory = dtype.Create_hvector args = (count, blocklength, stride) self._create(factory, *args) def testCreateIndexed(self): for dtype in datatypes: for block in range(5): blocklengths = list(range(block, block+5)) displacements = [0] for b in blocklengths[:-1]: stride = displacements[-1] + b * dtype.extent + 1 displacements.append(stride) factory = dtype.Create_indexed args = (blocklengths, displacements) self._create(factory, *args) args = (block, displacements) self._create(factory, *args) def testCreateIndexedBlock(self): for dtype in datatypes: for block in range(5): blocklengths = list(range(block, block+5)) displacements = [0] for b in blocklengths[:-1]: stride = displacements[-1] + b * dtype.extent + 1 displacements.append(stride) factory = dtype.Create_indexed_block args = (block, displacements) self._create(factory, *args) def testCreateHindexed(self): for dtype in datatypes: for block in range(5): blocklengths = list(range(block, block+5)) displacements = [0] for b in blocklengths[:-1]: stride = displacements[-1] + b * dtype.extent + 1 displacements.append(stride) factory = dtype.Create_hindexed args = (blocklengths, displacements) self._create(factory, *args) args = (block, displacements) self._create(factory, *args) def testCreateStruct(self): dtypes = datatypes for dtype1 in datatypes: for dtype2 in datatypes: for dtype3 in datatypes: dtypes = (dtype1, dtype2, dtype3) blocklengths = list(range(1, len(dtypes) + 1)) displacements = [0] for dtype in dtypes[:-1]: stride = displacements[-1] + dtype.extent displacements.append(stride) factory = MPI.Datatype.Create_struct args = (blocklengths, displacements, dtypes) self._create(factory, *args) def testCreateSubarray(self): for dtype in datatypes: for ndim in range(1, 5): for size in range(1, 5): for subsize in range(1, size): for start in range(size-subsize): for order in [MPI.ORDER_C, MPI.ORDER_FORTRAN, 'c', 'f', 'C', 'F']: sizes = [size] * ndim subsizes = [subsize] * ndim starts = [start] * ndim factory = dtype.Create_subarray args = sizes, subsizes, starts, order self._create(factory, *args) def testResized(self): for dtype in datatypes: for lb in range(-10, 10): for extent in range(1, 10): factory = dtype.Resized args = lb, extent self._create(factory, *args) def testGetSetName(self): for dtype in datatypes: try: name = dtype.Get_name() self.assertTrue(name) dtype.Set_name(name) except NotImplementedError: return def testCommit(self): nulltype = MPI.DATATYPE_NULL self.assertRaises(MPI.Exception, nulltype.Commit) for dtype in datatypes: dtype.Commit() def testFree(self): nulltype = MPI.DATATYPE_NULL self.assertRaises(MPI.Exception, nulltype.Free) for dtype in (MPI.CHAR, MPI.WCHAR, MPI.SIGNED_CHAR, MPI.UNSIGNED_CHAR, MPI.SHORT, MPI.UNSIGNED_SHORT, MPI.INT, MPI.UNSIGNED, MPI.UNSIGNED_INT, MPI.LONG, MPI.UNSIGNED_LONG, MPI.LONG_LONG, MPI.UNSIGNED_LONG_LONG, MPI.FLOAT, MPI.DOUBLE, MPI.LONG_DOUBLE, MPI.BYTE, MPI.PACKED, MPI.SHORT_INT, MPI.TWOINT, MPI.INT_INT, MPI.LONG_INT, MPI.LONG_LONG_INT, MPI.FLOAT_INT, MPI.DOUBLE_INT, MPI.LONG_DOUBLE_INT, MPI.UB, MPI.LB,): self.assertTrue(dtype) self.assertRaises(MPI.Exception, dtype.Free) self.assertTrue(dtype) class TestGetAddress(unittest.TestCase): def testGetAddress(self): from array import array location = array('i', range(10)) addr = MPI.Get_address(location) bufptr, buflen = location.buffer_info() self.assertEqual(addr, bufptr) _name, _version = MPI._mpi_info() if _name == 'OpenMPI': del TestDatatype.testFree # XXX for review in the future if __name__ == '__main__': try: unittest.main() except SystemExit: pass