Skip to content

Commit

Permalink
Add fromColumn, rowsWhere, rowsAt & Column class
Browse files Browse the repository at this point in the history
  • Loading branch information
w2sv committed Sep 13, 2022
1 parent 9ca80f6 commit dd357ee
Show file tree
Hide file tree
Showing 6 changed files with 155 additions and 54 deletions.
1 change: 1 addition & 0 deletions lib/koala.dart
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
library koala;

export 'src/dataframe.dart';
export 'src/column.dart';
100 changes: 100 additions & 0 deletions lib/src/column.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import 'package:collection/collection.dart';

import 'list_extensions/extended_list_base.dart';

typedef Mask = List<bool>;

extension MaskExtensions on Mask{
Mask operator&(Mask other) =>
IterableZip([this, other]).map((e) => e.first && e.last).toList();

Mask operator|(Mask other) =>
IterableZip([this, other]).map((e) => e.first | e.last).toList();

Mask operator^(Mask other) =>
IterableZip([this, other]).map((e) => e.first ^ e.last).toList();
}

class Column<E> extends ExtendedListBase<E>{
Column(List<E> records): super(records);

Column<T> cast<T>() =>
Column(super.cast<T>());

// ************* count *****************

/// Count number of occurrences of [element] of the column [colName].
int count(E object) =>
where((element) => element == object).length;

/// Count number of occurrences of values, corresponding to the column [colName],
/// equaling any element contained by [pool].
int countElementOccurrencesOf(Set<E?> pool) =>
where((element) => pool.contains(element)).length;

// ************* null freeing *************

Column<E> nullFreed({E? replaceWith = null}) =>
Column<E>(nullFreedIterable(replaceWith: replaceWith).toList());

Iterable<E> nullFreedIterable({E? replaceWith = null}) =>
replaceWith == null
? where((element) => element != null)
: map((e) => e ?? replaceWith);

// ****************** transformation ******************

List<num> cumSum() => _nullFreedNums().fold(
[],
(sums, element) =>
sums..add(sums.isEmpty ? element : sums.last + element));

// **************** accumulation ****************

double mean({bool treatNullsAsZeros = true}) =>
_nullFreedNums(treatNullsAsZeros: treatNullsAsZeros).average;

num max() =>
_nullFreedNums().max;

num min() =>
_nullFreedNums().min;

num sum() =>
_nullFreedNums().sum;

Iterable<num> _nullFreedNums({bool treatNullsAsZeros = false}) =>
cast<num?>().nullFreedIterable(replaceWith: treatNullsAsZeros ? 0.0 : null)
.cast<num>();

// ***************** masks *******************

Mask equals(E reference) =>
map((element) => element == reference).toList().cast<bool>();

Mask unequals(E reference) =>
map((element) => element != reference).toList().cast<bool>();

Mask isIn(Set<E> pool) =>
map((element) => pool.contains(element)).toList().cast<bool>();

Mask isNotIn(Set<E> pool) =>
map((element) => !pool.contains(element)).toList().cast<bool>();

Mask toMask(bool Function(E) test) =>
map(test).toList().cast<bool>();

// ****************** numerical column masks *********************

Mask operator<(num reference) =>
cast<num>().map((element) => element < reference).toList().cast<bool>();

Mask operator>(num reference) =>
cast<num>().map((element) => element > reference).toList().cast<bool>();

Mask operator<=(num reference) =>
cast<num>().map((element) => element <= reference).toList().cast<bool>();

Mask operator>=(num reference) =>
cast<num>().map((element) => element >= reference).toList().cast<bool>();
}
81 changes: 30 additions & 51 deletions lib/src/dataframe.dart
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@ import 'package:collection/collection.dart';
import 'package:csv/csv.dart';
import 'package:jiffy/jiffy.dart';

import 'column.dart';
import 'list_extensions/extended_list_base.dart';
import 'list_extensions/position_tracking_list.dart';
import 'utils/list.dart';
import 'utils/iterable.dart';
import 'utils/list.dart';

typedef Record = Object?;
typedef RecordRowMap = Map<String, Record>;
Expand Down Expand Up @@ -44,14 +45,14 @@ class DataFrame extends ExtendedListBase<RecordRow> {
}
}

