Skip to content

Commit

Permalink
Clean up of DataFramePipeline class for consistency.
Browse files Browse the repository at this point in the history
  • Loading branch information
mdbloice committed Dec 17, 2018
1 parent 0f7118e commit 0c321ad
Showing 1 changed file with 11 additions and 18 deletions.
29 changes: 11 additions & 18 deletions Augmentor/Pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,12 @@
import uuid
import warnings
import numpy as np
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
from concurrent.futures import ThreadPoolExecutor

# NOTE:
# https://pypi.org/project/futures/ mentions:
# The ProcessPoolExecutor class has known (unfixable) problems on Python 2 and
# should not be relied on for mission critical work.

from tqdm import tqdm
from PIL import Image
Expand Down Expand Up @@ -1758,24 +1763,13 @@ def __init__(self, source_dataframe, image_col, category_col, output_directory="
super(DataFramePipeline, self).__init__(source_directory=None,
output_directory=output_directory,
save_format=save_format)
self._populate(source_dataframe,
image_col,
category_col,
output_directory,
save_format)

def _populate(self,
source_dataframe,
image_col,
category_col,
output_directory,
save_format):

self._populate(source_dataframe, image_col, category_col, output_directory, save_format)

def _populate(self, source_dataframe, image_col, category_col, output_directory, save_format):
# Assume we have an absolute path for the output
# Scan the directory that user supplied.
self.augmentor_images, self.class_labels = scan_dataframe(source_dataframe,
image_col,
category_col,
output_directory)
self.augmentor_images, self.class_labels = scan_dataframe(source_dataframe, image_col, category_col, output_directory)

self._check_images(output_directory)

Expand Down Expand Up @@ -1908,7 +1902,6 @@ def sample(self, n):
if r <= operation.probability:
images_to_return = operation.perform_operation(images_to_return)

# Convert to array data
images_to_return = [np.asarray(x) for x in images_to_return]

if self.labels:
Expand Down

0 comments on commit 0c321ad

Please sign in to comment.