From 44a32567236e3962ea0429be8182b1a0c6bc274b Mon Sep 17 00:00:00 2001 From: Sanketh Varamballi Date: Tue, 18 Oct 2022 11:36:14 -0400 Subject: [PATCH] Added static typing to *_data classes in data_readers (#677) * added static typing for csv_data * added static typing for graph_data * added static typing to parquet_data * added static typing for text_data '' * removed if statement * changed repeated conditionals to single assert * fixed formatting * changed data_type from Optional[str] to str * removed extra casts * removed cast to self.delimiter * cleaned up omitted list * changed base_data test to work with new static typing * removed IO casts in parquet_data * removed options cast in text_data * fixed pre-commit failure Co-authored-by: Taylor Turner --- dataprofiler/data_readers/avro_data.py | 2 +- dataprofiler/data_readers/base_data.py | 2 +- dataprofiler/data_readers/csv_data.py | 104 +++++++++------ dataprofiler/data_readers/graph_data.py | 118 +++++++++++------- dataprofiler/data_readers/json_data.py | 2 +- dataprofiler/data_readers/parquet_data.py | 52 +++++--- dataprofiler/data_readers/text_data.py | 38 ++++-- .../tests/data_readers/test_base_data.py | 3 +- 8 files changed, 203 insertions(+), 118 deletions(-) diff --git a/dataprofiler/data_readers/avro_data.py b/dataprofiler/data_readers/avro_data.py index cafb97e5c..97a517546 100644 --- a/dataprofiler/data_readers/avro_data.py +++ b/dataprofiler/data_readers/avro_data.py @@ -14,7 +14,7 @@ class AVROData(JSONData, BaseData): """AVROData class to save and load spreadsheet data.""" - data_type: Optional[str] = "avro" + data_type: str = "avro" def __init__( self, diff --git a/dataprofiler/data_readers/base_data.py b/dataprofiler/data_readers/base_data.py index 97bda59f8..e9d2c69d8 100644 --- a/dataprofiler/data_readers/base_data.py +++ b/dataprofiler/data_readers/base_data.py @@ -17,7 +17,7 @@ class BaseData(object): """Abstract class for data loading and saving.""" - data_type: Optional[str] = None + data_type: str info: Optional[str] = None def __init__( diff --git a/dataprofiler/data_readers/csv_data.py b/dataprofiler/data_readers/csv_data.py index 1467509c8..aa52ba8f0 100644 --- a/dataprofiler/data_readers/csv_data.py +++ b/dataprofiler/data_readers/csv_data.py @@ -3,8 +3,10 @@ import random import re from collections import Counter +from typing import Dict, List, Optional, Tuple, Union, cast import numpy as np +import pandas as pd from six import StringIO from . import data_utils @@ -18,9 +20,14 @@ class CSVData(SpreadSheetDataMixin, BaseData): """SpreadsheetData class to save and load spreadsheet data.""" - data_type = "csv" + data_type: str = "csv" - def __init__(self, input_file_path=None, data=None, options=None): + def __init__( + self, + input_file_path: Optional[str] = None, + data: Optional[pd.DataFrame] = None, + options: Optional[Dict] = None, + ): """ Initialize Data class for loading datasets of type CSV. @@ -71,15 +78,15 @@ def __init__(self, input_file_path=None, data=None, options=None): # _selected_columns: columns being selected from the entire dataset # _header: any information pertaining to the file header. self._data_formats["records"] = self._get_data_as_records - self.SAMPLES_PER_LINE_DEFAULT = options.get("record_samples_per_line", 1) - self._selected_data_format = options.get("data_format", "dataframe") - self._delimiter = options.get("delimiter", None) - self._quotechar = options.get("quotechar", None) - self._selected_columns = options.get("selected_columns", list()) - self._header = options.get("header", "auto") - self._checked_header = "header" in options and self._header != "auto" - self._default_delimiter = "," - self._default_quotechar = '"' + self.SAMPLES_PER_LINE_DEFAULT: int = options.get("record_samples_per_line", 1) + self._selected_data_format: str = options.get("data_format", "dataframe") + self._delimiter: Optional[str] = options.get("delimiter", None) + self._quotechar: Optional[str] = options.get("quotechar", None) + self._selected_columns: List[str] = options.get("selected_columns", list()) + self._header: Optional[Union[str, int]] = options.get("header", "auto") + self._checked_header: bool = "header" in options and self._header != "auto" + self._default_delimiter: str = "," + self._default_quotechar: str = '"' if data is not None: self._load_data(data) @@ -89,31 +96,32 @@ def __init__(self, input_file_path=None, data=None, options=None): self._quotechar = self._default_quotechar @property - def selected_columns(self): + def selected_columns(self) -> List[str]: """Return selected columns.""" return self._selected_columns @property - def delimiter(self): + def delimiter(self) -> Optional[str]: """Return delimiter.""" return self._delimiter @property - def quotechar(self): + def quotechar(self) -> Optional[str]: """Return quotechar.""" return self._quotechar @property - def header(self): + def header(self) -> Optional[Union[str, int]]: """Return header.""" return self._header @property - def is_structured(self): + def is_structured(self) -> bool: """Determine compatibility with StructuredProfiler.""" return self.data_format == "dataframe" - def _check_and_return_options(self, options): + @staticmethod + def _check_and_return_options(options: Optional[Dict]) -> Dict: """ Ensure options are valid inputs to the data reader. @@ -121,7 +129,7 @@ def _check_and_return_options(self, options): :type options: dict :return: None """ - options = super()._check_and_return_options(options) + options = super(CSVData, CSVData)._check_and_return_options(options) if "header" in options: value = options["header"] @@ -164,8 +172,11 @@ def _check_and_return_options(self, options): @staticmethod def _guess_delimiter_and_quotechar( - data_as_str, quotechar=None, preferred=[",", "\t"], omitted=['"', "'"] - ): + data_as_str: str, + quotechar: Optional[str] = None, + preferred: List[str] = [",", "\t"], + omitted: List[str] = ['"', "'"], + ) -> Tuple[Optional[str], Optional[str]]: r""" Automatically check for what delimiter exists in a text document. @@ -186,7 +197,10 @@ def _guess_delimiter_and_quotechar( vocab = Counter(data_as_str) if "\n" in vocab: vocab.pop("\n") - for char in omitted + [quotechar]: + omitted_list: list[str] = omitted + if quotechar is not None: + omitted_list = omitted + [quotechar] + for char in omitted_list: if char in vocab: vocab.pop(char) @@ -320,13 +334,13 @@ def _guess_delimiter_and_quotechar( @staticmethod def _guess_header_row( - data_as_str, - suggested_delimiter=None, - suggested_quotechar=None, - diff_thresh=0.1, - none_thresh=0.5, - str_thresh=0.9, - ): + data_as_str: str, + suggested_delimiter: Optional[str] = None, + suggested_quotechar: Optional[str] = None, + diff_thresh: float = 0.1, + none_thresh: float = 0.5, + str_thresh: float = 0.9, + ) -> Optional[int]: r""" Attempt to select the best row for which a header would be valid. @@ -359,7 +373,7 @@ def _guess_header_row( quotechar = '"' # Determine type for every cell - header_check_list = [] + header_check_list: List[List[str]] = [] only_string_flag = True # Requires additional checks for row in data_as_str.split("\n"): @@ -378,7 +392,7 @@ def _guess_header_row( # Flags differences in types between each row (true/false) potential_header = header_check_list[0] - differences = [] + differences: List[List[bool]] = [] for i in range(0, len(header_check_list)): differences.append([]) @@ -517,7 +531,7 @@ def _guess_header_row( return row_classic_header_ends - def _load_data_from_str(self, data_as_str): + def _load_data_from_str(self, data_as_str: str) -> pd.DataFrame: """Load the data into memory from the str.""" delimiter, quotechar = None, None if not self._delimiter or not self._quotechar: @@ -535,12 +549,12 @@ def _load_data_from_str(self, data_as_str): return data_utils.read_csv_df( data_buffered, self.delimiter, - self.header, + cast(Optional[int], self.header), self.selected_columns, read_in_string=True, ) - def _load_data_from_file(self, input_file_path): + def _load_data_from_file(self, input_file_path: str) -> pd.DataFrame: """Load the data into memory from the file.""" data_as_str = data_utils.load_as_str_from_file( input_file_path, self.file_encoding @@ -556,8 +570,11 @@ def _load_data_from_file(self, input_file_path): self._quotechar = quotechar if self._header == "auto": - self._header = self._guess_header_row( - data_as_str, self._delimiter, self._quotechar + self._header = cast( + int, + self._guess_header_row( + data_as_str, self._delimiter, self._quotechar + ), ) self._checked_header = True @@ -581,13 +598,13 @@ def _load_data_from_file(self, input_file_path): return data_utils.read_csv_df( input_file_path, self.delimiter, - self.header, + cast(Optional[int], self.header), self.selected_columns, read_in_string=True, - encoding=self.file_encoding, + encoding=cast(str, self.file_encoding), ) - def _get_data_as_records(self, data): + def _get_data_as_records(self, data: pd.DataFrame) -> List[str]: """Return data as records.""" sep = self.delimiter if self.delimiter else self._default_delimiter quote = self.quotechar if self.quotechar else self._default_quotechar @@ -596,7 +613,7 @@ def _get_data_as_records(self, data): return super(CSVData, self)._get_data_as_records(data) @classmethod - def is_match(cls, file_path, options=None): + def is_match(cls, file_path: str, options: Optional[Dict] = None) -> bool: """ Check if first 1000 lines of given file has valid delimited format. @@ -716,7 +733,12 @@ def is_match(cls, file_path, options=None): # Assume not a CSV return False - def reload(self, input_file_path=None, data=None, options=None): + def reload( + self, + input_file_path: Optional[str] = None, + data: Optional[pd.DataFrame] = None, + options: Optional[Dict] = None, + ): """ Reload the data class with a new dataset. @@ -737,4 +759,4 @@ def reload(self, input_file_path=None, data=None, options=None): header=self.header, delimiter=self.delimiter, quotechar=self.quotechar ) super(CSVData, self).reload(input_file_path, data, options) - self.__init__(self.input_file_path, data, options) + self.__init__(self.input_file_path, data, options) # type: ignore diff --git a/dataprofiler/data_readers/graph_data.py b/dataprofiler/data_readers/graph_data.py index 1709230d9..8fa7b044e 100644 --- a/dataprofiler/data_readers/graph_data.py +++ b/dataprofiler/data_readers/graph_data.py @@ -1,5 +1,6 @@ """Contains class for identifying, reading, and loading graph data.""" import csv +from typing import Dict, List, Optional, Union, cast import networkx as nx @@ -12,9 +13,14 @@ class GraphData(BaseData): """GraphData class to identify, read, and load graph data.""" - data_type = "graph" + data_type: str = "graph" - def __init__(self, input_file_path=None, data=None, options=None): + def __init__( + self, + input_file_path: Optional[str] = None, + data: Optional[nx.Graph] = None, + options: Optional[Dict] = None, + ) -> None: """ Initialize Data class for identifying, reading, and loading graph data. @@ -56,26 +62,28 @@ def __init__(self, input_file_path=None, data=None, options=None): options = self._check_and_return_options(options) BaseData.__init__(self, input_file_path, data, options) - self._source_node = options.get("source_node", None) - self._destination_node = options.get("destination_node", None) - self._target_keywords = options.get( + self._source_node: Optional[int] = options.get("source_node", None) + self._destination_node: Optional[int] = options.get("destination_node", None) + self._target_keywords: List[str] = options.get( "target_keywords", ["target", "destination", "dst"] ) - self._source_keywords = options.get( + self._source_keywords: List[str] = options.get( "source_keywords", ["source", "src", "origin"] ) - self._graph_keywords = options.get("graph_keywords", ["node"]) - self._column_names = options.get("column_names", None) - self._delimiter = options.get("delimiter", None) - self._quotechar = options.get("quotechar", None) - self._header = options.get("header", "auto") - self._checked_header = "header" in options and self._header != "auto" + self._graph_keywords: List[str] = options.get("graph_keywords", ["node"]) + self._column_names: Optional[List[str]] = options.get("column_names", None) + self._delimiter: Optional[str] = options.get("delimiter", None) + self._quotechar: Optional[str] = options.get("quotechar", None) + self._header: Optional[Union[str, int]] = options.get("header", "auto") + self._checked_header: bool = "header" in options and self._header != "auto" if data is not None: self._load_data(data) @classmethod - def _find_target_string_in_column(self, column_names, keyword_list): + def _find_target_string_in_column( + self, column_names: List[str], keyword_list: List[str] + ) -> int: """Find out if col name contains keyword that could refer to target node col.""" column_name_symbols = ["_", ".", "-"] has_target = False @@ -102,9 +110,15 @@ def _find_target_string_in_column(self, column_names, keyword_list): return target_index @classmethod - def csv_column_names(cls, file_path, header, delimiter, encoding="utf-8"): + def csv_column_names( + cls, + file_path: str, + header: Optional[int], + delimiter: Optional[str], + encoding: str = "utf-8", + ) -> List[str]: """Fetch a list of column names from the csv file.""" - column_names = [] + column_names: List[str] = [] if delimiter is None: delimiter = "," if header is None: @@ -117,10 +131,9 @@ def csv_column_names(cls, file_path, header, delimiter, encoding="utf-8"): row_count = 0 for row in csv_reader: if row_count is header: - column_names.append(row) + column_names = row break row_count += 1 - column_names = column_names[0] # replace all whitespaces in the column names for index in range(0, len(column_names)): @@ -128,7 +141,7 @@ def csv_column_names(cls, file_path, header, delimiter, encoding="utf-8"): return column_names @classmethod - def is_match(cls, file_path, options=None): + def is_match(cls, file_path: str, options: Optional[Dict] = None) -> bool: """ Determine whether the file is a graph. @@ -141,24 +154,32 @@ def is_match(cls, file_path, options=None): options = dict() if not CSVData.is_match(file_path, options): return False - header = options.get("header", 0) - delimiter = options.get("delimiter", ",") - encoding = options.get("encoding", "utf-8") - column_names = cls.csv_column_names(file_path, header, delimiter, encoding) - source_keywords = options.get("source_keywords", ["source", "src", "origin"]) - target_keywords = options.get( + header: int = options.get("header", 0) + delimiter: str = options.get("delimiter", ",") + encoding: str = options.get("encoding", "utf-8") + column_names: List[str] = cls.csv_column_names( + file_path, header, delimiter, encoding + ) + source_keywords: List[str] = options.get( + "source_keywords", ["source", "src", "origin"] + ) + target_keywords: List[str] = options.get( "target_keywords", ["target", "destination", "dst"] ) - graph_keywords = options.get("graph_keywords", ["node"]) - source_index = cls._find_target_string_in_column(column_names, source_keywords) - destination_index = cls._find_target_string_in_column( + graph_keywords: List[str] = options.get("graph_keywords", ["node"]) + source_index: int = cls._find_target_string_in_column( + column_names, source_keywords + ) + destination_index: int = cls._find_target_string_in_column( column_names, target_keywords ) - graph_index = cls._find_target_string_in_column(column_names, graph_keywords) + graph_index: int = cls._find_target_string_in_column( + column_names, graph_keywords + ) - has_source = True if source_index >= 0 else False - has_target = True if destination_index >= 0 else False - has_graph_data = True if graph_index >= 0 else False + has_source: bool = True if source_index >= 0 else False + has_target: bool = True if destination_index >= 0 else False + has_graph_data: bool = True if graph_index >= 0 else False if has_target and has_source and has_graph_data: options.update(source_node=source_index) @@ -169,9 +190,10 @@ def is_match(cls, file_path, options=None): return True return False - def _format_data_networkx(self): + def _format_data_networkx(self) -> nx.Graph: """Format the input file into a networkX graph.""" networkx_graph = nx.Graph() + assert self.input_file_path is not None # read lines from csv if not self._checked_header or not self._delimiter: @@ -212,23 +234,27 @@ def _format_data_networkx(self): if count_delimiter_last == num_lines_read: self._delimiter = None - if self._column_names is None: + if ( + self._column_names is None + and isinstance(self._header, int) + and self.file_encoding is not None + ): self._column_names = self.csv_column_names( self.input_file_path, self._header, self._delimiter, self.file_encoding ) - if self._source_node is None: + if self._source_node is None and self._column_names is not None: self._source_node = self._find_target_string_in_column( self._column_names, self._source_keywords ) - if self._destination_node is None: + if self._destination_node is None and self._column_names is not None: self._destination_node = self._find_target_string_in_column( self._column_names, self._target_keywords ) data_as_pd = data_utils.read_csv_df( - self.input_file_path, + cast(str, self.input_file_path), self._delimiter, - self._header, + cast(Optional[int], self._header), [], read_in_string=True, encoding=self.file_encoding, @@ -237,33 +263,33 @@ def _format_data_networkx(self): csv_as_list = data_as_pd.values.tolist() # grab list of edges from source/dest nodes - for line in range(0, len(csv_as_list)): + for line_index in range(0, len(csv_as_list)): # fetch attributes in columns attributes = dict() for column in range(0, len(csv_as_list[0])): - if csv_as_list[line][column] is None: + if csv_as_list[line_index][column] is None: continue if ( column is not self._source_node or column is not self._destination_node - ): + ) and self._column_names is not None: attributes[self._column_names[column]] = float( - csv_as_list[line][column] + csv_as_list[line_index][column] ) elif column is self._source_node or column is self._destination_node: networkx_graph.add_node( - self.check_integer(csv_as_list[line][column]) + self.check_integer(csv_as_list[line_index][column]) ) networkx_graph.add_edge( - self.check_integer(csv_as_list[line][self._source_node]), - self.check_integer(csv_as_list[line][self._destination_node]), + self.check_integer(csv_as_list[line_index][self._source_node]), + self.check_integer(csv_as_list[line_index][self._destination_node]), **attributes ) # get NetworkX object from list return networkx_graph - def _load_data(self, data=None): + def _load_data(self, data: Optional[nx.Graph] = None) -> nx.Graph: if data is not None: if not isinstance(data, nx.Graph): raise ValueError("Only NetworkX Graph objects allowed as input data.") @@ -271,7 +297,7 @@ def _load_data(self, data=None): else: self._data = self._format_data_networkx() - def check_integer(self, string): + def check_integer(self, string: str) -> Union[int, str]: """Check whether string is integer and output integer.""" stringVal = string if string[0] == ("-", "+"): diff --git a/dataprofiler/data_readers/json_data.py b/dataprofiler/data_readers/json_data.py index c2e5ec22a..607f07857 100644 --- a/dataprofiler/data_readers/json_data.py +++ b/dataprofiler/data_readers/json_data.py @@ -18,7 +18,7 @@ class JSONData(SpreadSheetDataMixin, BaseData): """SpreadsheetData class to save and load spreadsheet data.""" - data_type: Optional[str] = "json" + data_type: str = "json" def __init__( self, diff --git a/dataprofiler/data_readers/parquet_data.py b/dataprofiler/data_readers/parquet_data.py index 5d92e7cb7..3fdb26721 100644 --- a/dataprofiler/data_readers/parquet_data.py +++ b/dataprofiler/data_readers/parquet_data.py @@ -1,4 +1,8 @@ """Contains class to save and load parquet data.""" +from io import BytesIO, StringIO +from typing import Any, Dict, List, Optional, Union + +import pandas as pd import pyarrow.parquet as pq from . import data_utils @@ -9,9 +13,14 @@ class ParquetData(SpreadSheetDataMixin, BaseData): """SpreadsheetData class to save and load parquet data.""" - data_type = "parquet" + data_type: str = "parquet" - def __init__(self, input_file_path=None, data=None, options=None): + def __init__( + self, + input_file_path: Optional[str] = None, + data: Optional[Union[pd.DataFrame, str]] = None, + options: Optional[Dict] = None, + ): """ Initialize Data class for loading datasets of type PARQUET. @@ -50,28 +59,36 @@ def __init__(self, input_file_path=None, data=None, options=None): # _selected_columns: columns being selected from the entire dataset self._data_formats["records"] = self._get_data_as_records self._data_formats["json"] = self._get_data_as_json - self._selected_data_format = options.get("data_format", "dataframe") - self._selected_columns = options.get("selected_columns", list()) + self._selected_data_format: str = options.get("data_format", "dataframe") + self._selected_columns: List[str] = options.get("selected_columns", list()) if data is not None: self._load_data(data) @property - def file_encoding(self): + def file_encoding(self) -> None: """Set file encoding to None since not detected for avro.""" return None + @file_encoding.setter + def file_encoding(self, value: Any) -> None: + """Do nothing. + + Required by mypy because the inherited self.file_encoding is read-write). + """ + pass + @property - def selected_columns(self): + def selected_columns(self) -> List[str]: """Return selected columns.""" return self._selected_columns @property - def is_structured(self): + def is_structured(self) -> bool: """Determine compatibility with StructuredProfiler.""" return self.data_format == "dataframe" - def _load_data_from_str(self, data_as_str): + def _load_data_from_str(self, data_as_str: str) -> pd.DataFrame: """Return data from string.""" data_generator = data_utils.data_generator(data_as_str.splitlines()) data, original_df_dtypes = data_utils.read_json_df( @@ -80,7 +97,7 @@ def _load_data_from_str(self, data_as_str): self._original_df_dtypes = original_df_dtypes return data - def _load_data_from_file(self, input_file_path): + def _load_data_from_file(self, input_file_path: str) -> pd.DataFrame: """Return data from file.""" data, original_df_dtypes = data_utils.read_parquet_df( input_file_path, self.selected_columns, read_in_string=True @@ -88,21 +105,23 @@ def _load_data_from_file(self, input_file_path): self._original_df_dtypes = original_df_dtypes return data - def _get_data_as_records(self, data): + def _get_data_as_records(self, data: pd.DataFrame) -> List[str]: """Return data records.""" # split into row samples separate by `\n` data = data.to_json(orient="records", lines=True) data = data.splitlines() return super(ParquetData, self)._get_data_as_records(data) - def _get_data_as_json(self, data): + def _get_data_as_json(self, data: pd.DataFrame) -> List[str]: """Return json data.""" data = data.to_json(orient="records") chars_per_line = min(len(data), self.SAMPLES_PER_LINE_DEFAULT) return list(map("".join, zip(*[iter(data)] * chars_per_line))) @classmethod - def is_match(cls, file_path, options=None): + def is_match( + cls, file_path: Union[str, StringIO, BytesIO], options: Optional[Dict] = None + ) -> bool: """ Test the given file to check if the file has valid Parquet format. @@ -132,7 +151,12 @@ def is_match(cls, file_path, options=None): return is_valid_parquet - def reload(self, input_file_path=None, data=None, options=None): + def reload( + self, + input_file_path: Optional[str] = None, + data: Any = None, + options: Optional[Dict] = None, + ) -> None: """ Reload the data class with a new dataset. @@ -148,4 +172,4 @@ def reload(self, input_file_path=None, data=None, options=None): :return: None """ super(ParquetData, self).reload(input_file_path, data, options) - self.__init__(self.input_file_path, data, options) + self.__init__(self.input_file_path, data, options) # type: ignore diff --git a/dataprofiler/data_readers/text_data.py b/dataprofiler/data_readers/text_data.py index 90f3c97cb..3452bcc94 100644 --- a/dataprofiler/data_readers/text_data.py +++ b/dataprofiler/data_readers/text_data.py @@ -3,6 +3,7 @@ from __future__ import print_function from io import StringIO +from typing import Dict, List, Optional, Union, cast from past.builtins import basestring @@ -13,9 +14,14 @@ class TextData(BaseData): """TextData class to save and load text files.""" - data_type = "text" + data_type: str = "text" - def __init__(self, input_file_path=None, data=None, options=None): + def __init__( + self, + input_file_path: Optional[str] = None, + data: Optional[List[str]] = None, + options: Optional[Dict] = None, + ) -> None: """ Initialize Data class for loading datasets of type TEXT. @@ -56,32 +62,32 @@ def __init__(self, input_file_path=None, data=None, options=None): # _delimiter: delimiter used to decipher the csv input file # _selected_columns: columns being selected from the entire dataset self._data_formats["text"] = self._get_data_as_text - self._selected_data_format = options.get("data_format", "text") - self._samples_per_line = options.get("samples_per_line", int(5e9)) + self._selected_data_format: str = options.get("data_format", "text") + self._samples_per_line: int = options.get("samples_per_line", int(5e9)) if data is not None: self._load_data(data) @property - def samples_per_line(self): + def samples_per_line(self) -> int: """Return samples per line.""" return self._samples_per_line @property - def is_structured(self): + def is_structured(self) -> bool: """Determine compatibility with StructuredProfiler.""" return False - def _load_data(self, data=None): + def _load_data(self, data: Optional[List[str]] = None) -> None: """Load data.""" if data is not None: self._data = data else: self._data = data_utils.read_text_as_list_of_strs( - self.input_file_path, self.file_encoding + cast(str, self.input_file_path), self.file_encoding ) - def _get_data_as_text(self, data): + def _get_data_as_text(self, data: Union[str, List[str]]) -> List[str]: """Return data as text.""" if ( isinstance(data, list) @@ -94,6 +100,7 @@ def _get_data_as_text(self, data): "Data is not in a str or list of str format and cannot be " "converted." ) + data = cast(str, data) samples_per_line = min(max(len(data), 1), self.samples_per_line) data = [ data[i * samples_per_line : (i + 1) * samples_per_line] @@ -101,12 +108,12 @@ def _get_data_as_text(self, data): ] return data - def tokenize(self): + def tokenize(self) -> None: """Tokenize data.""" raise NotImplementedError("Tokenizing does not currently exist for text data.") @classmethod - def is_match(cls, file_path, options=None): + def is_match(cls, file_path: str, options: Optional[Dict] = None) -> bool: """ Return True if all are text files. @@ -125,7 +132,12 @@ def is_match(cls, file_path, options=None): options = {"encoding": data_utils.detect_file_encoding(file_path)} return True - def reload(self, input_file_path=None, data=None, options=None): + def reload( + self, + input_file_path: Optional[str] = None, + data: Optional[List[str]] = None, + options: Optional[Dict] = None, + ) -> None: """ Reload the data class with a new dataset. @@ -141,4 +153,4 @@ def reload(self, input_file_path=None, data=None, options=None): :return: None """ super(TextData, self).reload(input_file_path, data, options) - self.__init__(self.input_file_path, data, options) + self.__init__(self.input_file_path, data, options) # type: ignore diff --git a/dataprofiler/tests/data_readers/test_base_data.py b/dataprofiler/tests/data_readers/test_base_data.py index e492db588..3f28d37c1 100644 --- a/dataprofiler/tests/data_readers/test_base_data.py +++ b/dataprofiler/tests/data_readers/test_base_data.py @@ -75,6 +75,7 @@ def test_can_apply_data_functions(self): class FakeDataClass: # matches the `data_type` value in BaseData for validating priority data_type = "FakeData" + options = {"not_empty": "data"} def func1(self): return "success" @@ -93,7 +94,7 @@ def func1(self): data.test # validate it will take BaseData attribute over the data attribute - self.assertIsNone(data.data_type) + self.assertFalse(data.options) # validate will auto call the data function if it doesn't exist in # BaseData