Skip to content

Commit

Permalink
Update to bring Flowchain in for method chaining
Browse files Browse the repository at this point in the history
  • Loading branch information
OrigamiDream committed May 26, 2023
1 parent 55a6ed2 commit 08fd961
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 42 deletions.
3 changes: 3 additions & 0 deletions gato/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,5 @@
from gato.config import GatoConfig
from gato.models import Gato
from flowchain import enable_tensor_chaining

enable_tensor_chaining()
8 changes: 4 additions & 4 deletions gato/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,15 @@ def call(self, inputs, training=None, mask=None):

ones = tf.ones((input_ids.shape[0], 1, self.config.layer_width), dtype=tf.float32)
image_embed = self.image_embedding((input_ids, (row_pos, col_pos)), training=training)
image_embed *= tf.matmul(encoding[..., 0], ones, transpose_a=True) # image patch masking
image_embed *= encoding[..., 0].transpose().matmul(ones) # image patch masking

# continuous value takes from first value of input_ids
continuous_embed = self.continuous_encoding(input_ids[..., 0])
continuous_embed = self.discrete_embedding(continuous_embed)
continuous_embed *= tf.matmul(encoding[..., 1], ones, transpose_a=True) # continuous value masking
continuous_embed *= encoding[..., 1].transpose().matmul(ones) # continuous value masking

discrete_embed = self.discrete_embedding(input_ids[..., 0])
discrete_embed *= tf.matmul(encoding[..., 2], ones, transpose_a=True) # discrete value masking
discrete_embed *= encoding[..., 2].transpose().matmul(ones) # discrete value masking

# Appendix C.3. Position Encodings > Local Observation Position Encodings
# add local observation position encodings
Expand Down Expand Up @@ -101,7 +101,7 @@ def call(self, inputs, training=None, mask=None):
patch_size = self.config.img_patch_size
depth = self.config.input_dim // (patch_size * patch_size)

x = tf.reshape(input_ids, (-1, input_ids.shape[1], patch_size, patch_size, depth))
x = input_ids.reshape((-1, input_ids.shape[1], patch_size, patch_size, depth))
x = self.residual_embedding(x)
x = self.pos_encoding((x, (row_pos, col_pos)))
return x
Expand Down
43 changes: 10 additions & 33 deletions gato/models/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,32 +7,13 @@

def _randomized_positions(from_v, to_v):
pos = tf.random.uniform(from_v.shape, minval=0, maxval=1, dtype=tf.float32)
pos = pos * tf.cast(to_v - from_v, dtype=tf.float32)
pos = tf.cast(pos, dtype=tf.int32)
return pos
pos = pos * (to_v - from_v).cast(tf.float32)
return pos.cast(tf.int32)


def _rounded_mean_positions(from_v, to_v):
pos = tf.cast(from_v + to_v, tf.float32)
pos = pos / 2
pos = tf.round(pos)
return pos


def _broadcast(row_pos, col_pos, row_ones, col_ones):
# broadcast (5,) to (20,) with column-axis
row_pos = tf.expand_dims(row_pos, 1)
row_pos = tf.matmul(row_pos, col_ones, transpose_b=True)
row_pos = tf.reshape(row_pos, (-1,))
row_pos = tf.stop_gradient(row_pos)

# broadcast (4,) to (20,) with row-axis
col_pos = tf.expand_dims(col_pos, 1)
col_pos = tf.matmul(row_ones, col_pos, transpose_b=True)
col_pos = tf.reshape(col_pos, (-1,))
col_pos = tf.stop_gradient(col_pos)

return row_pos, col_pos
pos = (from_v + to_v).cast(tf.float32) / 2.
return pos.round()


class PatchPositionEncoding(layers.Layer):
Expand All @@ -57,7 +38,7 @@ def __init__(self,
self.col_embedding = layers.Embedding(self.discretize_depth, self.embedding_dim, name='col_embedding')

def _discretize(self, pos):
return tf.round(pos * self.discretize_depth)
return (pos * self.discretize_depth).round()

def _discretize_interval(self, interval):
pos_from, pos_to = interval
Expand All @@ -83,12 +64,9 @@ def call(self, inputs, *args, **kwargs):
row_pos = _rounded_mean_positions(row_pos_from, row_pos_to)
col_pos = _rounded_mean_positions(col_pos_from, col_pos_to)

col_pos = tf.cast(col_pos, dtype=tf.int32)
row_pos = tf.cast(row_pos, dtype=tf.int32)

# > Once row and column position encoding are retrieved from the embedding table,
# > they are added onto the token embedding produced by the resnet embedding function.
return input_ids + self.row_embedding(row_pos) + self.col_embedding(col_pos)
return input_ids + self.row_embedding(row_pos.cast(tf.int32)) + self.col_embedding(col_pos.cast(tf.int32))

def get_config(self):
config = super(PatchPositionEncoding, self).get_config()
Expand Down Expand Up @@ -127,10 +105,10 @@ def call(self, inputs, *args, **kwargs):

residual = self.conv_proj(self.gn_proj(x))

x = tf.nn.gelu(self.gn1(x))
x = self.gn1(x).gelu()
x = self.conv1(x)

x = tf.nn.gelu(self.gn2(x))
x = self.gn2(x).gelu()
x = self.conv2(x)

return x + residual
Expand Down Expand Up @@ -185,7 +163,7 @@ def call(self, inputs, *args, **kwargs):
x = block(x)
if self.conv_proj is not None:
x = self.conv_proj(x)
x = tf.reshape(x, shape=(-1, inputs.shape[1], self.config.layer_width))
x = x.reshape((-1, inputs.shape[1], self.config.layer_width))
return x

def get_config(self):
Expand Down Expand Up @@ -222,8 +200,7 @@ def call(self, inputs, *args, **kwargs):
embed = self.embedding(obs_pos)

ones = tf.ones((embed.shape[0], 1, self.config.layer_width), dtype=tf.float32)
obs_mask = tf.cast(obs_mask, dtype=tf.float32)
obs_mask = tf.matmul(obs_mask, ones, transpose_a=True)
obs_mask = obs_mask.cast(tf.float32).transpose().matmul(ones)
return embed * obs_mask

def get_config(self):
Expand Down
7 changes: 3 additions & 4 deletions gato/models/tokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,9 @@

def mu_law_encode(x, mu=100, m=256):
# Appendix B. Agent Data Tokenization Details
sign = tf.math.sign(x)
numerator = tf.math.log(tf.abs(x) * mu + 1.0)
numerator = tf.math.log(x.abs() * mu + 1.0)
denominator = tf.math.log(m * mu + 1.0)
return (numerator / denominator) * sign
return (numerator / denominator) * x.sign()


def tokenize_continuous_values(x, mu=100, m=256, bins=1024, shift=None):
Expand All @@ -21,7 +20,7 @@ def tokenize_continuous_values(x, mu=100, m=256, bins=1024, shift=None):
# > We use 1024 bins and shift the resulting integers
# > so they are not overlapping with the ones used for text tokens.
c = (c + 1) * (bins / 2)
c = tf.cast(c, tf.int32)
c = c.cast(tf.int32)
if shift is not None:
c += shift
return c
Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

setup(
name='gato-tf',
version='0.0.2',
version='0.0.3',
description='Unofficial Gato: A Generalist Agent',
url='https://github.com/OrigamiDream/gato.git',
author='OrigamiDream',
Expand All @@ -11,6 +11,7 @@
packages=find_packages(),
install_requires=[
'tensorflow>=2.11',
'flowchain>=0.0.4'
],
keywords=[
'deep learning',
Expand Down

0 comments on commit 08fd961

Please sign in to comment.