Sophie

Sophie

distrib > Mandriva > 2008.1 > i586 > by-pkgid > 2fe96174012fea2d88f752857a5bea1d > files > 83

python-mpi4py-0.6.0-4mdv2008.1.i586.rpm

import unittest
from mpi4py import MPI
from array import array

typemap = dict(#b=MPI.SIGNED_CHAR,
               h=MPI.SHORT,
               i=MPI.INT,
               l=MPI.LONG,
               f=MPI.FLOAT,
               d=MPI.DOUBLE)

arrayimpl = []

try:
    import array
    def mk_dtype_array(a, datatype):
        return datatype or typemap[a.typecode]
    def mk_buf_array_1(a, dt=None, s=1, c=None):
        return MPI.Buffer(a, (c or len(a))//s, mk_dtype_array(a, dt))
    def mk_buf_array_2(a, dt=None, s=1, c=None):
        bptr, blen = a.buffer_info()
        return MPI.Buffer((bptr, False), (c or blen)//s, mk_dtype_array(a, dt))
    def mk_buf_array_3(a, dt=None, s=1, c=None):
        return (a, (c or len(a))//s, mk_dtype_array(a, dt))
    def mk_buf_array_4(a, dt=None, s=1, c=None):
        bptr, blen = a.buffer_info()
        return [(bptr, False), (c or blen)//s, mk_dtype_array(a, dt)]
    mk_buf_array = (mk_buf_array_1, mk_buf_array_2,
                    mk_buf_array_3, mk_buf_array_4)
    mk_arr_array = lambda typecode, init: array.array(typecode, init)
    eq_arr_array = lambda a, b : a == b
    arrayimpl.append((mk_buf_array, mk_arr_array, eq_arr_array))
except ImportError:
    pass

try:
    import numpy
    def mk_dtype_numpy(a, datatype):
        return datatype or typemap[a.dtype.char]
    def mk_buf_numpy_1(a, dt=None, s=1, c=None):
        return MPI.Buffer(a, (c or a.size)//s, mk_dtype_numpy(a, dt))
    def mk_buf_numpy_2(a, dt=None, s=1, c=None):
        return MPI.Buffer(a.data, (c or a.size)//s, mk_dtype_numpy(a, dt))
    def mk_buf_numpy_3(a, dt=None, s=1, c=None):
        data = a.__array_interface__['data']
        return MPI.Buffer(data,  (c or a.size)//s, mk_dtype_numpy(a, dt))
    def mk_buf_numpy_4(a, dt=None, s=1, c=None):
        return (a,  (c or a.size)//s, mk_dtype_numpy(a, dt))
    def mk_buf_numpy_5(a, dt=None, s=1, c=None):
        return (a.data, (c or a.size)//s, mk_dtype_numpy(a, dt))
    def mk_buf_numpy_6(a, dt=None, s=1, c=None):
        data = a.__array_interface__['data']
        return (data,  (c or a.size)//s, mk_dtype_numpy(a, dt))
    mk_buf_numpy = (mk_buf_numpy_1, mk_buf_numpy_2, mk_buf_numpy_3,
                    mk_buf_numpy_4, mk_buf_numpy_5, mk_buf_numpy_6,
                    )
    mk_arr_numpy = lambda typecode, init: numpy.array(init, dtype=typecode)
    eq_arr_numpy = lambda a, b : numpy.allclose(a, b)
    arrayimpl.append((mk_buf_numpy, mk_arr_numpy, eq_arr_numpy))
except ImportError:
    pass


def maxvalue(a):
    try:
        typecode = a.typecode
    except AttributeError:
        typecode = a.dtype.char
    if typecode == ('f'):
        return 1e30
    elif typecode == ('d'):
        return 1e300
    else:
        return 2 ** (a.itemsize * 7) - 1


class TestCollBufVecBase(object):

    COMM = MPI.COMM_NULL

    def testGatherv(self):
        size = self.COMM.Get_size()
        rank = self.COMM.Get_rank()
        for mkbufs, array, equal in arrayimpl:
            #for mkbuf in mkbufs:
                for typecode, datatype in typemap.items():
                    for root in range(size):
                        for count in range(size):
                            sbuf = array(typecode, [root]*(count))
                            rbuf = array(typecode, [  -1]*(size*size))
                            counts = [count] * size
                            displs = range(0, size*size, size)
                            recvbuf = [rbuf, (counts, displs), datatype]
                            if rank != root: recvbuf=None
                            self.COMM.Barrier()
                            self.COMM.Gatherv([sbuf, count, datatype],
                                              recvbuf, root)
                            self.COMM.Barrier()
                            if recvbuf is not None:
                                for i in range(size):
                                    row = rbuf[i*size:(i+1)*size]
                                    a, b = row[:count], row[count:]
                                    for va in a:
                                        self.assertEqual(va, root)
                                    for vb in b:
                                        self.assertEqual(vb, -1)

    def testScatterv(self):
        size = self.COMM.Get_size()
        rank = self.COMM.Get_rank()
        for mkbufs, array, equal in arrayimpl:
            #for mkbuf in mkbufs:
                for typecode, datatype in typemap.items():
                    for root in range(size):
                        for count in range(size):
                            sbuf = array(typecode, [root]*(size*size))
                            rbuf = array(typecode, [  -1]*(count))
                            counts = [count] * size
                            displs = range(0, size*size, size)
                            sendbuf = [sbuf, [counts, displs], datatype]
                            if rank != root: sendbuf = None
                            self.COMM.Scatterv(sendbuf,
                                               [rbuf, count, datatype], root)
                            for vr in rbuf:
                                self.assertEqual(vr, root)

    def testAllgatherv(self):
        size = self.COMM.Get_size()
        rank = self.COMM.Get_rank()
        for mkbufs, array, equal in arrayimpl:
            #for mkbuf in mkbufs:
                for typecode, datatype in typemap.items():
                    for root in range(size):
                        for count in range(size):
                            sbuf = array(typecode, [root]*(count))
                            rbuf = array(typecode, [  -1]*(size*size))
                            counts = [count] * size
                            displs = range(0, size*size, size)
                            sendbuf = [sbuf, count, datatype]
                            recvbuf = (rbuf, counts, displs, datatype)
                            self.COMM.Allgatherv(sendbuf, recvbuf)
                            for i in range(size):
                                row = rbuf[i*size:(i+1)*size]
                                a, b = row[:count], row[count:]
                                for va in a:
                                    self.assertEqual(va, root)
                                for vb in b:
                                    self.assertEqual(vb, -1)

    def testAlltoallv(self):
        size = self.COMM.Get_size()
        rank = self.COMM.Get_rank()
        for mkbufs, array, equal in arrayimpl:
            #for mkbuf in mkbufs:
                for typecode, datatype in typemap.items():
                    for root in range(size):
                        for count in range(size):
                            sarr = array(typecode, [root]*(size*size))
                            rarr = array(typecode, [  -1]*(size*size))
                            counts = [count] * size
                            displs = range(0, size*size, size)
                            sbuf = [sarr, counts, displs, datatype]
                            rbuf = [rarr, counts, displs, datatype]
                            sendbuf = [sarr, counts, displs, datatype]
                            recvbuf = (rarr, counts, displs, datatype)
                            self.COMM.Alltoallv(sendbuf, recvbuf)
                            for i in range(size):
                                row = rarr[i*size:(i+1)*size]
                                a, b = row[:count], row[count:]
                                for va in a:
                                    self.assertEqual(va, root)
                                for vb in b:
                                    self.assertEqual(vb, -1)


class TestCollBufVecSelf(TestCollBufVecBase, unittest.TestCase):
    COMM = MPI.COMM_SELF

class TestCollBufVecWorld(TestCollBufVecBase, unittest.TestCase):
    COMM = MPI.COMM_WORLD

if __name__ == '__main__':
    try:
        unittest.main()
    except SystemExit:
        pass