Sophie

Sophie

distrib > Mandriva > 2010.2 > i586 > media > contrib-backports > by-pkgid > e578866d55cd81fdb23827cdf3cec911 > files > 662

python-scikits-learn-0.6-1mdv2010.2.i586.rpm



<!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN"
  "http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd">

<html xmlns="http://www.w3.org/1999/xhtml">
  <head>
    <meta http-equiv="Content-Type" content="text/html; charset=utf-8" />
    
    <title>GMM classification &mdash; scikits.learn v0.6.0 documentation</title>
    <link rel="stylesheet" href="../../_static/nature.css" type="text/css" />
    <link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
    <script type="text/javascript">
      var DOCUMENTATION_OPTIONS = {
        URL_ROOT:    '../../',
        VERSION:     '0.6.0',
        COLLAPSE_INDEX: false,
        FILE_SUFFIX: '.html',
        HAS_SOURCE:  true
      };
    </script>
    <script type="text/javascript" src="../../_static/jquery.js"></script>
    <script type="text/javascript" src="../../_static/underscore.js"></script>
    <script type="text/javascript" src="../../_static/doctools.js"></script>
    <link rel="shortcut icon" href="../../_static/favicon.ico"/>
    <link rel="author" title="About these documents" href="../../about.html" />
    <link rel="top" title="scikits.learn v0.6.0 documentation" href="../../index.html" />
    <link rel="up" title="Examples" href="../index.html" />
    <link rel="next" title="Density Estimation for a mixture of Gaussians" href="plot_gmm_pdf.html" />
    <link rel="prev" title="Gaussian Mixture Model Ellipsoids" href="plot_gmm.html" /> 
  </head>
  <body>
    <div class="header-wrapper">
      <div class="header">
          <p class="logo"><a href="../../index.html">
            <img src="../../_static/scikit-learn-logo-small.png" alt="Logo"/>
          </a>
          </p><div class="navbar">
          <ul>
            <li><a href="../../install.html">Download</a></li>
            <li><a href="../../support.html">Support</a></li>
            <li><a href="../../user_guide.html">User Guide</a></li>
            <li><a href="../index.html">Examples</a></li>
            <li><a href="../../developers/index.html">Development</a></li>
       </ul>

<div class="search_form">

<div id="cse" style="width: 100%;"></div>
<script src="http://www.google.com/jsapi" type="text/javascript"></script>
<script type="text/javascript">
  google.load('search', '1', {language : 'en'});
  google.setOnLoadCallback(function() {
    var customSearchControl = new google.search.CustomSearchControl('016639176250731907682:tjtqbvtvij0');
    customSearchControl.setResultSetSize(google.search.Search.FILTERED_CSE_RESULTSET);
    var options = new google.search.DrawOptions();
    options.setAutoComplete(true);
    customSearchControl.draw('cse', options);
  }, true);
</script>

</div>

          </div> <!-- end navbar --></div>
    </div>

    <div class="content-wrapper">

    <!-- <div id="blue_tile"></div> -->

        <div class="sphinxsidebar">
        <div class="rel">
          <a href="plot_gmm.html" title="Gaussian Mixture Model Ellipsoids"
             accesskey="P">previous</a> |
          <a href="plot_gmm_pdf.html" title="Density Estimation for a mixture of Gaussians"
             accesskey="N">next</a> |
          <a href="../../genindex.html" title="General Index"
             accesskey="I">index</a>
        </div>
        

        <h3>Contents</h3>
         <ul>
<li><a class="reference internal" href="#">GMM classification</a></li>
</ul>


        

        </div>

      <div class="content">
            
      <div class="documentwrapper">
        <div class="bodywrapper">
          <div class="body">
            
  <div class="section" id="gmm-classification">
