Distributions.java
/*
* Copyright 2019-2022 Foreseeti AB <https://foreseeti.com>
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.mal_lang.lib;
import java.util.List;
public class Distributions {
public static void validate(String name, List<Double> params) throws CompilerException {
switch (name) {
case "Bernoulli":
Bernoulli.validate(params);
break;
case "Binomial":
Binomial.validate(params);
break;
case "Exponential":
Exponential.validate(params);
break;
case "Gamma":
Gamma.validate(params);
break;
case "LogNormal":
LogNormal.validate(params);
break;
case "Pareto":
Pareto.validate(params);
break;
case "TruncatedNormal":
TruncatedNormal.validate(params);
break;
case "Uniform":
Uniform.validate(params);
break;
case "EasyAndCertain":
case "EasyAndUncertain":
case "HardAndCertain":
case "HardAndUncertain":
case "VeryHardAndCertain":
case "VeryHardAndUncertain":
case "Infinity":
case "Zero":
case "Enabled":
case "Disabled":
Combination.validate(params);
break;
default:
throw new CompilerException(String.format("Distribution '%s' is not supported", name));
}
}
public static Distribution getDistribution(String name, List<Double> params) {
switch (name) {
case "Bernoulli":
return new Bernoulli(params);
case "Binomial":
return new Binomial(params);
case "Exponential":
return new Exponential(params);
case "Gamma":
return new Gamma(params);
case "LogNormal":
return new LogNormal(params);
case "Pareto":
return new Pareto(params);
case "TruncatedNormal":
return new TruncatedNormal(params);
case "Uniform":
return new Uniform(params);
case "EasyAndCertain":
return new EasyAndCertain();
case "EasyAndUncertain":
return new EasyAndUncertain();
case "HardAndCertain":
return new HardAndCertain();
case "HardAndUncertain":
return new HardAndUncertain();
case "VeryHardAndCertain":
return new VeryHardAndCertain();
case "VeryHardAndUncertain":
return new VeryHardAndUncertain();
case "Infinity":
return new Infinity();
case "Zero":
return new Zero();
case "Enabled":
return new Enabled();
case "Disabled":
return new Disabled();
default:
throw new RuntimeException(String.format("Distribution '%s' is not supported", name));
}
}
public interface Distribution {
double getMean();
}
public static class Bernoulli implements Distribution {
public final double probability;
public Bernoulli(double probability) {
this.probability = probability;
}
public Bernoulli(List<Double> params) {
this(params.get(0));
}
public static void validate(List<Double> params) throws CompilerException {
if (params == null || params.size() != 1) {
throw new CompilerException(
"Expected exactly one parameter (probability), for Bernoulli distribution");
} else if (params.get(0) > 1) {
throw new CompilerException(
String.format(
"%s is not in valid range '0 <= probability <= 1', for Bernoulli distribution",
params.get(0)));
}
}
@Override
public double getMean() {
return probability;
}
@Override
public String toString() {
return String.format("Bernoulli(%f)", probability);
}
}
public static class Binomial implements Distribution {
public final int trials;
public final double probability;
public Binomial(int trials, double probability) {
this.trials = trials;
this.probability = probability;
}
public Binomial(List<Double> params) {
this((int) Math.round(params.get(0)), params.get(1));
}
public static void validate(List<Double> params) throws CompilerException {
if (params == null || params.size() != 2) {
throw new CompilerException(
"Expected exactly two parameters (trials, probability), for Binomial distribution");
} else if (params.get(1) > 1) {
throw new CompilerException(
String.format(
"%s is not in valid range '0 <= probability <= 1', for Binomial distribution",
params.get(1)));
}
}
@Override
public double getMean() {
return trials * probability;
}
@Override
public String toString() {
return String.format("Binomial(%d, %f)", trials, probability);
}
}
public static class Exponential implements Distribution {
public final double lambda;
public Exponential(double lambda) {
this.lambda = lambda;
}
public Exponential(List<Double> params) {
this(params.get(0));
}
public static void validate(List<Double> params) throws CompilerException {
if (params == null || params.size() != 1) {
throw new CompilerException(
"Expected exactly one parameter (lambda), for Exponential distribution");
} else if (params.get(0) <= 0) {
throw new CompilerException(
String.format(
"%s is not in valid range 'lambda > 0', for Exponential distribution",
params.get(0)));
}
}
@Override
public double getMean() {
return 1 / lambda;
}
@Override
public String toString() {
return String.format("Exponential(%f)", lambda);
}
}
public static class Gamma implements Distribution {
public final double shape;
public final double scale;
public Gamma(double shape, double scale) {
this.shape = shape;
this.scale = scale;
}
public Gamma(List<Double> params) {
this(params.get(0), params.get(1));
}
public static void validate(List<Double> params) throws CompilerException {
if (params == null || params.size() != 2) {
throw new CompilerException(
"Expected exactly two parameters (shape, scale), for Gamma distribution");
} else if (params.get(0) <= 0) {
throw new CompilerException(
String.format(
"%s is not in valid range 'shape > 0', for Gamma distribution", params.get(0)));
} else if (params.get(1) <= 0) {
throw new CompilerException(
String.format(
"%s is not in valid range 'scale > 0', for Gamma distribution", params.get(1)));
}
}
@Override
public double getMean() {
return shape * scale;
}
@Override
public String toString() {
return String.format("Gamma(%f, %f)", shape, scale);
}
}
public static class LogNormal implements Distribution {
public final double mean;
public final double standardDeviation;
public LogNormal(double mean, double standardDeviation) {
this.mean = mean;
this.standardDeviation = standardDeviation;
}
public LogNormal(List<Double> params) {
this(params.get(0), params.get(1));
}
public static void validate(List<Double> params) throws CompilerException {
if (params == null || params.size() != 2) {
throw new CompilerException(
"Expected exactly two parameters (mean, standardDeviation), for LogNormal"
+ " distribution");
} else if (params.get(1) <= 0) {
throw new CompilerException(
String.format(
"%s is not in valid range 'standardDeviation > 0', for LogNormal distribution",
params.get(1)));
}
}
@Override
public double getMean() {
return Math.exp(mean + Math.pow(standardDeviation, 2) / 2);
}
@Override
public String toString() {
return String.format("LogNormal(%f, %f)", mean, standardDeviation);
}
}
public static class Pareto implements Distribution {
public final double min;
public final double shape;
public Pareto(double min, double shape) {
this.min = min;
this.shape = shape;
}
public Pareto(List<Double> params) {
this(params.get(0), params.get(1));
}
public static void validate(List<Double> params) throws CompilerException {
if (params == null || params.size() != 2) {
throw new CompilerException(
"Expected exactly two parameters (min, shape), for Pareto distribution");
} else if (params.get(0) <= 0) {
throw new CompilerException(
String.format(
"%s is not in valid range 'min > 0', for Pareto distribution", params.get(0)));
} else if (params.get(1) <= 0) {
throw new CompilerException(
String.format(
"%s is not in valid range 'shape > 0', for Pareto distribution", params.get(1)));
}
}
@Override
public double getMean() {
if (min <= 1) {
return Double.MAX_VALUE;
} else {
return (min * shape) / (min - 1);
}
}
@Override
public String toString() {
return String.format("Pareto(%f, %f)", min, shape);
}
}
public static class TruncatedNormal implements Distribution {
public final double mean;
public final double standardDeviation;
public TruncatedNormal(double mean, double standardDeviation) {
this.mean = mean;
this.standardDeviation = standardDeviation;
}
public TruncatedNormal(List<Double> params) {
this(params.get(0), params.get(1));
}
public static void validate(List<Double> params) throws CompilerException {
if (params == null || params.size() != 2) {
throw new CompilerException(
"Expected exactly two parameters (mean, standardDeviation), for TruncatedNormal"
+ " distribution");
} else if (params.get(1) <= 0) {
throw new CompilerException(
String.format(
"%s is not in valid range 'standardDeviation > 0', for TruncatedNormal"
+ " distribution",
params.get(1)));
}
}
@Override
public double getMean() {
return mean;
}
@Override
public String toString() {
return String.format("TruncatedNormal(%f, %f)", mean, standardDeviation);
}
}
public static class Uniform implements Distribution {
public final double min;
public final double max;
public Uniform(double min, double max) {
this.min = min;
this.max = max;
}
public Uniform(List<Double> params) {
this(params.get(0), params.get(1));
}
public static void validate(List<Double> params) throws CompilerException {
if (params == null || params.size() != 2) {
throw new CompilerException(
"Expected exactly two parameters (min, max), for Uniform distribution");
} else if (params.get(0) > params.get(1)) {
throw new CompilerException(
String.format(
"(%s, %s) does not meet requirement 'min <= max', for Uniform distribution",
params.get(0), params.get(1)));
}
}
@Override
public double getMean() {
return (min + max) / 2;
}
@Override
public String toString() {
return String.format("Uniform(%f, %f)", min, max);
}
}
/*Custom, combinations of above defined distributions*/
private abstract static class Combination implements Distribution {
public static void validate(List<Double> params) throws CompilerException {
if (params != null && params.size() != 0) {
throw new CompilerException(
String.format("Expected exactly zero parameters, for combination distributions"));
}
}
}
public static class EasyAndCertain extends Combination {
public static final Exponential exponential = new Exponential(1);
@Override
public double getMean() {
return exponential.getMean();
}
@Override
public String toString() {
return "EasyAndCertain";
}
}
public static class EasyAndUncertain extends Combination {
public static final Bernoulli bernoulli = new Bernoulli(0.5);
@Override
public double getMean() {
return bernoulli.getMean();
}
@Override
public String toString() {
return "EasyAndUncertain";
}
}
public static class HardAndCertain extends Combination {
public static final Exponential exponential = new Exponential(0.1);
@Override
public double getMean() {
return exponential.getMean();
}
@Override
public String toString() {
return "HardAndCertain";
}
}
public static class HardAndUncertain extends Combination {
public static final Bernoulli bernoulli = new Bernoulli(0.5);
public static final Exponential exponential = new Exponential(0.1);
@Override
public double getMean() {
return bernoulli.getMean() * exponential.getMean();
}
@Override
public String toString() {
return "HardAndUncertain";
}
}
public static class VeryHardAndCertain extends Combination {
public static final Exponential exponential = new Exponential(0.01);
@Override
public double getMean() {
return exponential.getMean();
}
@Override
public String toString() {
return "VeryHardAndCertain";
}
}
public static class VeryHardAndUncertain extends Combination {
public static final Bernoulli bernoulli = new Bernoulli(0.5);
public static final Exponential exponential = new Exponential(0.01);
@Override
public double getMean() {
return bernoulli.getMean() * exponential.getMean();
}
@Override
public String toString() {
return "VeryHardAndUncertain";
}
}
public static class Infinity extends Combination {
public Infinity() {}
@Override
public double getMean() {
return Double.MAX_VALUE;
}
@Override
public String toString() {
return "Infinity";
}
}
public static class Zero extends Combination {
@Override
public double getMean() {
return 0;
}
@Override
public String toString() {
return "Zero";
}
}
public static class Enabled extends Combination {
@Override
public double getMean() {
return 1;
}
@Override
public String toString() {
return "Enabled";
}
}
public static class Disabled extends Combination {
@Override
public double getMean() {
return 0;
}
@Override
public String toString() {
return "Disabled";
}
}
}