Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions cpp/src/arrow/dataset/dataset_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -185,5 +185,27 @@ inline bool operator==(const SubtreeImpl::Encoded& l, const SubtreeImpl::Encoded
l.partition_expression == r.partition_expression;
}

/// Get fragment scan options of the expected type.
/// \return Fragment scan options if provided on the scan options, else the default
/// options if set, else a default-constructed value. If options are provided
/// but of the wrong type, an error is returned.
template <typename T>
arrow::Result<std::shared_ptr<T>> GetFragmentScanOptions(
const std::string& type_name, ScanOptions* scan_options,
const std::shared_ptr<FragmentScanOptions>& default_options) {
auto source = default_options;
if (scan_options && scan_options->fragment_scan_options) {
source = scan_options->fragment_scan_options;
}
if (!source) {
return std::make_shared<T>();
}
if (source->type_name() != type_name) {
return Status::Invalid("FragmentScanOptions of type ", source->type_name(),
" were provided for scanning a fragment of type ", type_name);
}
return internal::checked_pointer_cast<T>(source);
}

} // namespace dataset
} // namespace arrow
5 changes: 5 additions & 0 deletions cpp/src/arrow/dataset/file_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,11 @@ class ARROW_DS_EXPORT FileSource {
/// \brief Base class for file format implementation
class ARROW_DS_EXPORT FileFormat : public std::enable_shared_from_this<FileFormat> {
public:
/// Options affecting how this format is scanned.
///
/// The options here can be overridden at scan time.
std::shared_ptr<FragmentScanOptions> default_fragment_scan_options;

virtual ~FileFormat() = default;

/// \brief The name identifying the kind of file format
Expand Down
24 changes: 13 additions & 11 deletions cpp/src/arrow/dataset/file_csv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,14 +82,11 @@ static inline Result<csv::ConvertOptions> GetConvertOptions(
ARROW_ASSIGN_OR_RAISE(auto column_names,
GetColumnNames(format.parse_options, first_block, pool));

auto convert_options = csv::ConvertOptions::Defaults();
if (scan_options && scan_options->fragment_scan_options &&
scan_options->fragment_scan_options->type_name() == kCsvTypeName) {
auto csv_scan_options = internal::checked_pointer_cast<CsvFragmentScanOptions>(
scan_options->fragment_scan_options);
convert_options = csv_scan_options->convert_options;
}

ARROW_ASSIGN_OR_RAISE(
auto csv_scan_options,
GetFragmentScanOptions<CsvFragmentScanOptions>(
kCsvTypeName, scan_options.get(), format.default_fragment_scan_options));
auto convert_options = csv_scan_options->convert_options;
for (FieldRef ref : scan_options->MaterializedFields()) {
ARROW_ASSIGN_OR_RAISE(auto field, ref.GetOne(*scan_options->dataset_schema));

Expand All @@ -99,8 +96,13 @@ static inline Result<csv::ConvertOptions> GetConvertOptions(
return convert_options;
}

static inline csv::ReadOptions GetReadOptions(const CsvFileFormat& format) {
auto read_options = csv::ReadOptions::Defaults();
static inline Result<csv::ReadOptions> GetReadOptions(
const CsvFileFormat& format, const std::shared_ptr<ScanOptions>& scan_options) {
ARROW_ASSIGN_OR_RAISE(
auto csv_scan_options,
GetFragmentScanOptions<CsvFragmentScanOptions>(
kCsvTypeName, scan_options.get(), format.default_fragment_scan_options));
auto read_options = csv_scan_options->read_options;
// Multithreaded conversion of individual files would lead to excessive thread
// contention when ScanTasks are also executed in multiple threads, so we disable it
// here.
Expand All @@ -112,7 +114,7 @@ static inline Result<std::shared_ptr<csv::StreamingReader>> OpenReader(
const FileSource& source, const CsvFileFormat& format,
const std::shared_ptr<ScanOptions>& scan_options = nullptr,
MemoryPool* pool = default_memory_pool()) {
auto reader_options = GetReadOptions(format);
ARROW_ASSIGN_OR_RAISE(auto reader_options, GetReadOptions(format, scan_options));

util::string_view first_block;
ARROW_ASSIGN_OR_RAISE(auto input, source.OpenCompressed());
Expand Down
10 changes: 8 additions & 2 deletions cpp/src/arrow/dataset/file_csv.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,17 @@ class ARROW_DS_EXPORT CsvFileFormat : public FileFormat {
std::shared_ptr<FileWriteOptions> DefaultWriteOptions() override { return NULLPTR; }
};

class ARROW_DS_EXPORT CsvFragmentScanOptions : public FragmentScanOptions {
public:
/// \brief Per-scan options for CSV fragments
struct ARROW_DS_EXPORT CsvFragmentScanOptions : public FragmentScanOptions {
std::string type_name() const override { return kCsvTypeName; }

/// CSV conversion options
csv::ConvertOptions convert_options = csv::ConvertOptions::Defaults();

/// CSV reading options
///
/// Note that use_threads is always ignored.
csv::ReadOptions read_options = csv::ReadOptions::Defaults();
};

} // namespace dataset
Expand Down
37 changes: 37 additions & 0 deletions cpp/src/arrow/dataset/file_csv_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,43 @@ bar)");
ASSERT_EQ(null_count, 1);
}

TEST_P(TestCsvFileFormat, CustomReadOptions) {
auto source = GetFileSource(R"(header_skipped
str
foo
MYNULL
N/A
bar)");
SetSchema({field("str", utf8())});
auto defaults = std::make_shared<CsvFragmentScanOptions>();
defaults->read_options.skip_rows = 1;
format_->default_fragment_scan_options = defaults;
ASSERT_OK_AND_ASSIGN(auto fragment, format_->MakeFragment(*source));
ASSERT_OK_AND_ASSIGN(auto physical_schema, fragment->ReadPhysicalSchema());
AssertSchemaEqual(opts_->dataset_schema, physical_schema);

{
int64_t rows = 0;
for (auto maybe_batch : Batches(fragment.get())) {
ASSERT_OK_AND_ASSIGN(auto batch, maybe_batch);
rows += batch->GetColumnByName("str")->length();
}
ASSERT_EQ(rows, 4);
}
{
// These options completely override the default ones
auto fragment_scan_options = std::make_shared<CsvFragmentScanOptions>();
fragment_scan_options->read_options.block_size = 1 << 22;
opts_->fragment_scan_options = fragment_scan_options;
int64_t rows = 0;
for (auto maybe_batch : Batches(fragment.get())) {
ASSERT_OK_AND_ASSIGN(auto batch, maybe_batch);
rows += batch->GetColumnByName("header_skipped")->length();
}
ASSERT_EQ(rows, 5);
}
}

TEST_P(TestCsvFileFormat, ScanRecordBatchReaderWithVirtualColumn) {
auto source = GetFileSource(R"(f64
1.0
Expand Down
6 changes: 6 additions & 0 deletions cpp/src/arrow/dataset/scanner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,12 @@ Status ScannerBuilder::Pool(MemoryPool* pool) {
return Status::OK();
}

Status ScannerBuilder::FragmentScanOptions(
std::shared_ptr<dataset::FragmentScanOptions> fragment_scan_options) {
scan_options_->fragment_scan_options = std::move(fragment_scan_options);
return Status::OK();
}

Result<std::shared_ptr<Scanner>> ScannerBuilder::Finish() {
if (!scan_options_->projection.IsBound()) {
RETURN_NOT_OK(Project(scan_options_->dataset_schema->field_names()));
Expand Down
3 changes: 3 additions & 0 deletions cpp/src/arrow/dataset/scanner.h
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,9 @@ class ARROW_DS_EXPORT ScannerBuilder {
/// \brief Set the pool from which materialized and scanned arrays will be allocated.
Status Pool(MemoryPool* pool);

/// \brief Set fragment-specific scan options.
Status FragmentScanOptions(std::shared_ptr<FragmentScanOptions> fragment_scan_options);
Copy link
Member

@jorisvandenbossche jorisvandenbossche Mar 18, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just ScanOptions instead of FragmentScanOptions might be more descriptive? (I find the "fragment" in it a bit confusing) Because it's not that this can be set for each fragment. It's the same for all fragments for one scan.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, but we already have a ScanOptions of course ;)
Then maybe FormatScanOptions? Since it are format-specific options?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's fair, but that does collide with ScanOptions itself, unless you mean just the naming of the builder method?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see…hmm, but if we had a hypothetical FlightFragment, we'd still want to have scan options specific to that fragment, right?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That would be "FlightScanOptios" then?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I think the slight niggle I have there is that there wouldn't be a corresponding 'Flight(File)Format'. Maybe 'PerScanOptions'? But it's not a big deal and FormatScanOptions is OK with me too.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, OK, I see now that the "Fragment" in the name was meant to mean the general "fragment type", while I interpreted it as the "single fragment".
Anyway, this is more a nitpicky remark, not too important ;)


/// \brief Return the constructed now-immutable Scanner object
Result<std::shared_ptr<Scanner>> Finish();

Expand Down
3 changes: 3 additions & 0 deletions cpp/src/arrow/dataset/type_fwd.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ class Fragment;
using FragmentIterator = Iterator<std::shared_ptr<Fragment>>;
using FragmentVector = std::vector<std::shared_ptr<Fragment>>;

class FragmentScanOptions;

class FileSource;
class FileFormat;
class FileFragment;
Expand All @@ -58,6 +60,7 @@ struct FileSystemDatasetWriteOptions;
class InMemoryDataset;

class CsvFileFormat;
struct CsvFragmentScanOptions;

class IpcFileFormat;
class IpcFileWriter;
Expand Down
17 changes: 17 additions & 0 deletions python/pyarrow/_csv.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,26 @@ from pyarrow.includes.libarrow cimport *
from pyarrow.lib cimport _Weakrefable


cdef class ConvertOptions(_Weakrefable):
cdef:
CCSVConvertOptions options

@staticmethod
cdef ConvertOptions wrap(CCSVConvertOptions options)


cdef class ParseOptions(_Weakrefable):
cdef:
CCSVParseOptions options

@staticmethod
cdef ParseOptions wrap(CCSVParseOptions options)


cdef class ReadOptions(_Weakrefable):
cdef:
CCSVReadOptions options
public object encoding

@staticmethod
cdef ReadOptions wrap(CCSVReadOptions options)
88 changes: 82 additions & 6 deletions python/pyarrow/_csv.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,6 @@ cdef class ReadOptions(_Weakrefable):
The character encoding of the CSV data. Columns that cannot
decode using this encoding can still be read as Binary.
"""
cdef:
CCSVReadOptions options
public object encoding

# Avoid mistakingly creating attributes
__slots__ = ()
Expand Down Expand Up @@ -161,6 +158,40 @@ cdef class ReadOptions(_Weakrefable):
def autogenerate_column_names(self, value):
self.options.autogenerate_column_names = value

def equals(self, ReadOptions other):
return (
self.use_threads == other.use_threads and
self.block_size == other.block_size and
self.skip_rows == other.skip_rows and
self.column_names == other.column_names and
self.autogenerate_column_names ==
other.autogenerate_column_names and
self.encoding == other.encoding
)

@staticmethod
cdef ReadOptions wrap(CCSVReadOptions options):
out = ReadOptions()
out.options = options
out.encoding = 'utf8' # No way to know this
return out

def __getstate__(self):
return (self.use_threads, self.block_size, self.skip_rows,
self.column_names, self.autogenerate_column_names,
self.encoding)

def __setstate__(self, state):
(self.use_threads, self.block_size, self.skip_rows,
self.column_names, self.autogenerate_column_names,
self.encoding) = state

def __eq__(self, other):
try:
return self.equals(other)
except TypeError:
return False


cdef class ParseOptions(_Weakrefable):
"""
Expand Down Expand Up @@ -320,6 +351,12 @@ cdef class ParseOptions(_Weakrefable):
self.escape_char, self.newlines_in_values,
self.ignore_empty_lines) = state

def __eq__(self, other):
try:
return self.equals(other)
except TypeError:
return False


cdef class _ISO8601(_Weakrefable):
"""
Expand Down Expand Up @@ -391,9 +428,6 @@ cdef class ConvertOptions(_Weakrefable):
`column_types`, or null by default).
This option is ignored if `include_columns` is empty.
"""
cdef:
CCSVConvertOptions options

# Avoid mistakingly creating attributes
__slots__ = ()

Expand Down Expand Up @@ -603,6 +637,48 @@ cdef class ConvertOptions(_Weakrefable):

self.options.timestamp_parsers = move(c_parsers)

@staticmethod
cdef ConvertOptions wrap(CCSVConvertOptions options):
out = ConvertOptions()
out.options = options
return out

def equals(self, ConvertOptions other):
return (
self.check_utf8 == other.check_utf8 and
self.column_types == other.column_types and
self.null_values == other.null_values and
self.true_values == other.true_values and
self.false_values == other.false_values and
self.timestamp_parsers == other.timestamp_parsers and
self.strings_can_be_null == other.strings_can_be_null and
self.auto_dict_encode == other.auto_dict_encode and
self.auto_dict_max_cardinality ==
other.auto_dict_max_cardinality and
self.include_columns == other.include_columns and
self.include_missing_columns == other.include_missing_columns
)

def __getstate__(self):
return (self.check_utf8, self.column_types, self.null_values,
self.true_values, self.false_values, self.timestamp_parsers,
self.strings_can_be_null, self.auto_dict_encode,
self.auto_dict_max_cardinality, self.include_columns,
self.include_missing_columns)

def __setstate__(self, state):
(self.check_utf8, self.column_types, self.null_values,
self.true_values, self.false_values, self.timestamp_parsers,
self.strings_can_be_null, self.auto_dict_encode,
self.auto_dict_max_cardinality, self.include_columns,
self.include_missing_columns) = state

def __eq__(self, other):
try:
return self.equals(other)
except TypeError:
return False


cdef _get_reader(input_file, ReadOptions read_options,
shared_ptr[CInputStream]* out):
Expand Down
Loading