Skip to content

Document main concepts #306

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions scala-rl-core/src/main/scala/com/scalarl/Agent.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
*
* The nodes are:
*
* \- State nodes, with edges leading out to each possible action. \- Action nodes, with edges
* leading out to (reward, state) pairs.
* - State nodes, with edges leading out to each possible action.
* - Action nodes, with edges leading out to (reward, state) pairs.
*
* Policies are maps of State => Map[A, Weight]. I don't know that I have a policy that is NOT
* that.
Expand All @@ -15,8 +15,8 @@
*
* So to get the value of an ACTION node you need either:
*
* \- To track it directly, with an ActionValueFn, or \- to estimate it with some model of the
* dynamics of the system.
* - To track it directly, with an ActionValueFn, or
* - to estimate it with some model of the dynamics of the system.
*
* TODO - Key questions: \- Can I rethink the interface here? Can StateValueFn instances ONLY be
* calculated for... rings where the weights add up to 1? "Affine combination" is the key idea
Expand Down
23 changes: 22 additions & 1 deletion scala-rl-core/src/main/scala/com/scalarl/SARS.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,38 @@ package com.scalarl

import cats.Functor

/** Chunk that you get back for playing an episode.
/** Represents a single step in a reinforcement learning episode.
*
* SARS stands for State-Action-Reward-State, capturing the complete transition:
* - The initial state the agent was in
* - The action the agent took
* - The reward received for taking that action
* - The next state the environment transitioned to
*/
final case class SARS[Obs, A, R, S[_]](
state: State[Obs, A, R, S],
action: A,
reward: R,
nextState: State[Obs, A, R, S]
) {

/** Maps the observation type of this SARS to a new type.
*
* @param f
* The function to transform the observation from type Obs to type P
* @param S
* Evidence that S has a Functor instance
*/
def mapObservation[P](f: Obs => P)(implicit S: Functor[S]): SARS[P, A, R, S] =
SARS(state.mapObservation(f), action, reward, nextState.mapObservation(f))

/** Maps the reward type of this SARS to a new type.
*
* @param f
* The function to transform the reward from type R to type T
* @param S
* Evidence that S has a Functor instance
*/
def mapReward[T](f: R => T)(implicit S: Functor[S]): SARS[Obs, A, T, S] =
SARS(state.mapReward(f), action, f(reward), nextState.mapReward(f))

Expand Down
12 changes: 9 additions & 3 deletions scala-rl-core/src/main/scala/com/scalarl/State.scala
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,17 @@ trait State[Obs, A, @specialized(Int, Long, Float, Double) R, M[_]] { self =>
def actions: Set[A] = dynamics.keySet
def act(action: A): M[(R, This)] = dynamics.getOrElse(action, invalidMove)

/** Returns a list of possible actions to take from this state. To specify the terminal state,
* return an empty set.
*/
def isTerminal: Boolean = actions.isEmpty

/** Maps the observation type of this state to a new type.
*
* @param f
* The function to transform the observation from type Obs to type P
* @param M
* Evidence that M has a Functor instance
* @return
* A new State with observations of type P but the same actions and rewards
*/
def mapObservation[P](
f: Obs => P
)(implicit M: Functor[M]): State[P, A, R, M] =
Expand Down
7 changes: 7 additions & 0 deletions scala-rl-core/src/main/scala/com/scalarl/Time.scala
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
package com.scalarl

/** A value class wrapper around Long that allows us to talk about time ticking and evolution in a
* type-safe way.
*
* This class provides methods for incrementing time, comparing time values, and basic arithmetic
* operations, while maintaining type safety through the AnyVal wrapper.
*/

case class Time(value: Long) extends AnyVal {
def tick: Time = Time(value + 1)
def -(r: Time) = value - r.value
Expand Down
40 changes: 39 additions & 1 deletion scala-rl-core/src/main/scala/com/scalarl/Util.scala
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,30 @@ import scala.language.higherKinds
object Util {
import cats.syntax.functor._

/** Here we provide various "missing" typeclass instances sewing together algebird typeclasses and
* implementing typeclasses for rainier types.
*/
object Instances {
// this lets us sort AveragedValue instances...
implicit val averageValueOrd: Ordering[AveragedValue] =
Ordering.by(_.value)

// shows how to extract the averaged value out from the accumulating data structure
implicit val avToDouble: ToDouble[AveragedValue] =
ToDouble.instance(_.value)

// Module instance, representing a module that can scale AveragedValue by some scalar double.
implicit val avModule: Module[Double, AveragedValue] =
Module.from((r, av) => AveragedValue(av.count, r * av.value))

// easy, just expose this implicitly.
implicit val realRing: Ring[Real] = RealRing

// trivial VectorSpace, showing that the cats.Id monad (and any Ring R) form a vectorspace.
implicit def idVectorSpace[R](implicit R: Ring[R]): VectorSpace[R, Id] =
VectorSpace.from[R, Id](R.times(_, _))

// Ring instance for rainer Reals.
object RealRing extends Ring[Real] {
override def one = Real.one
override def zero = Real.zero
Expand All @@ -37,11 +46,33 @@ object Util {
}
}

def confine[A](a: A, min: A, max: A)(implicit ord: Ordering[A]): A =
/** Clamps a value between a minimum and maximum value.
*
* This function ensures that the input value `a` is not less than `min` and not greater than
* `max`, returning the clamped value.
*
* @param a
* The value to clamp.
* @param min
* The minimum value.
*/
def clamp[A](a: A, min: A, max: A)(implicit ord: Ordering[A]): A =
ord.min(ord.max(a, min), max)

/** Creates a Map from a set of keys using a function to generate values.
*
* This function takes a set of keys and a function that maps each key to a value, returning a
* Map with the keys and their corresponding values.
*
* @param keys
*/
def makeMap[K, V](keys: Set[K])(f: K => V): Map[K, V] = makeMapUnsafe(keys)(f)

/** similar to makeMap, but doesn't guarantee that there are not duplicate keys. If keys contains
* duplicates, later keys override earlier keys.
*
* @param keys
*/
def makeMapUnsafe[K, V](keys: TraversableOnce[K])(f: K => V): Map[K, V] =
keys.foldLeft(Map.empty[K, V]) { case (m, k) =>
m.updated(k, f(k))
Expand All @@ -53,21 +84,28 @@ object Util {
def updateWith[K, V](m: Map[K, V], k: K)(f: Option[V] => V): Map[K, V] =
m.updated(k, f(m.get(k)))

/** Merges a key and a value into a map using a semigroup to combine values. */
def mergeV[K, V: Semigroup](m: Map[K, V], k: K, delta: V): Map[K, V] =
updateWith(m, k) {
case None => delta
case Some(v) => Semigroup.plus[V](v, delta)
}

/** Finds the keys with the maximum values in a map.
*/
def maxKeys[A, B: Ordering](m: Map[A, B]): Set[A] = allMaxBy(m.keySet)(m(_))

/** Returns the set of keys that map (via `f`) to the maximal B, out of all `as` transformed.
*/
def allMaxBy[A, B: Ordering](as: Set[A])(f: A => B): Set[A] =
if (as.isEmpty) Set.empty
else {
val maxB = f(as.maxBy(f))
as.filter(a => Ordering[B].equiv(maxB, f(a)))
}

/** Iterates a monadic function `f` `n` of times using the starting value `a`.
*/
def iterateM[M[_], A](
n: Int
)(a: A)(f: A => M[A])(implicit M: Monad[M]): M[A] =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@ package algebra
import cats.Id
import com.twitter.algebird.Ring

/** Another attempt at a better thing, here... but I don't know if this solves my problem of needing
* to compose up the stack,
/** This is not currently used! Another attempt at a better thing, here... but I don't know if this
* solves my problem of needing to compose up the stack.
*
* I had a note about this in [[Agent]].
*/
trait AffineCombination[M[_], R] {
implicit def ring: Ring[R]
Expand Down
30 changes: 28 additions & 2 deletions scala-rl-core/src/main/scala/com/scalarl/algebra/Module.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,40 @@ package algebra

import com.twitter.algebird.{Group, Ring, VectorSpace}

/** This class represents a module. For the required properties see:
/** This class represents an abstract-algebraic "module". A module is a generalization of vector
* spaces that allows scalars to come from a ring instead of a field. It consists of:
*
* https://en.wikipedia.org/wiki/Module_(mathematics)
* - An abelian group (G, +) representing the elements that can be scaled
* - A ring (R, +, *) representing the scalars
* - A scaling operation R × G → G that satisfies:
* - r(g₁ + g₂) = rg₁ + rg₂ (distributivity over group addition)
* - (r₁ + r₂)g = r₁g + r₂g (distributivity over ring addition)
* - (r₁r₂)g = r₁(r₂g) (compatibility with ring multiplication)
* - 1g = g (identity scalar)
*
* For more details see: https://en.wikipedia.org/wiki/Module_(mathematics)
*/
object Module {
// the default module!
type DModule[T] = Module[Double, T]

/** This method is used to get the default module for a given type.
*
* @param M
* The module to get.
* @return
* The default module for the given type.
*/
@inline final def apply[R, G](implicit M: Module[R, G]): Module[R, G] = M

/** supplies an implicit module, given an implicitly-available Ring for some type R.
*/
implicit def ringModule[R: Ring]: Module[R, R] = from(Ring.times(_, _))

/** Given an implicit ring and group, accepts a scaleFn that shows how to perform scalar
* multiplication between elements of the ring and the group and returns a new module over R and
* G.
*/
def from[R, G](
scaleFn: (R, G) => G
)(implicit R: Ring[R], G: Group[G]): Module[R, G] =
Expand All @@ -24,6 +47,9 @@ object Module {
if (R.isNonZero(r)) scaleFn(r, g) else G.zero
}

/* Algebird's vector space is generic on the container type C, and implicitly pulls in a group on
C[F]. We are a little more general.
*/
def fromVectorSpace[F, C[_]](implicit
R: Ring[F],
V: VectorSpace[F, C]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,5 @@ object Weight {

implicit val timesMonoid: Monoid[Weight] = Monoid.from(One)(_ * _)
implicit val ord: Ordering[Weight] = Ordering.by(_.w)
implicit val toDouble: ToDouble[Weight] = ToDouble.instance(_.w)
}
4 changes: 4 additions & 0 deletions scala-rl-core/src/main/scala/com/scalarl/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,9 @@ package com
/** Functional reinforcement learning in Scala.
*/
package object scalarl {

/** Type alias for [[com.scalarl.rainier.Categorical]], which represents a finite discrete
* probability distribution.
*/
type Cat[+T] = rainier.Categorical[T]
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import com.scalarl.algebra.ToDouble
import scala.annotation.tailrec
import scala.collection.immutable.Queue

/** A finite discrete distribution.
/** Identical to rainier's `Categorical`, except written with `Double` instead of `Real`.
*
* @param pmfSeq
* A map with keys corresponding to the possible outcomes and values corresponding to the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ object CarRental {

case class Inventory(n: Int, maxN: Int) {
def -(m: Move): Inventory = this + -m
def +(m: Move): Inventory = Inventory(Util.confine(n + m.n, 0, maxN), maxN)
def +(m: Move): Inventory = Inventory(Util.clamp(n + m.n, 0, maxN), maxN)
def update(rentals: Move, returns: Move): Inventory =
Inventory(
math.min(n - rentals.n + returns.n, maxN),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ object Grid {
/** Returns a row that's guaranteed to sit within the range specified by numColumns.
*/
def confine(numRows: Int): Row =
Row(Util.confine(value, 0, numRows - 1))
Row(Util.clamp(value, 0, numRows - 1))

def isWithin(numRows: Int): Boolean = value >= 0 && value < numRows
def assertWithin(numRows: Int): Try[Row] =
Expand All @@ -45,7 +45,7 @@ object Grid {
/** Returns a column that's guaranteed to sit within the range specified by numColumns.
*/
def confine(numColumns: Int): Col =
Col(Util.confine(value, 0, numColumns - 1))
Col(Util.clamp(value, 0, numColumns - 1))

def isWithin(numColumns: Int): Boolean = value >= 0 && value < numColumns
def assertWithin(numColumns: Int): Try[Col] =
Expand Down