/// Build a dataframe from a list of [rowMaps], e.g.
/// Builds a dataframe from a list of [rowMaps], e.g.
/// [{'col1': 420, 'col2': 69},
/// {'col1': 666, 'col2': 1470}]
DataFrame.fromRowMaps(List<RecordRowMap> rowMaps)
: this._columnNames = PositionTrackingList(rowMaps.first.keys.toList()),
super(rowMaps.map((e) => e.values.toList()).toList());

/// Returns an empty dataframe
/// Returns an empty dataframe.
DataFrame.empty()
: this._columnNames = PositionTrackingList([]),
super([]);
Expand Down Expand Up @@ -226,29 +227,44 @@ class DataFrame extends ExtendedListBase<RecordRow> {
}
}

List<int> get shape => [length, nColumns];

// ************* data access *****************

/// Enables (typed) column access.
///
/// If [start] and/or [end] are specified the column will be sliced respectively,
/// after which [includeRecord], if specified, may determine which elements are to be included.
List<T> call<T>(String colName,
Column<T> call<T>(String colName,
{int start = 0, int? end, bool Function(T)? includeRecord}) {
Iterable<T> column =
sublist(start, end).map((row) => row[columnIndex(colName)]).cast<T>();
if (includeRecord != null) {
column = column.where(includeRecord);
}
return column.toList();
return Column(column.toList());
}

DataFrame fromColumns(List<String> columnNames) =>
DataFrame._copied(
PositionTrackingList(columnNames),
columnNames.map((e) => this(e)).transposed()
);

/// Returns an iterable over the column data.
Iterable<RecordCol> columnIterable() => _columnNames.map((e) => this(e));
Iterable<Column> columns() =>
_columnNames.map((e) => this(e));

/// Grab a (typed) record sitting at dataframe[rowIndex][colName].
T record<T>(int rowIndex, String colName) =>
this[rowIndex][columnIndex(colName)] as T;

DataFrame rowsAt(Iterable<int> indices) =>
DataFrame._copied(_columnNames, indices.map((e) => this[e]).toList());

DataFrame rowsWhere(List<bool> mask) =>
DataFrame._copied(_columnNames, applyMask(mask).toList());

// **************** manipulation ******************

/// Add a new column to the end of the dataframe. The [records] have to be of the same length
Expand Down Expand Up @@ -282,7 +298,7 @@ class DataFrame extends ExtendedListBase<RecordRow> {
/// Transform the values corresponding to [name] as per [transformElement] in-place.
void transformColumn(
String name, dynamic Function(dynamic element) transformElement) {
this(name).asMap().forEach((i, element) {
this(name).forEachIndexed((i, element) {
this[i][columnIndex(name)] = transformElement(element);
});
}
Expand All @@ -292,7 +308,7 @@ class DataFrame extends ExtendedListBase<RecordRow> {
void addRowFromMap(RecordRowMap rowMap) =>
add([for (final name in _columnNames) rowMap[name]]);

// ************ alternate representations *************
// ************ map representations *************

/// Returns a list of {columnName: value} Map-representations for each row.
List<RecordRowMap> rowMaps() =>
Expand Down Expand Up @@ -339,7 +355,7 @@ class DataFrame extends ExtendedListBase<RecordRow> {
DataFrame sortedBy(String colName,
{bool ascending = true,
bool nullsFirst = true,
CompareRecords? compareRecords}) =>
Comparator<Record>? compareRecords}) =>
DataFrame._copied(
_columnNames,
_sort(colName,
Expand All @@ -355,7 +371,7 @@ class DataFrame extends ExtendedListBase<RecordRow> {
void sortBy(String colName,
{bool ascending = true,
bool nullsFirst = true,
CompareRecords? compareRecords}) =>
Comparator<Record>? compareRecords}) =>
_sort(colName,
inPlace: true,
ascending: ascending,
Expand All @@ -366,15 +382,15 @@ class DataFrame extends ExtendedListBase<RecordRow> {
{required bool inPlace,
required bool ascending,
required bool nullsFirst,
required CompareRecords? compareRecords}) {
required Comparator<Record>? compareRecords}) {
final index = columnIndex(colName);
return (inPlace ? this : copy2D(this))
..sort((a, b) => _compareRecords(
a[index], b[index], ascending, nullsFirst, compareRecords));
}

static int _compareRecords(Record a, Record b, bool ascending,
bool nullsFirst, CompareRecords? compare) {
bool nullsFirst, Comparator<Record>? compare) {
// return compare result if function given
if (compare != null) return compare(a, b);

Expand Down Expand Up @@ -413,7 +429,7 @@ class DataFrame extends ExtendedListBase<RecordRow> {
final indexColumnDelimiter = ' | ';
final consecutiveElementDelimiter = ' ';

final List<int> columnWidths = columnIterable()
final List<int> columnWidths = columns()
.mapIndexed((index, col) =>
(col + [columnNames[index]]).map((el) => el.toString().length).max)
.toList();
Expand All @@ -432,41 +448,4 @@ class DataFrame extends ExtendedListBase<RecordRow> {
.join(consecutiveElementDelimiter))
]).map((e) => e.join(indexColumnDelimiter)).join('\n');
}
}

extension RecordColumnExtensions<T> on List<T> {
/// Count number of occurrences of [element] of the column [colName].
int count(T object) => where((element) => element == object).length;

/// Count number of occurrences of values, corresponding to the column [colName],
/// equaling any element contained by [pool].
int countElementOccurrencesOf(Set<T> pool) =>
where((element) => pool.contains(element)).length;

Iterable<T> withoutNulls({T? nullReplacement = null}) =>
nullReplacement == null
? where((element) => element != null)
: map((e) => e ?? nullReplacement);
}

extension NumericalRecordColumnExtensions on List<num?> {
List<double> cumSum() => _nullPurgedDoubles().fold(
[],
(sums, element) =>
sums..add(sums.isEmpty ? element : sums.last + element));

double mean({bool treatNullsAsZeros = true}) {
final nullPurged = _nullPurgedDoubles(treatNullsAsZeros: treatNullsAsZeros);
return nullPurged.sum / nullPurged.length;
}

List<double> _nullPurgedDoubles({bool treatNullsAsZeros = false}) =>
(withoutNulls(nullReplacement: treatNullsAsZeros ? 0.0 : null))
.map((e) => e!.toDouble())
.toList();
}

/// A function that compares two objects for sorting. It will return -1 if a
/// should be ordered before b, 0 if a and b are equal wrt to ordering, and 1
/// if a should be ordered after b.
typedef CompareRecords = int Function(Record a, Record b);
}
10 changes: 10 additions & 0 deletions lib/src/utils/iterable.dart
Original file line number Diff line number Diff line change
@@ -1,3 +1,13 @@
import 'package:collection/collection.dart';

extension IterableExtensions<E> on Iterable<E> {
List<E> toFixedLengthList() => toList(growable: false);

Iterable<E> applyMask(List<bool> mask) =>
whereIndexed((index, _) => mask[index]);
}

extension IterableIterableExtensions<E> on Iterable<Iterable<E>>{
List<List<E>> transposed() =>
IterableZip(this).map((e) => e.toList()).toList();
}
8 changes: 5 additions & 3 deletions test/df_test.dart
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,9 @@ void main() {
expect(df<int?>('col1').runtimeType.toString(), 'List<int?>');

// columnIterable
expect(df.columnIterable().toList(), [[1, 1, null], [2, 1, 8]]);
expect(df.columns().toList(), [[1, 1, null], [2, 1, 8]]);

df.rowsWhere((df('col1') > 6) & (df('col2') <= 5));

// record
expect(df.record(2, 'col1'), null);
Expand Down Expand Up @@ -220,8 +222,8 @@ void main() {
expect(df('col1').count(null), 1);
expect(df('col2').countElementOccurrencesOf({1, 2}), 2);

expect(df('col1').withoutNulls(), [1, 1]);
expect(df('col1').withoutNulls(nullReplacement: 69), [1, 1, 69]);
expect(df('col1').nullFreed(), [1, 1]);
expect(df('col1').nullFreed(replaceWith: 69), [1, 1, 69]);
});

test('sorting', () async {
Expand Down
9 changes: 9 additions & 0 deletions test/iterable_test.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import 'package:koala/src/utils/iterable.dart';
import 'package:test/expect.dart';
import 'package:test/scaffolding.dart';

void main(){
test('transposed', (){
expect([[1, 2, 3], [2, 3, 4]].transposed(), [[1, 2], [2, 3], [3, 4]]);
});
}

0 comments on commit dd357ee

Please sign in to comment.