Skip to content

Commit

Permalink
Remove square arena restriction.
Browse files Browse the repository at this point in the history
  • Loading branch information
vaxenburg committed Sep 4, 2024
1 parent b7c9e86 commit 7802806
Showing 1 changed file with 115 additions and 82 deletions.
197 changes: 115 additions & 82 deletions flybody/tasks/arenas/hills.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Arenas for the fruitfly walker."""

import numpy as np
from scipy import ndimage

Expand All @@ -23,16 +25,16 @@ def terrain_bowl(physics,
random_state=None):
"""Generate a bowl-shaped terrain.
Args:
physics: Current physics instance.
bump_scale: Spatial extent of bumps.
elevation_z: Returned terrain will be normalized between [0, elevation_z].
tanh_rel_radius: Radius of bowl, relative to half-length of arena.
tanh_sharpness: Sharpness of tanh.
Args:
physics: Current physics instance.
bump_scale: Spatial extent of bumps.
elevation_z: Returned terrain will be normalized between [0, elevation_z].
tanh_rel_radius: Radius of bowl, relative to half-length of arena.
tanh_sharpness: Sharpness of tanh.
Returns:
terrain (nrow, ncol).
"""
Returns:
terrain (nrow, ncol).
"""
size = physics.model.hfield_size[0, :2] # half-lengths! (e.g., radius)
nrow = physics.model.hfield_nrow[0]
ncol = physics.model.hfield_ncol[0]
Expand Down Expand Up @@ -60,16 +62,16 @@ def terrain_bowl(physics,
def add_sine_bumps(terrain, arena_size, wavelength=5., phase=0., height=1.):
"""Add sine-like bumps to an existing terrain.
Args:
terrain: Initial terrain, (nrow, ncol).
arena_size: Half-length (aka radius) of arena, shape (2,).
wavelength: Wavelength of bumps, in actual world length.
phase: Phase of sine.
height: Amplitude of bumps, in actual world length.
Args:
terrain: Initial terrain, (nrow, ncol).
arena_size: Half-length (aka radius) of arena, shape (2,).
wavelength: Wavelength of bumps, in actual world length units.
phase: Phase of sine.
height: Amplitude of bumps, in actual world length units.
Returns:
terrain: Initial terrain with added sine bumps, (nrow, ncol).
"""
Returns:
terrain: Initial terrain with added sine bumps, (nrow, ncol).
"""
_, ncol = terrain.shape
x_axis = np.linspace(-arena_size[0], arena_size[0], ncol)
bumps = height * 0.5 * (np.sin(2 * np.pi / wavelength * x_axis + phase) +
Expand All @@ -91,22 +93,22 @@ def add_sine_trench(terrain,
sigma=0.2):
"""Add sine-shaped trench to terrain.
Args:
terrain: Initial terrain, (nrow, ncol).
arena_size: Half-lengths (aka radius) of arena, cm, shape (2,).
wavelength: Sine wavelength, in cm.
amplitude: Sine amplitude, in cm.
phase: Sine phase, rad.
start_x: x of trench entrance, cm.
end_x: x of trench end, cm.
width: Width of trench before smoothing, cm.
height: Height of trench, cm.
sigma: Terrain smoothing stddev, in cm.
Returns:
terrain: Initial terrain with added sine trench, (nrow, ncol).
sine: Trench sine, not used in the task but can be used for analysis etc.
"""
Args:
terrain: Initial terrain, (nrow, ncol).
arena_size: Half-lengths (aka radius) of arena, cm, shape (2,).
wavelength: Sine wavelength, in cm.
amplitude: Sine amplitude, in cm.
phase: Sine phase, rad.
start_x: x of trench entrance, cm.
end_x: x of trench end, cm.
width: Width of trench before smoothing, cm.
height: Height of trench, cm.
sigma: Terrain smoothing stddev, in cm.
Returns:
terrain: Initial terrain with added sine trench, (nrow, ncol).
sine: Trench sine, not used in the task but can be used for analysis etc.
"""
nrow, ncol = terrain.shape
idx_from, _ = pos_to_terrain_idx(start_x, 0, arena_size, nrow, ncol)
idx_to, _ = pos_to_terrain_idx(end_x, 0, arena_size, nrow, ncol)
Expand All @@ -131,38 +133,48 @@ def add_sine_trench(terrain,
class Hills(composer.Arena):
"""A hilly arena.
Args:
name: Name of the arena.
dim: Half-length of the actual arena (this is the `radius`.)
aesthetic: Aesthetic of the arena.
hfield_elevation_z=1, hfield_base_z: hfield asset parameters.
grid_density: number of hfield grid points per unit length of the actual
floor. For example, if grid density == 10, the number of hfield grid
points in 1x1 square of actual floor is 100.
elevation_z_range: Range of elevation of horizon mountains.
"""
Args:
name: Name of the arena.
dim (tuple or int): Half-length of the actual arena (this is the `radius`).
If a tuple is provided, then it's (radius_x, radius_y).
aesthetic: Aesthetic of the arena.
hfield_elevation_z, hfield_base_z: hfield asset parameters.
grid_density (tuple or int): number of hfield grid points per unit length
of the actual floor. For example, if grid density == 10, the number of
hfield grid points in 1x1 square of actual floor is 100.
If a tuple is provided, it's (density_x, density_y).
elevation_z_range: Range of elevation of horizon mountains.
"""

def _build(self,
name='hills',
dim=20,
dim=(20, 20),
aesthetic='outdoor_natural',
hfield_elevation_z=1,
hfield_base_z=0.05,
grid_density=10,
grid_density=(10, 10),
elevation_z_range=(4., 5.)):
super()._build(name=name)

size = (dim, dim)
if isinstance(dim, tuple):
# Potentially rectangular arena.
size = dim
else:
# Square arena.
size = (dim, dim)
if not isinstance(grid_density, tuple):
grid_density = (grid_density, grid_density)
self._elevation_z_range = elevation_z_range

self._hfield = self._mjcf_root.asset.add(
'hfield',
name='terrain',
nrow=((2 * grid_density * size[0]) // 2) * 2 + 1,
ncol=((2 * grid_density * size[1]) // 2) * 2 + 1,
nrow=((2 * grid_density[1] * size[1]) // 2) * 2 + 1,
ncol=((2 * grid_density[0] * size[0]) // 2) * 2 + 1,
size=size + (hfield_elevation_z, hfield_base_z))

if aesthetic != 'default':

ground_info = locomotion_arenas_assets.get_ground_texture_info(
aesthetic)
sky_info = locomotion_arenas_assets.get_sky_texture_info(aesthetic)
Expand Down Expand Up @@ -198,18 +210,39 @@ def _build(self,
name='groundplane',
size=list(size) + [0.5],
material=self._material)

else:

self._ground_texture = self._mjcf_root.asset.add(
'texture',
rgb1=[.2, .3, .4],
rgb2=[.1, .2, .3],
type='2d',
builtin='checker',
name='groundplane',
width=200,
height=200,
mark='edge',
markrgb=[0.8, 0.8, 0.8])
self._ground_material = self._mjcf_root.asset.add(
'material',
name='groundplane',
texrepeat=[2, 2], # Makes white squares exactly 1x1 length units.
texuniform=True,
reflectance=2,
texture=self._ground_texture)

self._terrain_geom = self._mjcf_root.worldbody.add(
'geom',
name='terrain',
type='hfield',
rgba=(0.2, 0.3, 0.4, 1),
pos=(0, 0, -0.01),
hfield='terrain')
hfield='terrain',
material=self._ground_material)
self._ground_geom = self._mjcf_root.worldbody.add(
'geom',
type='plane',
name='groundplane',
pos=(0, 0, -0.01),
rgba=(0.2, 0.3, 0.4, 1),
size=list(size) + [0.5])

Expand Down Expand Up @@ -256,24 +289,24 @@ def ground_geoms(self):
class SineTrench(Hills):
"""A hilly arena.
Args:
name: Name of the arena.
dim: Half-length of the actual arena (this is the `radius`.)
aesthetic: Aesthetic of the arena.
hfield_elevation_z=1, hfield_base_z: hfield asset parameters.
grid_density: number of hfield grid points per unit length of the actual
floor. For example, if grid density == 10, the number of hfield grid
points in 1x1 square of actual floor is 100.
elevation_z_range: Range of elevation of horizon mountains.
start_offset_range: Range of x-offset of trench entrance.
trench_len_range: Range of trench length.
phase_range: Range of sine phase.
wavelength_range: Range of sine wavelength.
amplitude_range: Range of sine amplitude.
width_range: Range of trench width (see implementation how it's calculated).
height_range: Range of trench height.
sigma_range: Range of terrain smoothing stddev.
"""
Args:
name: Name of the arena.
dim: Half-length of the actual arena (this is the `radius`.)
aesthetic: Aesthetic of the arena.
hfield_elevation_z=1, hfield_base_z: hfield asset parameters.
grid_density: number of hfield grid points per unit length of the actual
floor. For example, if grid density == 10, the number of hfield grid
points in 1x1 square of actual floor is 100.
elevation_z_range: Range of elevation of horizon mountains.
start_offset_range: Range of x-offset of trench entrance.
trench_len_range: Range of trench length.
phase_range: Range of sine phase.
wavelength_range: Range of sine wavelength.
amplitude_range: Range of sine amplitude.
width_range: Range of trench width (see implementation how it's calculated).
height_range: Range of trench height.
sigma_range: Range of terrain smoothing stddev.
"""

def _build(self,
name='sine_trench',
Expand Down Expand Up @@ -368,18 +401,18 @@ def trench_specs(self):
class SineBumps(Hills):
"""A hilly arena with sinusoidal bumps.
Args:
name: Name of the arena.
dim: Half-length of the actual arena (this is the `radius`.)
aesthetic: Aesthetic of the arena.
hfield_elevation_z=1, hfield_base_z: hfield asset parameters.
grid_density: number of hfield grid points per unit length of the actual
floor. For example, if grid density == 10, the number of hfield grid
points in 1x1 square of actual floor is 100.
elevation_z_range: Range of elevation of horizon mountains.
phase_range: Range of sine phase.
wavelength_range: Range of sine wavelength.
height_range: Range of sine amplitude.
Args:
name: Name of the arena.
dim: Half-length of the actual arena (this is the `radius`.)
aesthetic: Aesthetic of the arena.
hfield_elevation_z=1, hfield_base_z: hfield asset parameters.
grid_density: number of hfield grid points per unit length of the actual
floor. For example, if grid density == 10, the number of hfield grid
points in 1x1 square of actual floor is 100.
elevation_z_range: Range of elevation of horizon mountains.
phase_range: Range of sine phase.
wavelength_range: Range of sine wavelength.
height_range: Range of sine amplitude.
"""

def _build(self,
Expand Down

0 comments on commit 7802806

Please sign in to comment.