View Javadoc

1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  package org.apache.commons.math.distribution;
19  
20  import java.io.Serializable;
21  
22  import org.apache.commons.math.MathException;
23  import org.apache.commons.math.MathRuntimeException;
24  import org.apache.commons.math.MaxIterationsExceededException;
25  import org.apache.commons.math.special.Erf;
26  
27  /**
28   * Default implementation of
29   * {@link org.apache.commons.math.distribution.NormalDistribution}.
30   *
31   * @version $Revision: 772119 $ $Date: 2009-05-06 05:43:28 -0400 (Wed, 06 May 2009) $
32   */
33  public class NormalDistributionImpl extends AbstractContinuousDistribution 
34          implements NormalDistribution, Serializable {
35      
36      /** Serializable version identifier */
37      private static final long serialVersionUID = 8589540077390120676L;
38  
39      /** &sqrt;(2 π) */
40      private static final double SQRT2PI = Math.sqrt(2 * Math.PI);
41  
42      /** The mean of this distribution. */
43      private double mean = 0;
44      
45      /** The standard deviation of this distribution. */
46      private double standardDeviation = 1;
47  
48      /**
49       * Create a normal distribution using the given mean and standard deviation.
50       * @param mean mean for this distribution
51       * @param sd standard deviation for this distribution
52       */
53      public NormalDistributionImpl(double mean, double sd){
54          super();
55          setMean(mean);
56          setStandardDeviation(sd);
57      }
58      
59      /**
60       * Creates normal distribution with the mean equal to zero and standard
61       * deviation equal to one. 
62       */
63      public NormalDistributionImpl(){
64          this(0.0, 1.0);
65      }
66      
67      /**
68       * Access the mean.
69       * @return mean for this distribution
70       */ 
71      public double getMean() {
72          return mean;
73      }
74      
75      /**
76       * Modify the mean.
77       * @param mean for this distribution
78       */
79      public void setMean(double mean) {
80          this.mean = mean;
81      }
82  
83      /**
84       * Access the standard deviation.
85       * @return standard deviation for this distribution
86       */
87      public double getStandardDeviation() {
88          return standardDeviation;
89      }
90  
91      /**
92       * Modify the standard deviation.
93       * @param sd standard deviation for this distribution
94       * @throws IllegalArgumentException if <code>sd</code> is not positive.
95       */
96      public void setStandardDeviation(double sd) {
97          if (sd <= 0.0) {
98              throw MathRuntimeException.createIllegalArgumentException(
99                    "standard deviation must be positive ({0})",
100                   sd);
101         }       
102         standardDeviation = sd;
103     }
104 
105     /**
106      * Return the probability density for a particular point.
107      *
108      * @param x The point at which the density should be computed.
109      * @return The pdf at point x.
110      */
111     public double density(Double x) {
112         double x0 = x - getMean();
113         return Math.exp(-x0 * x0 / (2 * getStandardDeviation() * getStandardDeviation())) / (getStandardDeviation() * SQRT2PI);
114     }
115 
116     /**
117      * For this distribution, X, this method returns P(X &lt; <code>x</code>).
118      * @param x the value at which the CDF is evaluated.
119      * @return CDF evaluted at <code>x</code>. 
120      * @throws MathException if the algorithm fails to converge; unless
121      * x is more than 20 standard deviations from the mean, in which case the
122      * convergence exception is caught and 0 or 1 is returned.
123      */
124     public double cumulativeProbability(double x) throws MathException {
125         try {
126             return 0.5 * (1.0 + Erf.erf((x - mean) /
127                     (standardDeviation * Math.sqrt(2.0))));
128         } catch (MaxIterationsExceededException ex) {
129             if (x < (mean - 20 * standardDeviation)) { // JDK 1.5 blows at 38
130                 return 0.0d;
131             } else if (x > (mean + 20 * standardDeviation)) {
132                 return 1.0d;
133             } else {
134                 throw ex;
135             }
136         }
137     }
138     
139     /**
140      * For this distribution, X, this method returns the critical point x, such
141      * that P(X &lt; x) = <code>p</code>.
142      * <p>
143      * Returns <code>Double.NEGATIVE_INFINITY</code> for p=0 and 
144      * <code>Double.POSITIVE_INFINITY</code> for p=1.</p>
145      *
146      * @param p the desired probability
147      * @return x, such that P(X &lt; x) = <code>p</code>
148      * @throws MathException if the inverse cumulative probability can not be
149      *         computed due to convergence or other numerical errors.
150      * @throws IllegalArgumentException if <code>p</code> is not a valid
151      *         probability.
152      */
153     @Override
154     public double inverseCumulativeProbability(final double p) 
155     throws MathException {
156         if (p == 0) {
157             return Double.NEGATIVE_INFINITY;
158         }
159         if (p == 1) {
160             return Double.POSITIVE_INFINITY;
161         }
162         return super.inverseCumulativeProbability(p);
163     }
164     
165     /**
166      * Access the domain value lower bound, based on <code>p</code>, used to
167      * bracket a CDF root.  This method is used by
168      * {@link #inverseCumulativeProbability(double)} to find critical values.
169      * 
170      * @param p the desired probability for the critical value
171      * @return domain value lower bound, i.e.
172      *         P(X &lt; <i>lower bound</i>) &lt; <code>p</code> 
173      */
174     @Override
175     protected double getDomainLowerBound(double p) {
176         double ret;
177 
178         if (p < .5) {
179             ret = -Double.MAX_VALUE;
180         } else {
181             ret = getMean();
182         }
183         
184         return ret;
185     }
186 
187     /**
188      * Access the domain value upper bound, based on <code>p</code>, used to
189      * bracket a CDF root.  This method is used by
190      * {@link #inverseCumulativeProbability(double)} to find critical values.
191      * 
192      * @param p the desired probability for the critical value
193      * @return domain value upper bound, i.e.
194      *         P(X &lt; <i>upper bound</i>) &gt; <code>p</code> 
195      */
196     @Override
197     protected double getDomainUpperBound(double p) {
198         double ret;
199 
200         if (p < .5) {
201             ret = getMean();
202         } else {
203             ret = Double.MAX_VALUE;
204         }
205         
206         return ret;
207     }
208 
209     /**
210      * Access the initial domain value, based on <code>p</code>, used to
211      * bracket a CDF root.  This method is used by
212      * {@link #inverseCumulativeProbability(double)} to find critical values.
213      * 
214      * @param p the desired probability for the critical value
215      * @return initial domain value
216      */
217     @Override
218     protected double getInitialDomain(double p) {
219         double ret;
220 
221         if (p < .5) {
222             ret = getMean() - getStandardDeviation();
223         } else if (p > .5) {
224             ret = getMean() + getStandardDeviation();
225         } else {
226             ret = getMean();
227         }
228         
229         return ret;
230     }
231 }