<span id="example-mixture-plot-gmm-classifier-py"></span><h1>GMM classification<a class="headerlink" href="#gmm-classification" title="Permalink to this headline">ΒΆ</a></h1>
<p>Demonstration of <em class="xref std std-ref">gmm</em> for classification.</p>
<p>Plots predicted labels on both training and held out test data using a
variety of GMM classifiers on the iris dataset.</p>
<p>Compares GMMs with spherical, diagonal, full, and tied covariance
matrices in increasing order of performance.  Although one would
expect full covariance to perform best in general, it is prone to
overfitting on small datasets and does not generalize well to held out
test data.</p>
<p>On the plots, train data is shown as dots, while test data is shown as
crosses. The iris dataset is four-dimensional. Only the first two
dimensions are shown here, and thus some points are separated in other
dimensions.</p>
<img alt="auto_examples/mixture/images/plot_gmm_classifier.png" class="align-center" src="auto_examples/mixture/images/plot_gmm_classifier.png" />
<p><strong>Python source code:</strong> <a class="reference download internal" href="../../_downloads/plot_gmm_classifier.py"><tt class="xref download docutils literal"><span class="pre">plot_gmm_classifier.py</span></tt></a></p>
<div class="highlight-python"><div class="highlight"><pre><span class="k">print</span> <span class="n">__doc__</span>

<span class="c"># Author: Ron Weiss &lt;ronweiss@gmail.com&gt;, Gael Varoquaux</span>
<span class="c"># License: BSD Style.</span>

<span class="c"># $Id$</span>

<span class="kn">import</span> <span class="nn">pylab</span> <span class="kn">as</span> <span class="nn">pl</span>
<span class="kn">import</span> <span class="nn">matplotlib</span> <span class="kn">as</span> <span class="nn">mpl</span>
<span class="kn">import</span> <span class="nn">numpy</span> <span class="kn">as</span> <span class="nn">np</span>

<span class="kn">from</span> <span class="nn">scikits.learn</span> <span class="kn">import</span> <span class="n">datasets</span>
<span class="kn">from</span> <span class="nn">scikits.learn.cross_val</span> <span class="kn">import</span> <span class="n">StratifiedKFold</span>
<span class="kn">from</span> <span class="nn">scikits.learn.mixture</span> <span class="kn">import</span> <span class="n">GMM</span>

<span class="k">def</span> <span class="nf">make_ellipses</span><span class="p">(</span><span class="n">gmm</span><span class="p">,</span> <span class="n">ax</span><span class="p">):</span>
    <span class="k">for</span> <span class="n">n</span><span class="p">,</span> <span class="n">color</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="s">&#39;rgb&#39;</span><span class="p">):</span>
        <span class="n">v</span><span class="p">,</span> <span class="n">w</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">linalg</span><span class="o">.</span><span class="n">eigh</span><span class="p">(</span><span class="n">gmm</span><span class="o">.</span><span class="n">covars</span><span class="p">[</span><span class="n">n</span><span class="p">][:</span><span class="mi">2</span><span class="p">,</span> <span class="p">:</span><span class="mi">2</span><span class="p">])</span>
        <span class="n">u</span> <span class="o">=</span> <span class="n">w</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">/</span> <span class="n">np</span><span class="o">.</span><span class="n">linalg</span><span class="o">.</span><span class="n">norm</span><span class="p">(</span><span class="n">w</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span>
        <span class="n">angle</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">arctan</span><span class="p">(</span><span class="n">u</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="o">/</span><span class="n">u</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span>
        <span class="n">angle</span> <span class="o">=</span> <span class="mi">180</span> <span class="o">*</span> <span class="n">angle</span> <span class="o">/</span> <span class="n">np</span><span class="o">.</span><span class="n">pi</span> <span class="c"># convert to degrees</span>
        <span class="n">v</span> <span class="o">*=</span> <span class="mi">9</span>
        <span class="n">ell</span> <span class="o">=</span> <span class="n">mpl</span><span class="o">.</span><span class="n">patches</span><span class="o">.</span><span class="n">Ellipse</span><span class="p">(</span><span class="n">gmm</span><span class="o">.</span><span class="n">means</span><span class="p">[</span><span class="n">n</span><span class="p">,</span> <span class="p">:</span><span class="mi">2</span><span class="p">],</span> <span class="n">v</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">v</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="mi">180</span> <span class="o">+</span> <span class="n">angle</span><span class="p">,</span>
                                  <span class="n">color</span><span class="o">=</span><span class="n">color</span><span class="p">)</span>
        <span class="n">ell</span><span class="o">.</span><span class="n">set_clip_box</span><span class="p">(</span><span class="n">ax</span><span class="o">.</span><span class="n">bbox</span><span class="p">)</span>
        <span class="n">ell</span><span class="o">.</span><span class="n">set_alpha</span><span class="p">(</span><span class="mf">0.5</span><span class="p">)</span>
        <span class="n">ax</span><span class="o">.</span><span class="n">add_artist</span><span class="p">(</span><span class="n">ell</span><span class="p">)</span>

