Skip to content

Commit

Permalink
Merge pull request #183 from sumadithya/random-choices-docs
Browse files Browse the repository at this point in the history
  • Loading branch information
Xrayez authored Feb 20, 2022
2 parents 1b3dd96 + 953c1a0 commit 677b0d4
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 15 deletions.
34 changes: 20 additions & 14 deletions core/math/random.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,15 +139,15 @@ Variant Random::pop(const Variant &p_from) {
return Variant();
}

Array Random::choices(const Variant &p_from, int p_count, const PoolIntArray &p_weights, bool p_is_cumulative) {
Array Random::choices(const Variant &p_sequence, int p_count, const PoolIntArray &p_weights, bool p_is_cumulative) {
int sum = 0;
LocalVector<int, int> cumulative_weights;
LocalVector<int, int> weights;
LocalVector<int, int> indices;
Array weighted_choices;

if ((p_from.get_type() == Variant::DICTIONARY) && p_weights.empty()) {
Dictionary dict = p_from;
if ((p_sequence.get_type() == Variant::DICTIONARY) && p_weights.empty()) {
Dictionary dict = p_sequence;
Array w = dict.values();
for (int i = 0; i < w.size(); i++) {
weights.push_back(w[i]);
Expand All @@ -158,20 +158,25 @@ Array Random::choices(const Variant &p_from, int p_count, const PoolIntArray &p_
}
}

ERR_FAIL_COND_V_MSG(weights.empty() && p_is_cumulative, Array(), "Cumulative weights cannot be empty.");
if (!weights.empty()) {
if (p_is_cumulative) {
int prev_cumulative = weights[0];
for(int i = 1; i < weights.size(); i++) {
ERR_FAIL_COND_V_MSG(weights[i] < 0, Array(), "Weights must be non-negative integers.");
ERR_FAIL_COND_V_MSG(weights[i] < prev_cumulative, Array(), "Cumulative weights must be non-decreasing.");
prev_cumulative = weights[i];
}
sum = weights[weights.size() - 1];
cumulative_weights = weights;
} else {
for (int i = 0; i < weights.size(); i++) {
if (weights[i] < 0) {
ERR_FAIL_V_MSG(Array(), "Weights must be positive integers.");
} else {
sum += weights[i];
cumulative_weights.push_back(sum);
}
ERR_FAIL_COND_V_MSG(weights[i] < 0, Array(), "Weights must be non-negative integers.");
sum += weights[i];
cumulative_weights.push_back(sum);
}
}
ERR_FAIL_COND_V_MSG(sum == 0, Array(), "Sum of weights cannot be zero");

for (int i = 0; i < p_count; i++) {
int left = 0;
Expand All @@ -190,9 +195,9 @@ Array Random::choices(const Variant &p_from, int p_count, const PoolIntArray &p_
}
}

switch (p_from.get_type()) {
switch (p_sequence.get_type()) {
case Variant::STRING: {
String str = p_from;
String str = p_sequence;
ERR_FAIL_COND_V_MSG(str.empty(), Variant(), "String is empty.");
if (weights.empty()) {
for (int i = 0; i < p_count; i++) {
Expand All @@ -214,7 +219,7 @@ Array Random::choices(const Variant &p_from, int p_count, const PoolIntArray &p_
case Variant::POOL_VECTOR3_ARRAY:
case Variant::POOL_COLOR_ARRAY:
case Variant::ARRAY: {
Array arr = p_from;
Array arr = p_sequence;
ERR_FAIL_COND_V_MSG(arr.empty(), Variant(), "Array is empty.");

if (weights.empty()) {
Expand All @@ -230,8 +235,9 @@ Array Random::choices(const Variant &p_from, int p_count, const PoolIntArray &p_
return weighted_choices;
} break;
case Variant::DICTIONARY: {
Dictionary dict = p_from;
Dictionary dict = p_sequence;
ERR_FAIL_COND_V_MSG(dict.empty(), Variant(), "Dictionary is empty.");
ERR_FAIL_COND_V_MSG(((dict.size() != weights.size()) && (!weights.empty())), Variant(), "Size of weights does not match.");
for (int i = 0; i < p_count; i++) {
weighted_choices.push_back(dict.get_key_at_index(indices[i]));
}
Expand Down Expand Up @@ -276,7 +282,7 @@ void Random::_bind_methods() {
ClassDB::bind_method(D_METHOD("range", "from", "to"), &Random::range);
ClassDB::bind_method(D_METHOD("pick", "from"), &Random::pick);
ClassDB::bind_method(D_METHOD("pop", "from"), &Random::pop);
ClassDB::bind_method(D_METHOD("choices", "from", "count", "weights", "cumulative"), &Random::choices, DEFVAL(1), DEFVAL(Variant()), DEFVAL(false));
ClassDB::bind_method(D_METHOD("choices", "sequence", "count", "weights", "cumulative"), &Random::choices, DEFVAL(1), DEFVAL(Variant()), DEFVAL(false));
ClassDB::bind_method(D_METHOD("shuffle", "array"), &Random::shuffle);
ClassDB::bind_method(D_METHOD("decision", "probability"), &Random::decision);

Expand Down
5 changes: 4 additions & 1 deletion doc/Random.xml
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,14 @@
</method>
<method name="choices">
<return type="Array" />
<argument index="0" name="from" type="Variant" />
<argument index="0" name="sequence" type="Variant" />
<argument index="1" name="count" type="int" default="1" />
<argument index="2" name="weights" type="PoolIntArray" default="null" />
<argument index="3" name="cumulative" type="bool" default="false" />
<description>
Returns an [Array] of randomly picked elements from a [code]sequence[/code], with the number of elements equal to [code]count[/code]. The elements are picked according to integer [code]weights[/code] or an array of values from the [code]sequence[/code] if it's a [Dictionary] and if [code]weights[/codes] is empty.
All elements of [code]weights[/code] must be non-negative integers, and must contain at least one non-zero element if [code]weights[/code] is not empty. Additionally, the order of integers should be non-decreasing if [code]cumulative[/code] is [code]true[/code].
If [code]weights[/code] is not empty and if [code]sequence[/code] is not a [Dictionary], then the size of [code]weights[/code] must be equal to the size of [code]sequence[/code].
</description>
</method>
<method name="color_hsv">
Expand Down
56 changes: 56 additions & 0 deletions tests/project/goost/core/math/test_random.gd
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,62 @@ func test_pick():

Engine.print_error_messages = true

func test_choices():
var rng = Random.new_instance()

rng.seed = 58885
var elements = rng.choices(["Godot", Color.blue, "Goost", Color.red], 4, [1,3,6,9])
assert_eq(elements, [Color.red, Color.red, "Goost", "Godot"])

rng.seed = 222
elements = rng.choices("Goost", 7, [1,14,6,9,5])
assert_eq(elements, ['G', 'o', 't', 'G', 's', 's', 't'])

rng.seed = 335
elements = rng.choices({"Godot": 3, "Goost": 8, "Godex": 10}, 4)
assert_eq(elements, ['Godex', 'Godot', 'Godex', 'Godex'])

rng.seed = 335
elements = rng.choices({"Godot": 3, "Goost": 8, "Godex": 10}, 4, [])
assert_eq(elements, ['Godex', 'Godot', 'Godex', 'Godex'])

rng.seed = 335
elements = rng.choices({"Godot": 3, "Goost": 8, "Godex": 10}, 4, [4, 9, 16])
assert_eq(elements, ['Godex', 'Godex', 'Godex', 'Godot'])

rng.seed = 335
elements = rng.choices({"Godot": 3, "Goost": 8, "Godex": 10}, 4, [], true)
assert_eq(elements, ['Godot', 'Goost', 'Godot', 'Goost'])

rng.seed = 335
elements = rng.choices({"Godot": 3, "Goost": 8, "Godex": 10}, 4, [4, 9, 16], true)
assert_eq(elements, ['Godex', 'Godot', 'Godot', 'Goost'])

Engine.print_error_messages = false

assert_eq(rng.choices(""), Array([]))
assert_eq(rng.choices([]), Array([]))

# unequal sizes
assert_eq(rng.choices({"Godot": 3, "Goost": 8, "Godex": 10}, 4, [4, 9, 16, 18], true), Array([]))
assert_eq(rng.choices({"Godot": 3, "Goost": 8, "Godex": 10}, 4, [4, 9], true), Array([]))
assert_eq(rng.choices(["Godot", "Goost", "Godex"], 4, [4, 9, 16, 18], true), Array([]))
assert_eq(rng.choices(["Godot", "Goost", "Godex"], 4, [4, 9], true), Array([]))

# decreasing/ negative
assert_eq(rng.choices({"Godot": 3, "Goost": -8, "Godex": 10}, 4, [], false), Array([]))
assert_eq(rng.choices({"Godot": 3, "Goost": -8, "Godex": 10}, 4, [], true), Array([]))
assert_eq(rng.choices({"Godot": 3, "Goost": 8, "Godex": 7}, 4, [], true), Array([]))
assert_eq(rng.choices({"Godot": 3, "Goost": 8, "Godex": 10}, 4, [4, -9, 16, 18], false), Array([]))
assert_eq(rng.choices({"Godot": 3, "Goost": 8, "Godex": 10}, 4, [4, -9, 16, 18], true), Array([]))
assert_eq(rng.choices({"Godot": 3, "Goost": 8, "Godex": 10}, 4, [4, 9, 6, 18], true), Array([]))
assert_eq(rng.choices({"roman" : 22, 22 : 25, BoxShape.new() : BoxShape.new()}, 37, PoolIntArray([]), true), Array([]))

# All zero weights
assert_eq(rng.choices({"Godot": 3, "Goost": 8, "Godex": 10}, 4, [0, 0, 0], true), Array([]))
assert_eq(rng.choices({"Godot": 3, "Goost": 8, "Godex": 10}, 4, [0, 0, 0], false), Array([]))

Engine.print_error_messages = true

func test_pop():
var rng = Random.new_instance()
Expand Down

0 comments on commit 677b0d4

Please sign in to comment.