Sophie

Sophie

distrib > Mandriva > 2008.1 > i586 > by-pkgid > 2fe96174012fea2d88f752857a5bea1d > files > 84

python-mpi4py-0.6.0-4mdv2008.1.i586.rpm

import sys, unittest
from mpi4py import MPI

class TestCommBase(object):

    def testProperties(self):
        comm = self.COMM
        self.assertEqual(comm.Get_size(), comm.size)
        self.assertEqual(comm.Get_rank(), comm.rank)
        self.assertEqual(comm.Is_intra(), comm.is_intra)
        self.assertEqual(comm.Is_inter(), comm.is_inter)
        self.assertEqual(comm.Get_topology(), comm.topology)

    def testGroup(self):
        comm  = self.COMM
        group = self.COMM.Get_group()
        self.assertEqual(comm.Get_size(), group.Get_size())
        self.assertEqual(comm.Get_rank(), group.Get_rank())
        group.Free()
        self.assertEqual(group, MPI.GROUP_NULL)

    def testContructor(self):
        comm = MPI.Comm(self.COMM)
        self.assertEqual(comm, self.COMM)
        self.assertNotEqual(comm, MPI.COMM_NULL)

    def testCloneFree(self):
        comm = self.COMM.Clone()
        comm.Free()
        self.assertEqual(comm, MPI.COMM_NULL)

    def testCompare(self):
        results = (MPI.IDENT, MPI.CONGRUENT, MPI.SIMILAR, MPI.UNEQUAL)
        ccmp = MPI.Comm.Compare(self.COMM, MPI.COMM_WORLD)
        self.assertTrue(ccmp in results)
        ccmp = MPI.Comm.Compare(self.COMM, self.COMM)
        self.assertEqual(ccmp, MPI.IDENT)
        comm = self.COMM.Dup()
        ccmp = MPI.Comm.Compare(self.COMM, comm)
        comm.Free()
        self.assertEqual(ccmp, MPI.CONGRUENT)
        cmp_null = lambda: MPI.Comm.Compare(self.COMM, MPI.COMM_NULL)
        self.assertRaises(MPI.Exception, cmp_null)

    def testBarrier(self):
        self.COMM.Barrier()

    def testBcast(self):
        messages = [ None, 7, 3.14, 1+2j, 'mpi4py',
                    [None, 7, 3.14, 1+2j, 'mpi4py'],
                    (None, 7, 3.14, 1+2j, 'mpi4py'),
                     dict(k1=None,k2=7,k3=3.14,k4=1+2j,k5='mpi4py')]
        for smess in messages:
            rmess = self.COMM.Bcast(smess)
            self.assertEqual(smess, rmess)
            for root in range(self.COMM.Get_size()):
                rmess = self.COMM.Bcast(smess, root)
                self.assertEqual(smess, rmess)

    def testGather(self):
        size = self.COMM.Get_size()
        rank = self.COMM.Get_rank()
        for root in range(size):
            rmess = self.COMM.Gather(rank, None, root)
            if rank == root:
                self.assertEqual(rmess, list(range(size)))
            else:
                self.assertEqual(rmess, None)

    def testScatter(self):
        size = self.COMM.Get_size()
        rank = self.COMM.Get_rank()
        for root in range(size):
            smess = [rank + root] * size
            rmess = self.COMM.Scatter(smess, None, root)
            self.assertEqual(rmess, root * 2)

    def testAllgather(self):
        size = self.COMM.Get_size()
        rank = self.COMM.Get_rank()
        rmess = self.COMM.Allgather(rank, None)
        self.assertEqual(rmess, list(range(size)))

    def testAlltoall(self):
        size = self.COMM.Get_size()
        rank = self.COMM.Get_rank()
        smess = [rank] * size
        rmess = self.COMM.Alltoall(smess, None)
        self.assertEqual(rmess, list(range(size)))

    def testGetSetName(self):
        try:
            name = self.COMM.Get_name()
            self.COMM.Set_name('comm')
            self.assertEqual(self.COMM.Get_name(), 'comm')
            self.COMM.Set_name(name)
            self.assertEqual(self.COMM.Get_name(), name)
        except NotImplementedError:
            pass

    def testGetParent(self):
        try:
            parent = MPI.Comm.Get_parent()
            if parent == MPI.COMM_NULL:
                self.assertRaises(MPI.Exception, parent.Disconnect)
        except NotImplementedError:
            pass


class TestCommNull(unittest.TestCase):

    def testNull(self):
        COMM_NULL = MPI.COMM_NULL
        comm_null = MPI.Comm(COMM_NULL)
        self.assertEqual(comm_null, COMM_NULL)
        cmpnull = lambda: MPI.Comm.Compare(COMM_NULL, COMM_NULL)
        self.assertRaises(MPI.Exception, cmpnull)
        for meth in ('Get_size', 'Get_rank', 'Is_inter', 'Is_intra',
                     'Get_group', 'Get_topology', 'Get_errhandler',
                     'Clone', 'Free', 'Abort'):
            bound_method = getattr(COMM_NULL, meth)
            self.assertRaises(MPI.Exception, bound_method)
            try:
                bound_method()
                raise AssertionError
            except MPI.Exception:
                exc = sys.exc_info()[1]
                self.assertEqual(exc.Get_error_class(), MPI.ERR_COMM)

    def testInterNull(self):
        comm_null = MPI.Intercomm(MPI.COMM_NULL)
        self.assertRaises(MPI.Exception, comm_null.Get_remote_group)
        self.assertRaises(MPI.Exception, comm_null.Get_remote_size)
        comm_null_Dup    = lambda : comm_null.Dup()
        comm_null_Create = lambda : comm_null.Create(group=MPI.GROUP_EMPTY)
        comm_null_Split  = lambda : comm_null.Split(color=0,key=0)
        comm_null_Merge  = lambda : comm_null.Merge(high=True)
        self.assertRaises(MPI.Exception, comm_null_Dup)
        self.assertRaises(MPI.Exception, comm_null_Create)
        self.assertRaises(MPI.Exception, comm_null_Split)
        self.assertRaises(MPI.Exception, comm_null_Merge)


class TestCommSelf(TestCommBase, unittest.TestCase):
    def setUp(self):
        self.COMM = MPI.COMM_SELF
    def testSize(self):
        size = self.COMM.Get_size()
        self.assertEqual(size, 1)
    def testRank(self):
        rank = self.COMM.Get_rank()
        self.assertEqual(rank, 0)
    def testFree(self):
        self.assertRaises(MPI.Exception, MPI.COMM_SELF.Free)
        self.assertRaises(MPI.Exception, MPI.__COMM_SELF__.Free)


class TestCommWorld(TestCommBase, unittest.TestCase):
    def setUp(self):
        self.COMM = MPI.COMM_WORLD
    def testSize(self):
        size = self.COMM.Get_size()
        self.assertTrue(size >= 1)
    def testRank(self):
        size = self.COMM.Get_size()
        rank = self.COMM.Get_rank()
        self.assertTrue(rank >= 0 and rank < size)
    def testFree(self):
        self.assertRaises(MPI.Exception, MPI.COMM_WORLD.Free)
        self.assertRaises(MPI.Exception, MPI.__COMM_WORLD__.Free)


if __name__ == '__main__':
    try:
        unittest.main()
    except SystemExit:
        pass