<span class="n">iris</span> <span class="o">=</span> <span class="n">datasets</span><span class="o">.</span><span class="n">load_iris</span><span class="p">()</span>

<span class="c"># Break up the dataset into non-overlapping training (75%) and testing</span>
<span class="c"># (25%) sets.</span>
<span class="n">skf</span> <span class="o">=</span> <span class="n">StratifiedKFold</span><span class="p">(</span><span class="n">iris</span><span class="o">.</span><span class="n">target</span><span class="p">,</span> <span class="n">k</span><span class="o">=</span><span class="mi">4</span><span class="p">)</span>
<span class="c"># Only take the first fold.</span>
<span class="n">train_index</span><span class="p">,</span> <span class="n">test_index</span> <span class="o">=</span> <span class="n">skf</span><span class="o">.</span><span class="n">__iter__</span><span class="p">()</span><span class="o">.</span><span class="n">next</span><span class="p">()</span>


<span class="n">X_train</span> <span class="o">=</span> <span class="n">iris</span><span class="o">.</span><span class="n">data</span><span class="p">[</span><span class="n">train_index</span><span class="p">]</span>
<span class="n">y_train</span> <span class="o">=</span> <span class="n">iris</span><span class="o">.</span><span class="n">target</span><span class="p">[</span><span class="n">train_index</span><span class="p">]</span>
<span class="n">X_test</span> <span class="o">=</span> <span class="n">iris</span><span class="o">.</span><span class="n">data</span><span class="p">[</span><span class="n">test_index</span><span class="p">]</span>
<span class="n">y_test</span> <span class="o">=</span> <span class="n">iris</span><span class="o">.</span><span class="n">target</span><span class="p">[</span><span class="n">test_index</span><span class="p">]</span>

<span class="n">n_classes</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">unique</span><span class="p">(</span><span class="n">y_train</span><span class="p">))</span>

