1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package org.mal_lang.lib;
18
19 import java.util.List;
20
21 public class Distributions {
22
23 public static void validate(String name, List<Double> params) throws CompilerException {
24 switch (name) {
25 case "Bernoulli":
26 Bernoulli.validate(params);
27 break;
28 case "Binomial":
29 Binomial.validate(params);
30 break;
31 case "Exponential":
32 Exponential.validate(params);
33 break;
34 case "Gamma":
35 Gamma.validate(params);
36 break;
37 case "LogNormal":
38 LogNormal.validate(params);
39 break;
40 case "Pareto":
41 Pareto.validate(params);
42 break;
43 case "TruncatedNormal":
44 TruncatedNormal.validate(params);
45 break;
46 case "Uniform":
47 Uniform.validate(params);
48 break;
49 case "EasyAndCertain":
50 case "EasyAndUncertain":
51 case "HardAndCertain":
52 case "HardAndUncertain":
53 case "VeryHardAndCertain":
54 case "VeryHardAndUncertain":
55 case "Infinity":
56 case "Zero":
57 case "Enabled":
58 case "Disabled":
59 Combination.validate(params);
60 break;
61 default:
62 throw new CompilerException(String.format("Distribution '%s' is not supported", name));
63 }
64 }
65
66 public static Distribution getDistribution(String name, List<Double> params) {
67 switch (name) {
68 case "Bernoulli":
69 return new Bernoulli(params);
70 case "Binomial":
71 return new Binomial(params);
72 case "Exponential":
73 return new Exponential(params);
74 case "Gamma":
75 return new Gamma(params);
76 case "LogNormal":
77 return new LogNormal(params);
78 case "Pareto":
79 return new Pareto(params);
80 case "TruncatedNormal":
81 return new TruncatedNormal(params);
82 case "Uniform":
83 return new Uniform(params);
84 case "EasyAndCertain":
85 return new EasyAndCertain();
86 case "EasyAndUncertain":
87 return new EasyAndUncertain();
88 case "HardAndCertain":
89 return new HardAndCertain();
90 case "HardAndUncertain":
91 return new HardAndUncertain();
92 case "VeryHardAndCertain":
93 return new VeryHardAndCertain();
94 case "VeryHardAndUncertain":
95 return new VeryHardAndUncertain();
96 case "Infinity":
97 return new Infinity();
98 case "Zero":
99 return new Zero();
100 case "Enabled":
101 return new Enabled();
102 case "Disabled":
103 return new Disabled();
104 default:
105 throw new RuntimeException(String.format("Distribution '%s' is not supported", name));
106 }
107 }
108
109 public interface Distribution {
110 double getMean();
111 }
112
113 public static class Bernoulli implements Distribution {
114 public final double probability;
115
116 public Bernoulli(double probability) {
117 this.probability = probability;
118 }
119
120 public Bernoulli(List<Double> params) {
121 this(params.get(0));
122 }
123
124 public static void validate(List<Double> params) throws CompilerException {
125 if (params == null || params.size() != 1) {
126 throw new CompilerException(
127 "Expected exactly one parameter (probability), for Bernoulli distribution");
128 } else if (params.get(0) > 1) {
129 throw new CompilerException(
130 String.format(
131 "%s is not in valid range '0 <= probability <= 1', for Bernoulli distribution",
132 params.get(0)));
133 }
134 }
135
136 @Override
137 public double getMean() {
138 return probability;
139 }
140
141 @Override
142 public String toString() {
143 return String.format("Bernoulli(%f)", probability);
144 }
145 }
146
147 public static class Binomial implements Distribution {
148 public final int trials;
149 public final double probability;
150
151 public Binomial(int trials, double probability) {
152 this.trials = trials;
153 this.probability = probability;
154 }
155
156 public Binomial(List<Double> params) {
157 this((int) Math.round(params.get(0)), params.get(1));
158 }
159
160 public static void validate(List<Double> params) throws CompilerException {
161 if (params == null || params.size() != 2) {
162 throw new CompilerException(
163 "Expected exactly two parameters (trials, probability), for Binomial distribution");
164 } else if (params.get(1) > 1) {
165 throw new CompilerException(
166 String.format(
167 "%s is not in valid range '0 <= probability <= 1', for Binomial distribution",
168 params.get(1)));
169 }
170 }
171
172 @Override
173 public double getMean() {
174 return trials * probability;
175 }
176
177 @Override
178 public String toString() {
179 return String.format("Binomial(%d, %f)", trials, probability);
180 }
181 }
182
183 public static class Exponential implements Distribution {
184 public final double lambda;
185
186 public Exponential(double lambda) {
187 this.lambda = lambda;
188 }
189
190 public Exponential(List<Double> params) {
191 this(params.get(0));
192 }
193
194 public static void validate(List<Double> params) throws CompilerException {
195 if (params == null || params.size() != 1) {
196 throw new CompilerException(
197 "Expected exactly one parameter (lambda), for Exponential distribution");
198 } else if (params.get(0) <= 0) {
199 throw new CompilerException(
200 String.format(
201 "%s is not in valid range 'lambda > 0', for Exponential distribution",
202 params.get(0)));
203 }
204 }
205
206 @Override
207 public double getMean() {
208 return 1 / lambda;
209 }
210
211 @Override
212 public String toString() {
213 return String.format("Exponential(%f)", lambda);
214 }
215 }
216
217 public static class Gamma implements Distribution {
218 public final double shape;
219 public final double scale;
220
221 public Gamma(double shape, double scale) {
222 this.shape = shape;
223 this.scale = scale;
224 }
225
226 public Gamma(List<Double> params) {
227 this(params.get(0), params.get(1));
228 }
229
230 public static void validate(List<Double> params) throws CompilerException {
231 if (params == null || params.size() != 2) {
232 throw new CompilerException(
233 "Expected exactly two parameters (shape, scale), for Gamma distribution");
234 } else if (params.get(0) <= 0) {
235 throw new CompilerException(
236 String.format(
237 "%s is not in valid range 'shape > 0', for Gamma distribution", params.get(0)));
238 } else if (params.get(1) <= 0) {
239 throw new CompilerException(
240 String.format(
241 "%s is not in valid range 'scale > 0', for Gamma distribution", params.get(1)));
242 }
243 }
244
245 @Override
246 public double getMean() {
247 return shape * scale;
248 }
249
250 @Override
251 public String toString() {
252 return String.format("Gamma(%f, %f)", shape, scale);
253 }
254 }
255
256 public static class LogNormal implements Distribution {
257 public final double mean;
258 public final double standardDeviation;
259
260 public LogNormal(double mean, double standardDeviation) {
261 this.mean = mean;
262 this.standardDeviation = standardDeviation;
263 }
264
265 public LogNormal(List<Double> params) {
266 this(params.get(0), params.get(1));
267 }
268
269 public static void validate(List<Double> params) throws CompilerException {
270 if (params == null || params.size() != 2) {
271 throw new CompilerException(
272 "Expected exactly two parameters (mean, standardDeviation), for LogNormal"
273 + " distribution");
274 } else if (params.get(1) <= 0) {
275 throw new CompilerException(
276 String.format(
277 "%s is not in valid range 'standardDeviation > 0', for LogNormal distribution",
278 params.get(1)));
279 }
280 }
281
282 @Override
283 public double getMean() {
284 return Math.exp(mean + Math.pow(standardDeviation, 2) / 2);
285 }
286
287 @Override
288 public String toString() {
289 return String.format("LogNormal(%f, %f)", mean, standardDeviation);
290 }
291 }
292
293 public static class Pareto implements Distribution {
294 public final double min;
295 public final double shape;
296
297 public Pareto(double min, double shape) {
298 this.min = min;
299 this.shape = shape;
300 }
301
302 public Pareto(List<Double> params) {
303 this(params.get(0), params.get(1));
304 }
305
306 public static void validate(List<Double> params) throws CompilerException {
307 if (params == null || params.size() != 2) {
308 throw new CompilerException(
309 "Expected exactly two parameters (min, shape), for Pareto distribution");
310 } else if (params.get(0) <= 0) {
311 throw new CompilerException(
312 String.format(
313 "%s is not in valid range 'min > 0', for Pareto distribution", params.get(0)));
314 } else if (params.get(1) <= 0) {
315 throw new CompilerException(
316 String.format(
317 "%s is not in valid range 'shape > 0', for Pareto distribution", params.get(1)));
318 }
319 }
320
321 @Override
322 public double getMean() {
323 if (min <= 1) {
324 return Double.MAX_VALUE;
325 } else {
326 return (min * shape) / (min - 1);
327 }
328 }
329
330 @Override
331 public String toString() {
332 return String.format("Pareto(%f, %f)", min, shape);
333 }
334 }
335
336 public static class TruncatedNormal implements Distribution {
337 public final double mean;
338 public final double standardDeviation;
339
340 public TruncatedNormal(double mean, double standardDeviation) {
341 this.mean = mean;
342 this.standardDeviation = standardDeviation;
343 }
344
345 public TruncatedNormal(List<Double> params) {
346 this(params.get(0), params.get(1));
347 }
348
349 public static void validate(List<Double> params) throws CompilerException {
350 if (params == null || params.size() != 2) {
351 throw new CompilerException(
352 "Expected exactly two parameters (mean, standardDeviation), for TruncatedNormal"
353 + " distribution");
354 } else if (params.get(1) <= 0) {
355 throw new CompilerException(
356 String.format(
357 "%s is not in valid range 'standardDeviation > 0', for TruncatedNormal"
358 + " distribution",
359 params.get(1)));
360 }
361 }
362
363 @Override
364 public double getMean() {
365 return mean;
366 }
367
368 @Override
369 public String toString() {
370 return String.format("TruncatedNormal(%f, %f)", mean, standardDeviation);
371 }
372 }
373
374 public static class Uniform implements Distribution {
375 public final double min;
376 public final double max;
377
378 public Uniform(double min, double max) {
379 this.min = min;
380 this.max = max;
381 }
382
383 public Uniform(List<Double> params) {
384 this(params.get(0), params.get(1));
385 }
386
387 public static void validate(List<Double> params) throws CompilerException {
388 if (params == null || params.size() != 2) {
389 throw new CompilerException(
390 "Expected exactly two parameters (min, max), for Uniform distribution");
391 } else if (params.get(0) > params.get(1)) {
392 throw new CompilerException(
393 String.format(
394 "(%s, %s) does not meet requirement 'min <= max', for Uniform distribution",
395 params.get(0), params.get(1)));
396 }
397 }
398
399 @Override
400 public double getMean() {
401 return (min + max) / 2;
402 }
403
404 @Override
405 public String toString() {
406 return String.format("Uniform(%f, %f)", min, max);
407 }
408 }
409
410
411
412 private abstract static class Combination implements Distribution {
413 public static void validate(List<Double> params) throws CompilerException {
414 if (params != null && params.size() != 0) {
415 throw new CompilerException(
416 String.format("Expected exactly zero parameters, for combination distributions"));
417 }
418 }
419 }
420
421 public static class EasyAndCertain extends Combination {
422 public static final Exponential exponential = new Exponential(1);
423
424 @Override
425 public double getMean() {
426 return exponential.getMean();
427 }
428
429 @Override
430 public String toString() {
431 return "EasyAndCertain";
432 }
433 }
434
435 public static class EasyAndUncertain extends Combination {
436 public static final Bernoulli bernoulli = new Bernoulli(0.5);
437
438 @Override
439 public double getMean() {
440 return bernoulli.getMean();
441 }
442
443 @Override
444 public String toString() {
445 return "EasyAndUncertain";
446 }
447 }
448
449 public static class HardAndCertain extends Combination {
450 public static final Exponential exponential = new Exponential(0.1);
451
452 @Override
453 public double getMean() {
454 return exponential.getMean();
455 }
456
457 @Override
458 public String toString() {
459 return "HardAndCertain";
460 }
461 }
462
463 public static class HardAndUncertain extends Combination {
464 public static final Bernoulli bernoulli = new Bernoulli(0.5);
465 public static final Exponential exponential = new Exponential(0.1);
466
467 @Override
468 public double getMean() {
469 return bernoulli.getMean() * exponential.getMean();
470 }
471
472 @Override
473 public String toString() {
474 return "HardAndUncertain";
475 }
476 }
477
478 public static class VeryHardAndCertain extends Combination {
479 public static final Exponential exponential = new Exponential(0.01);
480
481 @Override
482 public double getMean() {
483 return exponential.getMean();
484 }
485
486 @Override
487 public String toString() {
488 return "VeryHardAndCertain";
489 }
490 }
491
492 public static class VeryHardAndUncertain extends Combination {
493 public static final Bernoulli bernoulli = new Bernoulli(0.5);
494 public static final Exponential exponential = new Exponential(0.01);
495
496 @Override
497 public double getMean() {
498 return bernoulli.getMean() * exponential.getMean();
499 }
500
501 @Override
502 public String toString() {
503 return "VeryHardAndUncertain";
504 }
505 }
506
507 public static class Infinity extends Combination {
508 public Infinity() {}
509
510 @Override
511 public double getMean() {
512 return Double.MAX_VALUE;
513 }
514
515 @Override
516 public String toString() {
517 return "Infinity";
518 }
519 }
520
521 public static class Zero extends Combination {
522 @Override
523 public double getMean() {
524 return 0;
525 }
526
527 @Override
528 public String toString() {
529 return "Zero";
530 }
531 }
532
533 public static class Enabled extends Combination {
534 @Override
535 public double getMean() {
536 return 1;
537 }
538
539 @Override
540 public String toString() {
541 return "Enabled";
542 }
543 }
544
545 public static class Disabled extends Combination {
546 @Override
547 public double getMean() {
548 return 0;
549 }
550
551 @Override
552 public String toString() {
553 return "Disabled";
554 }
555 }
556 }