Skip to content

Commit

Permalink
first working version
Browse files Browse the repository at this point in the history
  • Loading branch information
sitzmann committed Jan 25, 2022
1 parent 4a1410b commit 2bbf863
Show file tree
Hide file tree
Showing 45 changed files with 579 additions and 266 deletions.
Empty file added __init__.py
Empty file.
28 changes: 28 additions & 0 deletions cfg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# coding=utf-8
# directories
input_images_dir = 'input_images/'
media_out_dir = 'media_out/'
data_dir = 'data/'
reference_digits_dir = data_dir + 'reference_digits/'

# dimensions
resize_factor = 20
field_height_m, field_width_m = 100, 37 # dimensions of an actual field, not the tactics board
resolution_x, resolution_y = 1920, 1080
max_num_windows = 5
digit_target_size = 28

#
min_intensity, max_intensity = 0, 255
ksize_sharpening = 0.8
ksize_blur_thresholded = 3

# field detection
ksize_initial_blur = 15
field_detection_poly_epsilon = 150

radius_players_cm = 1
player_radius_lb, player_radius_ub = 0.75, 0.9

# digits
font_size = 0.15 * resize_factor
File renamed without changes.
Binary file added data/reference_digits/1/22-Jan-2022__18-28-33.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/reference_digits/1/22-Jan-2022__18-28-37.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/reference_digits/1/22-Jan-2022__18-30-05.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/reference_digits/1/22-Jan-2022__18-30-10.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/reference_digits/2/22-Jan-2022__18-28-32.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/reference_digits/2/22-Jan-2022__18-28-38.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/reference_digits/2/22-Jan-2022__18-30-04.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/reference_digits/2/22-Jan-2022__18-30-11.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/reference_digits/3/22-Jan-2022__18-28-34.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/reference_digits/3/22-Jan-2022__18-28-39.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/reference_digits/3/22-Jan-2022__18-30-03.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/reference_digits/3/22-Jan-2022__18-30-08.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/reference_digits/4/22-Jan-2022__18-28-36.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/reference_digits/4/22-Jan-2022__18-28-40.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/reference_digits/4/22-Jan-2022__18-30-02.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/reference_digits/4/22-Jan-2022__18-30-07.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/reference_digits/5/22-Jan-2022__18-28-43.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/reference_digits/5/22-Jan-2022__18-28-44.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/reference_digits/5/22-Jan-2022__18-29-58.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/reference_digits/5/22-Jan-2022__18-29-59.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/reference_digits/6/22-Jan-2022__18-28-47.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/reference_digits/6/22-Jan-2022__18-28-48.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/reference_digits/6/22-Jan-2022__18-29-54.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/reference_digits/6/22-Jan-2022__18-29-56.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/reference_digits/7/22-Jan-2022__18-28-42.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/reference_digits/7/22-Jan-2022__18-28-45.png
Binary file added data/reference_digits/7/22-Jan-2022__18-29-57.png
Binary file added data/reference_digits/7/22-Jan-2022__18-30-00.png
80 changes: 59 additions & 21 deletions digit_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,33 +5,71 @@
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR
from glob import glob
import matplotlib.pyplot as plt
import cv2
import cfg


model = None # lazy initialization
def classify_image(image):
global model
if model is None:
model = Net()
state_dict_path = "mnist_cnn.pt"
global_model = None # lazy initialization
global_features = None


def get_model():
global global_model
if global_model is None:
global_model = Net()
state_dict_path = cfg.data_dir + "mnist_cnn.pt"
state_dict = torch.load(state_dict_path)
model.load_state_dict(state_dict)
global_model.load_state_dict(state_dict)
return global_model


def get_reference_features():
global global_features
if global_features is None:
global_features = []
for label in range(1, 8):
file_paths = glob(f'{cfg.reference_digits_dir}/{label}/*.png')
features_i = []
for img_path in file_paths:
img = cv2.imread(img_path)[:, :, 0]
features_i.append(get_features(img).unsqueeze(0))
features_i = torch.cat(features_i, dim=0)
global_features.append(features_i)
global_features = torch.cat([f.unsqueeze(0) for f in global_features])
return global_features


