Skip to content

Commit

Permalink
CLI based init
Browse files Browse the repository at this point in the history
  • Loading branch information
jdvin committed Jan 6, 2025
1 parent 7fb03ed commit 10bf047
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 49 deletions.
12 changes: 12 additions & 0 deletions fs/elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,18 @@
grey = (88, 110, 117)
magenta = (211, 54, 130)

COLOURS = {
"yellow": yellow,
"beige": beige,
"darkbeige": darkbeige,
"orange": orange,
"blue": blue,
"red": red,
"green": green,
"grey": grey,
"magenta": magenta,
}


class Particle:
def __init__(self, x, y):
Expand Down
162 changes: 113 additions & 49 deletions fs/main.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,23 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any
import argparse
import random

import pygame
import sys
import numpy as np
from elements import ELEMENTS, Particle, Metal, Water, Sand, Acid
from elements import COLOURS, ELEMENTS, Particle, Metal, Water, Sand, Acid
from utils import bezier


@dataclass
class Config:
width: int = 400
height: int = 400
ms_per_frame: float = 1000 / 20 # 20 fps. Set to 0 to run as fast as possible.
scale: int = 2
aircolor: tuple[int, int, int] = (0, 0, 0)
width: int
height: int
ms_per_frame: float
scale: int
aircolor: tuple[int, int, int]


@dataclass
Expand All @@ -35,9 +36,9 @@ class PenStroke:

@dataclass
class SimulationConfig(Config):
data_path: str = "data"
max_frames: int = 1000
n_strokes: int = 5
data_path: str
max_frames: int
n_strokes: int


class Renderer(ABC):
Expand Down Expand Up @@ -71,6 +72,26 @@ def draw(self, state: dict[tuple[int, int], Particle], config: Config):
pygame.display.update()


class ReplayRenderer(Renderer):
def __init__(self, config: SimulationConfig):
self.frames = np.load(config.data_path + "/frames.npy")
self.frame_idx = 0
pygame.init()
self.window = pygame.display.set_mode((config.width, config.height))
pygame.display.set_caption("Falling Sand Replay")

def draw(self, state: dict[tuple[int, int], Particle], config: Config):
if self.frame_idx >= len(self.frames):
pygame.quit()
sys.exit()

frame = self.frames[self.frame_idx]
surface = pygame.surfarray.make_surface(frame)
self.window.blit(surface, (0, 0))
pygame.display.update()
self.frame_idx += 1


class SimulationRenderer(Renderer):
def __init__(self, config: SimulationConfig):
self.window = np.memmap(
Expand Down Expand Up @@ -157,6 +178,14 @@ def update(self, state: dict[tuple[int, int], Particle]):
self.active_element = Acid


class DummyInputHandler(InputHandler):
def __init__(self, config: Config):
self.config = config

def update(self, state: dict[tuple[int, int], Particle]):
pass


class SimulationInputHandler(InputHandler):
def __init__(self, config: SimulationConfig):
self.n_strokes = config.n_strokes
Expand Down Expand Up @@ -218,42 +247,6 @@ def update(self, state: dict[tuple[int, int], Particle]):
self.action_idx += 1


def replay(path: str):
frames = np.load(path + "/frames.npy")
actions = np.load(path + "/actions.npy")
config = SimulationConfig()

pygame.init()
window = pygame.display.set_mode((config.width, config.height))
pygame.display.set_caption("Falling Sand Replay")
clock = pygame.time.Clock()

frame_idx = 0
frame_time = 0

while frame_idx < len(frames):
frame_time += clock.tick()

for event in pygame.event.get():
if event.type == pygame.QUIT:
pygame.quit()
return

if frame_time < config.ms_per_frame:
continue

# Convert numpy array to pygame surface
frame = frames[frame_idx]
surface = pygame.surfarray.make_surface(frame)
window.blit(surface, (0, 0))
pygame.display.update()

frame_idx += 1
frame_time = 0

pygame.quit()


@dataclass
class Engine:
config: Config
Expand Down Expand Up @@ -281,11 +274,82 @@ def run(self):
frame_time = 0


def create_arg_parser():
parser = argparse.ArgumentParser(description="Falling Sand Simulation")

parser.add_argument(
"--input-handler",
choices=["pygame", "simulation"],
default="pygame",
help="Select input handler type",
)

parser.add_argument(
"--renderer",
choices=["pygame", "simulation", "replay"],
default="pygame",
help="Select renderer type",
)

# Config options.
parser.add_argument("--width", type=int, default=400, help="Window width")
parser.add_argument("--height", type=int, default=400, help="Window height")
parser.add_argument(
"--ms-per-frame",
type=float,
default=50.0,
help="Milliseconds per frame (default: 50.0 for 20fps)",
)
parser.add_argument("--scale", type=int, default=2, help="Pixel scale factor")
parser.add_argument("--aircolor", type=str, default="black", help="Air color")
parser.add_argument(
"--data-path", default="data", help="Path for saving/loading simulation data"
)
parser.add_argument(
"--max-frames",
type=int,
default=1000,
help="Maximum number of frames for simulation",
)
parser.add_argument(
"--n-strokes", type=int, default=5, help="Number of strokes for simulation"
)

return parser


def main():
config = SimulationConfig()
# input_handler = PygameInputHandler(config)
input_handler = SimulationInputHandler(config)
renderer = PygameRenderer(config)
parser = create_arg_parser()
args = parser.parse_args()

config = SimulationConfig(
width=args.width,
height=args.height,
ms_per_frame=args.ms_per_frame,
scale=args.scale,
aircolor=COLOURS[args.aircolor],
data_path=args.data_path,
max_frames=args.max_frames,
n_strokes=args.n_strokes,
)

renderers = {
"pygame": PygameRenderer,
"simulation": SimulationRenderer,
"replay": ReplayRenderer,
}
renderer = renderers[args.renderer](config)

input_handlers = {
"pygame": PygameInputHandler,
"simulation": SimulationInputHandler,
"dummy": DummyInputHandler,
}
# In replay mode, no input is accepted.
input_handler = input_handlers[
args.input_handler if args.renderer != "replay" else "dummy"
](config)

engine = Engine(config, renderer, input_handler)
engine.run()

Expand Down

0 comments on commit 10bf047

Please sign in to comment.