<span class="c"># Try GMMs using different types of covariances.</span>
<span class="n">classifiers</span> <span class="o">=</span> <span class="nb">dict</span><span class="p">((</span><span class="n">x</span><span class="p">,</span> <span class="n">GMM</span><span class="p">(</span><span class="n">n_states</span><span class="o">=</span><span class="n">n_classes</span><span class="p">,</span> <span class="n">cvtype</span><span class="o">=</span><span class="n">x</span><span class="p">))</span>
                    <span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="p">[</span><span class="s">&#39;spherical&#39;</span><span class="p">,</span> <span class="s">&#39;diag&#39;</span><span class="p">,</span> <span class="s">&#39;tied&#39;</span><span class="p">,</span> <span class="s">&#39;full&#39;</span><span class="p">])</span>

<span class="n">n_classifiers</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">classifiers</span><span class="p">)</span>

<span class="n">pl</span><span class="o">.</span><span class="n">figure</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">3</span><span class="o">*</span><span class="n">n_classifiers</span><span class="o">/</span><span class="mi">2</span><span class="p">,</span> <span class="mi">6</span><span class="p">))</span>
<span class="n">pl</span><span class="o">.</span><span class="n">subplots_adjust</span><span class="p">(</span><span class="n">bottom</span><span class="o">=.</span><span class="mo">01</span><span class="p">,</span> <span class="n">top</span><span class="o">=</span><span class="mf">0.95</span><span class="p">,</span> <span class="n">hspace</span><span class="o">=.</span><span class="mi">15</span><span class="p">,</span> <span class="n">wspace</span><span class="o">=.</span><span class="mo">05</span><span class="p">,</span>
                   <span class="n">left</span><span class="o">=.</span><span class="mo">01</span><span class="p">,</span> <span class="n">right</span><span class="o">=.</span><span class="mi">99</span><span class="p">)</span>


<span class="k">for</span> <span class="n">index</span><span class="p">,</span> <span class="p">(</span><span class="n">name</span><span class="p">,</span> <span class="n">classifier</span><span class="p">)</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">classifiers</span><span class="o">.</span><span class="n">iteritems</span><span class="p">()):</span>
    <span class="c"># Since we have class labels for the training data, we can</span>
    <span class="c"># initialize the GMM parameters in a supervised manner.</span>
    <span class="n">classifier</span><span class="o">.</span><span class="n">means</span> <span class="o">=</span> <span class="p">[</span><span class="n">X_train</span><span class="p">[</span><span class="n">y_train</span> <span class="o">==</span> <span class="n">i</span><span class="p">,</span> <span class="p">:]</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
                        <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">xrange</span><span class="p">(</span><span class="n">n_classes</span><span class="p">)]</span>

    <span class="c"># Train the other parameters using the EM algorithm.</span>
    <span class="n">classifier</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">X_train</span><span class="p">,</span> <span class="n">init_params</span><span class="o">=</span><span class="s">&#39;wc&#39;</span><span class="p">,</span> <span class="n">n_iter</span><span class="o">=</span><span class="mi">20</span><span class="p">)</span>

    <span class="n">h</span> <span class="o">=</span> <span class="n">pl</span><span class="o">.</span><span class="n">subplot</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="n">n_classifiers</span><span class="o">/</span><span class="mi">2</span><span class="p">,</span> <span class="n">index</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span>
    <span class="n">make_ellipses</span><span class="p">(</span><span class="n">classifier</span><span class="p">,</span> <span class="n">h</span><span class="p">)</span>

    <span class="k">for</span> <span class="n">n</span><span class="p">,</span> <span class="n">color</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="s">&#39;rgb&#39;</span><span class="p">):</span>
        <span class="n">data</span> <span class="o">=</span> <span class="n">iris</span><span class="o">.</span><span class="n">data</span><span class="p">[</span><span class="n">iris</span><span class="o">.</span><span class="n">target</span> <span class="o">==</span> <span class="n">n</span><span class="p">]</span>
        <span class="n">pl</span><span class="o">.</span><span class="n">scatter</span><span class="p">(</span><span class="n">data</span><span class="p">[:,</span><span class="mi">0</span><span class="p">],</span> <span class="n">data</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">],</span> <span class="mf">0.8</span><span class="p">,</span> <span class="n">color</span><span class="o">=</span><span class="n">color</span><span class="p">,</span>
                    <span class="n">label</span><span class="o">=</span><span class="n">iris</span><span class="o">.</span><span class="n">target_names</span><span class="p">[</span><span class="n">n</span><span class="p">])</span>
    <span class="c"># Plot the test data with crosses</span>
    <span class="k">for</span> <span class="n">n</span><span class="p">,</span> <span class="n">color</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="s">&#39;rgb&#39;</span><span class="p">):</span>
        <span class="n">data</span> <span class="o">=</span> <span class="n">X_test</span><span class="p">[</span><span class="n">y_test</span> <span class="o">==</span> <span class="n">n</span><span class="p">]</span>
        <span class="n">pl</span><span class="o">.</span><span class="n">plot</span><span class="p">(</span><span class="n">data</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">],</span> <span class="n">data</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">],</span> <span class="s">&#39;x&#39;</span><span class="p">,</span> <span class="n">color</span><span class="o">=</span><span class="n">color</span><span class="p">)</span>

    <span class="n">y_train_pred</span> <span class="o">=</span> <span class="n">classifier</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">X_train</span><span class="p">)</span>
    <span class="n">train_accuracy</span>  <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">y_train_pred</span><span class="o">.</span><span class="n">ravel</span><span class="p">()</span> <span class="o">==</span> <span class="n">y_train</span><span class="o">.</span><span class="n">ravel</span><span class="p">())</span> <span class="o">*</span> <span class="mi">100</span>
    <span class="n">pl</span><span class="o">.</span><span class="n">text</span><span class="p">(</span><span class="mf">0.05</span><span class="p">,</span> <span class="mf">0.9</span><span class="p">,</span> <span class="s">&#39;Train accuracy: </span><span class="si">%.1f</span><span class="s">&#39;</span> <span class="o">%</span> <span class="n">train_accuracy</span><span class="p">,</span>
                    <span class="n">transform</span><span class="o">=</span><span class="n">h</span><span class="o">.</span><span class="n">transAxes</span><span class="p">)</span>

    <span class="n">y_test_pred</span> <span class="o">=</span> <span class="n">classifier</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">X_test</span><span class="p">)</span>
    <span class="n">test_accuracy</span>  <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">y_test_pred</span><span class="o">.</span><span class="n">ravel</span><span class="p">()</span> <span class="o">==</span> <span class="n">y_test</span><span class="o">.</span><span class="n">ravel</span><span class="p">())</span> <span class="o">*</span> <span class="mi">100</span>
    <span class="n">pl</span><span class="o">.</span><span class="n">text</span><span class="p">(</span><span class="mf">0.05</span><span class="p">,</span> <span class="mf">0.8</span><span class="p">,</span> <span class="s">&#39;Test accuracy: </span><span class="si">%.1f</span><span class="s">&#39;</span> <span class="o">%</span> <span class="n">test_accuracy</span><span class="p">,</span>
                    <span class="n">transform</span><span class="o">=</span><span class="n">h</span><span class="o">.</span><span class="n">transAxes</span><span class="p">)</span>

    <span class="n">pl</span><span class="o">.</span><span class="n">xticks</span><span class="p">(())</span>
    <span class="n">pl</span><span class="o">.</span><span class="n">yticks</span><span class="p">(())</span>
    <span class="n">pl</span><span class="o">.</span><span class="n">title</span><span class="p">(</span><span class="n">name</span><span class="p">)</span>

<span class="n">pl</span><span class="o">.</span><span class="n">legend</span><span class="p">(</span><span class="n">loc</span><span class="o">=</span><span class="s">&#39;lower right&#39;</span><span class="p">,</span> <span class="n">prop</span><span class="o">=</span><span class="nb">dict</span><span class="p">(</span><span class="n">size</span><span class="o">=</span><span class="mi">12</span><span class="p">))</span>


<span class="n">pl</span><span class="o">.</span><span class="n">show</span><span class="p">()</span>
</pre></div>
</div>
</div>


          </div>
        </div>
      </div>
        <div class="clearer"></div>
      </div>
    </div>

    <div class="footer">
        <p style="text-align: center">This documentation is relative
        to scikits.learn version 0.6.0<p>
        &copy; 2010, scikits.learn developers (BSD Lincense).
      Created using <a href="http://sphinx.pocoo.org/">Sphinx</a> 1.0.5. Design by <a href="http://webylimonada.com">Web y Limonada</a>.
    </div>
  </body>
</html>