Skip to content

Commit

Permalink
refactor Visualizer to internally use ClassConfig
Browse files Browse the repository at this point in the history
  • Loading branch information
AdeelH committed Apr 5, 2024
1 parent 4155e4e commit e3fd1ef
Showing 1 changed file with 10 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@
import matplotlib.pyplot as plt

from rastervision.pipeline.file_system import make_dir
from rastervision.core.data import ClassConfig
from rastervision.pytorch_learner.utils import (
deserialize_albumentation_transform, validate_albumentation_transform,
MinMaxNormalize)
from rastervision.pytorch_learner.learner_config import (
RGBTuple,
ChannelInds,
ensure_class_colors,
validate_channel_display_groups,
get_default_channel_display_groups,
)
Expand Down Expand Up @@ -60,14 +60,21 @@ def __init__(self,
title is a string that will be used as the title of the subplot
for that group.
"""
self.class_names = class_names
self.class_colors = ensure_class_colors(self.class_names, class_colors)
self.class_config = ClassConfig(names=class_names, colors=class_colors)
if transform is None:
transform = A.to_dict(MinMaxNormalize())
self.transform = validate_albumentation_transform(transform)
self._channel_display_groups = validate_channel_display_groups(
channel_display_groups)

@property
def class_names(self):
return self.class_config.names

@property
def class_colors(self):
return self.class_config.colors

@abstractmethod
def plot_xyz(self,
axs,
Expand Down

0 comments on commit e3fd1ef

Please sign in to comment.