-
Notifications
You must be signed in to change notification settings - Fork 18
/
Rewards.swift
154 lines (137 loc) · 5.26 KB
/
Rewards.swift
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
// Copyright 2019, The Jelly Bean World Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not
// use this file except in compliance with the License. You may obtain a copy of
// the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations under
// the License.
/// Transition of an agent during a single simulation step.
public struct AgentTransition {
/// State of the agent before the simulation step was performed.
public let previousState: AgentState
/// State of the agent after the simulation step was performed.
public let currentState: AgentState
@inlinable
public init(previousState: AgentState, currentState: AgentState) {
self.previousState = previousState
self.currentState = currentState
}
}
public protocol Reward {
/// Returns a reward value for the provided transition.
///
/// - Parameter transition: Agent transition for which to compute a reward.
/// - Returns: Reward value for the provided transition.
func callAsFunction(for transition: AgentTransition) -> Float
}
public struct ZeroReward: Reward {
public init() {}
@inlinable
public func callAsFunction(for transition: AgentTransition) -> Float {
0
}
}
/// Reward function that scores agent transitions.
public enum SimpleReward: Reward, Equatable {
case zero
case action(value: Float)
case collect(item: Item, value: Float)
case avoid(item: Item, value: Float)
case explore(value: Float)
indirect case combined([SimpleReward])
/// Adds two reward functions. The resulting reward will be equal to the sum of the rewards
/// computed by the two functions.
@inlinable
public static func +(lhs: SimpleReward, rhs: SimpleReward) -> SimpleReward {
.combined([lhs, rhs])
}
/// Returns a reward value for the provided transition.
///
/// - Parameter transition: Agent transition for which to compute a reward.
/// - Returns: Reward value for the provided transition.
@inlinable
public func callAsFunction(for transition: AgentTransition) -> Float {
switch self {
case .zero:
return 0
case let .action(value):
return value
case let .collect(item, value):
let currentItemCount = transition.currentState.items[item] ?? 0
let previousItemCount = transition.previousState.items[item] ?? 0
return Float(currentItemCount - previousItemCount) * value
case let .avoid(item, value):
let currentItemCount = transition.currentState.items[item] ?? 0
let previousItemCount = transition.previousState.items[item] ?? 0
return Float(previousItemCount - currentItemCount) * value
case let .explore(value):
let x = Float(transition.currentState.position.x)
let y = Float(transition.currentState.position.y)
let previousX = Float(transition.previousState.position.x)
let previousY = Float(transition.previousState.position.y)
let distance = x * x + y * y
let previousDistance = previousX * previousX + previousY * previousY
return distance > previousDistance ? value : 0.0
case let .combined(rewards):
return rewards.map { $0(for: transition) }.reduce(0, +)
}
}
}
extension SimpleReward: CustomStringConvertible {
public var description: String {
switch self {
case .zero:
return "Zero"
case let .action(value):
return "Action[\(String(format: "%.2f", value))]"
case let .collect(item, value):
return "Collect[\(item.description), \(String(format: "%.2f", value))]"
case let .avoid(item, value):
return "Avoid[\(item.description), \(String(format: "%.2f", value))]"
case let .explore(value):
return "Explore[\(String(format: "%.2f", value))]"
case let .combined(rewards):
return rewards.map { $0.description }.joined(separator: " ∧ ")
}
}
}
/// Reward function schedule which specifies which reward function is used at each time step.
/// This is useful for representing never-ending learning settings that require adaptation.
public protocol RewardSchedule {
/// Returns the reward function to use for the specified time step.
func reward(forStep step: UInt64) -> Reward
}
/// Fixed reward function schedule that uses the same reward function for all time steps.
public struct FixedReward: RewardSchedule {
public let reward: Reward
public init(_ reward: Reward) {
self.reward = reward
}
public func reward(forStep step: UInt64) -> Reward {
reward
}
}
public struct CyclicalSchedule: RewardSchedule {
public let rewards: [(Reward, UInt64)]
public let cycleDuration: UInt64
public init(_ rewards: [(Reward, UInt64)]) {
precondition(!rewards.isEmpty)
self.rewards = rewards
self.cycleDuration = rewards.map { $0.1 }.reduce(0, +)
}
public func reward(forStep step: UInt64) -> Reward {
let step = step % cycleDuration
var cumulativeDuration: UInt64 = 0
for reward in rewards {
if step < cumulativeDuration + reward.1 { return reward.0 }
cumulativeDuration += reward.1
}
return rewards.last!.0
}
}