def get_batch(image):
batch = torch.tensor(image, dtype=torch.float32).view(1, 1, 28, 28)
transform = transforms.Normalize((0.1307,), (0.3081,))
batch = transform(batch)
log_probs = model(batch)[0]
prediction = torch.argmax(log_probs[1:8]).item() + 1
entropy = -sum(torch.exp(log_probs) * log_probs)
return batch


def classify_image_by_examples(image, show=False):
features = get_features(image)
reference_features = get_reference_features()
min_mses = ((reference_features - features)**2).mean(axis=2).min(axis=1).values
pred = min_mses.argmin().item() + 1
show_prediction(get_batch(image), pred, min_mses) if show else None
return pred


def show_prediction(batch, prediction, mses):
plt.gray()
plt.imshow(batch.view(28, 28, 1))
plt.gcf().suptitle(str(prediction))
mses = [f'{mse:.2f}' for mse in mses]
plt.gcf().suptitle(f'{prediction} {mses}')
plt.show()
return prediction, entropy




def get_features(image):
model = get_model()
batch = get_batch(image)
pred, features = model(batch)
return features[0]


class Net(nn.Module):
Expand All @@ -56,12 +94,12 @@ def forward(self, x):
x = self.dropout1(x)
x = self.bn2(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = F.relu(x)
features = self.fc1(x)
x = F.relu(features)
x = self.dropout2(x)
x = self.fc2(x)
output = F.log_softmax(x, dim=1)
return output
return output, features


def train(args, model, device, train_loader, optimizer, epoch):
Expand Down Expand Up @@ -143,9 +181,9 @@ def main():
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
dataset1 = datasets.MNIST('data', train=True, download=True,
dataset1 = datasets.MNIST('training_data', train=True, download=True,
transform=transform)
dataset2 = datasets.MNIST('data', train=False,
dataset2 = datasets.MNIST('training_data', train=False,
transform=transform)
train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs)
test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)
Expand All @@ -160,7 +198,7 @@ def main():
# scheduler.step()

if args.save_model:
torch.save(model.state_dict(), "mnist_cnn.pt")
torch.save(model.state_dict(), "data/mnist_cnn.pt")


if __name__ == '__main__':
Expand Down
139 changes: 139 additions & 0 deletions drawer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
import cairo
import cv2
import numpy as np
from typing import Optional
import export_scene
import os
import shutil
import state
import cfg


def main():
init_context()
draw_field()
draw_player(10, 0)
show(surface)


def draw_scene(state: state.State):
init_context()
draw_field()
[draw_player(player, (0.6, 0, 0)) for player in state.players_team_1]
[draw_player(player, (0, 0, 0.6)) for player in state.players_team_2]
return surface


def animate_scene(state_1, state_2, num_steps, name='animation'):
output_path = f'{cfg.media_out_dir}{name}'
shutil.rmtree(output_path, ignore_errors=True)
os.makedirs(output_path, exist_ok=True)
for i, frac in enumerate(np.linspace(0, 1, num_steps)):
state = state_1 * (1 - frac) + state_2 * frac
surface = draw_scene(state)
path = f'{output_path}/{i:04d}.png'
surface.write_to_png(path)
export_scene.save_video(output_path, fps=10)
shutil.rmtree(output_path)


# globals
width, height, endzone_height = 37, 100, 18
border = 3
scale = 8
ctx: Optional[cairo.Context] = None
surface: Optional[cairo.ImageSurface] = None


def draw_field():
draw_background()
ctx.move_to(*m2p([border, border]))
select_line_brush()
path = [[width, 0], [0, height], [-width, 0], [0, -height]]
for line in path:
rel_line(*line)
for y in [endzone_height, height - endzone_height]:
move_to(0, y)
rel_line(width, 0)
ctx.stroke()


def draw_background():
ctx.set_source_rgb(0.3, 0.5, 0.3)
ctx.move_to(*m2p([0, 0]))
w = width + 2 * border
h = height + 2 * border
path = [[w, 0], [0, h], [-w, 0], [0, -h]]
for line in path:
rel_line(*line)
ctx.fill()

@np.vectorize
def m2p(x, rounded=True):
x = scale * x
if rounded:
x = int(x)
return x


def move_to(x, y):
ctx.move_to(*m2p([border + x, border + y]))


def rel_line(x, y):
ctx.rel_line_to(*m2p([x, y]))


def show(surface, filename='temp', wait=10):
os.makedirs(cfg.media_out_dir, exist_ok=True)
path = f'{cfg.media_out_dir}/{filename}.png'
surface.write_to_png(path)
cv2.imshow(filename, cv2.imread(path))
window_x = cfg.resolution_x - 50 - m2p(width + 2*border)
cv2.moveWindow(filename, window_x, 0)
cv2.waitKey(wait)


def select_line_brush():
ctx.set_source_rgb(0.7, 0.7, 0.7) # Solid color
ctx.set_line_width(m2p(0.3))


def init_context():
global ctx, surface
size_meters = np.array([width + 2 * border, height + 2 * border])
size_pixels = m2p(size_meters / 2, rounded=True) * 2
surface = cairo.ImageSurface(cairo.FORMAT_ARGB32, *size_pixels)
ctx = cairo.Context(surface)


def draw_player(player: state.Player, color):
ctx.set_line_width(0)
ctx.set_source_rgb(*color)
radius = m2p(1, rounded=False)
move_to(*player.pos)
ctx.arc(*ctx.get_current_point(), radius, 0, 2 * np.pi)
ctx.fill()
ctx.set_source_rgb(0.8, 0.8, 0.8)
ctx.select_font_face("Serif", cairo.FONT_SLANT_NORMAL, cairo.FONT_WEIGHT_BOLD)
ctx.set_font_size(m2p(2))
pos = [m2p(player.pos[0] + border), m2p(player.pos[1] + border)]
text(player.label, pos, player.orientation)


def text(string, pos, orientation):
ctx.save()
fascent, fdescent, fheight, fxadvance, fyadvance = ctx.font_extents()
x_off, y_off, tw, th = ctx.text_extents(string)[:4]
nx = -tw/2
ny = fheight/2
ctx.translate(pos[0], pos[1])
ctx.rotate(orientation / -180 * np.pi)
ctx.translate(nx, ny)
ctx.move_to(0, -4)
ctx.show_text(string)
ctx.restore()


if __name__ == "__main__":
main()
9 changes: 9 additions & 0 deletions export_scene.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import moviepy.video.io.ImageSequenceClip
from glob import glob


def save_video(img_folder, fps=30):
image_files = sorted(glob(f'{img_folder}/*.png'))
clip = moviepy.video.io.ImageSequenceClip.ImageSequenceClip(image_files, fps=fps)
clip.write_videofile(f'{img_folder}.mp4')

Binary file added input_images/disc.jpg
Binary file added input_images/ho-stack-1.jpg
Binary file added input_images/ho-stack-2.jpg
Binary file added input_images/old_boards/first_board.jpg
Binary file added input_images/old_boards/second_board.jpg
28 changes: 28 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import scan
import drawer
images_dir = 'input_images/'


def main():
for image_path in [images_dir + 'ho-stack-1.jpg', images_dir + 'ho-stack-2.jpg']:
# image_path = images_dir + 'disc.jpg'
show_digits = False
show_circles = True
record_examples = False
player_positions = scan.scan(image_path, show_digits, show_circles, record_examples)
surface = drawer.draw_scene(player_positions)
drawer.show(surface, wait=0)


def animate():
img_1, img_2 = images_dir + 'ho-stack-1.jpg', images_dir + 'ho-stack-2.jpg'
show_circles = False
show_digits = False
state_1 = scan.scan(img_1, show_digits, show_circles)
state_2 = scan.scan(img_2, show_digits, show_circles)
drawer.animate_scene(state_1, state_2, 30)


if __name__ == '__main__':
main()
# animate()
Binary file added media_out/animation.mp4
Binary file not shown.
Binary file added media_out/temp.png
Loading

0 comments on commit 2bbf863

Please sign in to comment.