View Javadoc
1   /*
2    * Copyright 2019-2022 Foreseeti AB <https://foreseeti.com>
3    *
4    * Licensed under the Apache License, Version 2.0 (the "License");
5    * you may not use this file except in compliance with the License.
6    * You may obtain a copy of the License at
7    *
8    *     https://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS,
12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13   * See the License for the specific language governing permissions and
14   * limitations under the License.
15   */
16  
17  package org.mal_lang.lib;
18  
19  import java.util.List;
20  
21  public class Distributions {
22  
23    public static void validate(String name, List<Double> params) throws CompilerException {
24      switch (name) {
25        case "Bernoulli":
26          Bernoulli.validate(params);
27          break;
28        case "Binomial":
29          Binomial.validate(params);
30          break;
31        case "Exponential":
32          Exponential.validate(params);
33          break;
34        case "Gamma":
35          Gamma.validate(params);
36          break;
37        case "LogNormal":
38          LogNormal.validate(params);
39          break;
40        case "Pareto":
41          Pareto.validate(params);
42          break;
43        case "TruncatedNormal":
44          TruncatedNormal.validate(params);
45          break;
46        case "Uniform":
47          Uniform.validate(params);
48          break;
49        case "EasyAndCertain":
50        case "EasyAndUncertain":
51        case "HardAndCertain":
52        case "HardAndUncertain":
53        case "VeryHardAndCertain":
54        case "VeryHardAndUncertain":
55        case "Infinity":
56        case "Zero":
57        case "Enabled":
58        case "Disabled":
59          Combination.validate(params);
60          break;
61        default:
62          throw new CompilerException(String.format("Distribution '%s' is not supported", name));
63      }
64    }
65  
66    public static Distribution getDistribution(String name, List<Double> params) {
67      switch (name) {
68        case "Bernoulli":
69          return new Bernoulli(params);
70        case "Binomial":
71          return new Binomial(params);
72        case "Exponential":
73          return new Exponential(params);
74        case "Gamma":
75          return new Gamma(params);
76        case "LogNormal":
77          return new LogNormal(params);
78        case "Pareto":
79          return new Pareto(params);
80        case "TruncatedNormal":
81          return new TruncatedNormal(params);
82        case "Uniform":
83          return new Uniform(params);
84        case "EasyAndCertain":
85          return new EasyAndCertain();
86        case "EasyAndUncertain":
87          return new EasyAndUncertain();
88        case "HardAndCertain":
89          return new HardAndCertain();
90        case "HardAndUncertain":
91          return new HardAndUncertain();
92        case "VeryHardAndCertain":
93          return new VeryHardAndCertain();
94        case "VeryHardAndUncertain":
95          return new VeryHardAndUncertain();
96        case "Infinity":
97          return new Infinity();
98        case "Zero":
99          return new Zero();
100       case "Enabled":
101         return new Enabled();
102       case "Disabled":
103         return new Disabled();
104       default:
105         throw new RuntimeException(String.format("Distribution '%s' is not supported", name));
106     }
107   }
108 
109   public interface Distribution {
110     double getMean();
111   }
112 
113   public static class Bernoulli implements Distribution {
114     public final double probability;
115 
116     public Bernoulli(double probability) {
117       this.probability = probability;
118     }
119 
120     public Bernoulli(List<Double> params) {
121       this(params.get(0));
122     }
123 
124     public static void validate(List<Double> params) throws CompilerException {
125       if (params == null || params.size() != 1) {
126         throw new CompilerException(
127             "Expected exactly one parameter (probability), for Bernoulli distribution");
128       } else if (params.get(0) > 1) {
129         throw new CompilerException(
130             String.format(
131                 "%s is not in valid range '0 <= probability <= 1', for Bernoulli distribution",
132                 params.get(0)));
133       }
134     }
135 
136     @Override
137     public double getMean() {
138       return probability;
139     }
140 
141     @Override
142     public String toString() {
143       return String.format("Bernoulli(%f)", probability);
144     }
145   }
146 
147   public static class Binomial implements Distribution {
148     public final int trials;
149     public final double probability;
150 
151     public Binomial(int trials, double probability) {
152       this.trials = trials;
153       this.probability = probability;
154     }
155 
156     public Binomial(List<Double> params) {
157       this((int) Math.round(params.get(0)), params.get(1));
158     }
159 
160     public static void validate(List<Double> params) throws CompilerException {
161       if (params == null || params.size() != 2) {
162         throw new CompilerException(
163             "Expected exactly two parameters (trials, probability), for Binomial distribution");
164       } else if (params.get(1) > 1) {
165         throw new CompilerException(
166             String.format(
167                 "%s is not in valid range '0 <= probability <= 1', for Binomial distribution",
168                 params.get(1)));
169       }
170     }
171 
172     @Override
173     public double getMean() {
174       return trials * probability;
175     }
176 
177     @Override
178     public String toString() {
179       return String.format("Binomial(%d, %f)", trials, probability);
180     }
181   }
182 
183   public static class Exponential implements Distribution {
184     public final double lambda;
185 
186     public Exponential(double lambda) {
187       this.lambda = lambda;
188     }
189 
190     public Exponential(List<Double> params) {
191       this(params.get(0));
192     }
193 
194     public static void validate(List<Double> params) throws CompilerException {
195       if (params == null || params.size() != 1) {
196         throw new CompilerException(
197             "Expected exactly one parameter (lambda), for Exponential distribution");
198       } else if (params.get(0) <= 0) {
199         throw new CompilerException(
200             String.format(
201                 "%s is not in valid range 'lambda > 0', for Exponential distribution",
202                 params.get(0)));
203       }
204     }
205 
206     @Override
207     public double getMean() {
208       return 1 / lambda;
209     }
210 
211     @Override
212     public String toString() {
213       return String.format("Exponential(%f)", lambda);
214     }
215   }
216 
217   public static class Gamma implements Distribution {
218     public final double shape;
219     public final double scale;
220 
221     public Gamma(double shape, double scale) {
222       this.shape = shape;
223       this.scale = scale;
224     }
225 
226     public Gamma(List<Double> params) {
227       this(params.get(0), params.get(1));
228     }
229 
230     public static void validate(List<Double> params) throws CompilerException {
231       if (params == null || params.size() != 2) {
232         throw new CompilerException(
233             "Expected exactly two parameters (shape, scale), for Gamma distribution");
234       } else if (params.get(0) <= 0) {
235         throw new CompilerException(
236             String.format(
237                 "%s is not in valid range 'shape > 0', for Gamma distribution", params.get(0)));
238       } else if (params.get(1) <= 0) {
239         throw new CompilerException(
240             String.format(
241                 "%s is not in valid range 'scale > 0', for Gamma distribution", params.get(1)));
242       }
243     }
244 
245     @Override
246     public double getMean() {
247       return shape * scale;
248     }
249 
250     @Override
251     public String toString() {
252       return String.format("Gamma(%f, %f)", shape, scale);
253     }
254   }
255 
256   public static class LogNormal implements Distribution {
257     public final double mean;
258     public final double standardDeviation;
259 
260     public LogNormal(double mean, double standardDeviation) {
261       this.mean = mean;
262       this.standardDeviation = standardDeviation;
263     }
264 
265     public LogNormal(List<Double> params) {
266       this(params.get(0), params.get(1));
267     }
268 
269     public static void validate(List<Double> params) throws CompilerException {
270       if (params == null || params.size() != 2) {
271         throw new CompilerException(
272             "Expected exactly two parameters (mean, standardDeviation), for LogNormal"
273                 + " distribution");
274       } else if (params.get(1) <= 0) {
275         throw new CompilerException(
276             String.format(
277                 "%s is not in valid range 'standardDeviation > 0', for LogNormal distribution",
278                 params.get(1)));
279       }
280     }
281 
282     @Override
283     public double getMean() {
284       return Math.exp(mean + Math.pow(standardDeviation, 2) / 2);
285     }
286 
287     @Override
288     public String toString() {
289       return String.format("LogNormal(%f, %f)", mean, standardDeviation);
290     }
291   }
292 
293   public static class Pareto implements Distribution {
294     public final double min;
295     public final double shape;
296 
297     public Pareto(double min, double shape) {
298       this.min = min;
299       this.shape = shape;
300     }
301 
302     public Pareto(List<Double> params) {
303       this(params.get(0), params.get(1));
304     }
305 
306     public static void validate(List<Double> params) throws CompilerException {
307       if (params == null || params.size() != 2) {
308         throw new CompilerException(
309             "Expected exactly two parameters (min, shape), for Pareto distribution");
310       } else if (params.get(0) <= 0) {
311         throw new CompilerException(
312             String.format(
313                 "%s is not in valid range 'min > 0', for Pareto distribution", params.get(0)));
314       } else if (params.get(1) <= 0) {
315         throw new CompilerException(
316             String.format(
317                 "%s is not in valid range 'shape > 0', for Pareto distribution", params.get(1)));
318       }
319     }
320 
321     @Override
322     public double getMean() {
323       if (min <= 1) {
324         return Double.MAX_VALUE;
325       } else {
326         return (min * shape) / (min - 1);
327       }
328     }
329 
330     @Override
331     public String toString() {
332       return String.format("Pareto(%f, %f)", min, shape);
333     }
334   }
335 
336   public static class TruncatedNormal implements Distribution {
337     public final double mean;
338     public final double standardDeviation;
339 
340     public TruncatedNormal(double mean, double standardDeviation) {
341       this.mean = mean;
342       this.standardDeviation = standardDeviation;
343     }
344 
345     public TruncatedNormal(List<Double> params) {
346       this(params.get(0), params.get(1));
347     }
348 
349     public static void validate(List<Double> params) throws CompilerException {
350       if (params == null || params.size() != 2) {
351         throw new CompilerException(
352             "Expected exactly two parameters (mean, standardDeviation), for TruncatedNormal"
353                 + " distribution");
354       } else if (params.get(1) <= 0) {
355         throw new CompilerException(
356             String.format(
357                 "%s is not in valid range 'standardDeviation > 0', for TruncatedNormal"
358                     + " distribution",
359                 params.get(1)));
360       }
361     }
362 
363     @Override
364     public double getMean() {
365       return mean;
366     }
367 
368     @Override
369     public String toString() {
370       return String.format("TruncatedNormal(%f, %f)", mean, standardDeviation);
371     }
372   }
373 
374   public static class Uniform implements Distribution {
375     public final double min;
376     public final double max;
377 
378     public Uniform(double min, double max) {
379       this.min = min;
380       this.max = max;
381     }
382 
383     public Uniform(List<Double> params) {
384       this(params.get(0), params.get(1));
385     }
386 
387     public static void validate(List<Double> params) throws CompilerException {
388       if (params == null || params.size() != 2) {
389         throw new CompilerException(
390             "Expected exactly two parameters (min, max), for Uniform distribution");
391       } else if (params.get(0) > params.get(1)) {
392         throw new CompilerException(
393             String.format(
394                 "(%s, %s) does not meet requirement 'min <= max', for Uniform distribution",
395                 params.get(0), params.get(1)));
396       }
397     }
398 
399     @Override
400     public double getMean() {
401       return (min + max) / 2;
402     }
403 
404     @Override
405     public String toString() {
406       return String.format("Uniform(%f, %f)", min, max);
407     }
408   }
409 
410   /*Custom, combinations of above defined distributions*/
411 
412   private abstract static class Combination implements Distribution {
413     public static void validate(List<Double> params) throws CompilerException {
414       if (params != null && params.size() != 0) {
415         throw new CompilerException(
416             String.format("Expected exactly zero parameters, for combination distributions"));
417       }
418     }
419   }
420 
421   public static class EasyAndCertain extends Combination {
422     public static final Exponential exponential = new Exponential(1);
423 
424     @Override
425     public double getMean() {
426       return exponential.getMean();
427     }
428 
429     @Override
430     public String toString() {
431       return "EasyAndCertain";
432     }
433   }
434 
435   public static class EasyAndUncertain extends Combination {
436     public static final Bernoulli bernoulli = new Bernoulli(0.5);
437 
438     @Override
439     public double getMean() {
440       return bernoulli.getMean();
441     }
442 
443     @Override
444     public String toString() {
445       return "EasyAndUncertain";
446     }
447   }
448 
449   public static class HardAndCertain extends Combination {
450     public static final Exponential exponential = new Exponential(0.1);
451 
452     @Override
453     public double getMean() {
454       return exponential.getMean();
455     }
456 
457     @Override
458     public String toString() {
459       return "HardAndCertain";
460     }
461   }
462 
463   public static class HardAndUncertain extends Combination {
464     public static final Bernoulli bernoulli = new Bernoulli(0.5);
465     public static final Exponential exponential = new Exponential(0.1);
466 
467     @Override
468     public double getMean() {
469       return bernoulli.getMean() * exponential.getMean();
470     }
471 
472     @Override
473     public String toString() {
474       return "HardAndUncertain";
475     }
476   }
477 
478   public static class VeryHardAndCertain extends Combination {
479     public static final Exponential exponential = new Exponential(0.01);
480 
481     @Override
482     public double getMean() {
483       return exponential.getMean();
484     }
485 
486     @Override
487     public String toString() {
488       return "VeryHardAndCertain";
489     }
490   }
491 
492   public static class VeryHardAndUncertain extends Combination {
493     public static final Bernoulli bernoulli = new Bernoulli(0.5);
494     public static final Exponential exponential = new Exponential(0.01);
495 
496     @Override
497     public double getMean() {
498       return bernoulli.getMean() * exponential.getMean();
499     }
500 
501     @Override
502     public String toString() {
503       return "VeryHardAndUncertain";
504     }
505   }
506 
507   public static class Infinity extends Combination {
508     public Infinity() {}
509 
510     @Override
511     public double getMean() {
512       return Double.MAX_VALUE;
513     }
514 
515     @Override
516     public String toString() {
517       return "Infinity";
518     }
519   }
520 
521   public static class Zero extends Combination {
522     @Override
523     public double getMean() {
524       return 0;
525     }
526 
527     @Override
528     public String toString() {
529       return "Zero";
530     }
531   }
532 
533   public static class Enabled extends Combination {
534     @Override
535     public double getMean() {
536       return 1;
537     }
538 
539     @Override
540     public String toString() {
541       return "Enabled";
542     }
543   }
544 
545   public static class Disabled extends Combination {
546     @Override
547     public double getMean() {
548       return 0;
549     }
550 
551     @Override
552     public String toString() {
553       return "Disabled";
554     }
555   }
556 }