Skip to content

Commit d82cd7c

Browse files
authored
Add bindings for RowImpl in time series SequentialTransformerBase (#3875)
* Add bindings for rowimp in time series. * Add bindings for rowimp in time series. * Remove assert.
1 parent f0f34ac commit d82cd7c

File tree

2 files changed

+16
-10
lines changed

2 files changed

+16
-10
lines changed

src/Microsoft.ML.TimeSeries/SequentialTransformerBase.cs

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,6 @@ private protected virtual void CloneCore(TState state)
273273
internal TState StateRef { get; set; }
274274

275275
public int StateRefCount;
276-
277276
/// <summary>
278277
/// The main constructor for the sequential transform
279278
/// </summary>
@@ -492,7 +491,7 @@ DataViewRow IRowToRowMapper.GetRow(DataViewRow input, IEnumerable<DataViewSchema
492491
var active = RowCursorUtils.FromColumnsToPredicate(activeColumns, OutputSchema);
493492
var getters = _mapper.CreateGetters(input, active, out Action disposer);
494493
var pingers = _mapper.CreatePinger(input, active, out Action pingerDisposer);
495-
return new RowImpl(_bindings.Schema, input, getters, pingers, disposer + pingerDisposer);
494+
return new RowImpl(_bindings, input, getters, pingers, disposer + pingerDisposer);
496495
}
497496
}
498497

@@ -504,23 +503,24 @@ private sealed class RowImpl : StatefulRow
504503
private readonly Action<long> _pinger;
505504
private readonly Action _disposer;
506505
private bool _disposed;
506+
private readonly ColumnBindings _bindings;
507507

508508
public override DataViewSchema Schema => _schema;
509509

510510
public override long Position => _input.Position;
511511

512512
public override long Batch => _input.Batch;
513513

514-
public RowImpl(DataViewSchema schema, DataViewRow input, Delegate[] getters, Action<long> pinger, Action disposer)
514+
public RowImpl(ColumnBindings bindings, DataViewRow input, Delegate[] getters, Action<long> pinger, Action disposer)
515515
{
516-
Contracts.CheckValue(schema, nameof(schema));
516+
Contracts.CheckValue(bindings, nameof(bindings));
517517
Contracts.CheckValue(input, nameof(input));
518-
Contracts.Check(Utils.Size(getters) == schema.Count);
519-
_schema = schema;
518+
_schema = bindings.Schema;
520519
_input = input;
521520
_getters = getters ?? new Delegate[0];
522521
_pinger = pinger;
523522
_disposer = disposer;
523+
_bindings = bindings;
524524
}
525525

526526
protected override void Dispose(bool disposing)
@@ -538,9 +538,13 @@ public override ValueGetter<DataViewRowId> GetIdGetter()
538538

539539
public override ValueGetter<T> GetGetter<T>(DataViewSchema.Column column)
540540
{
541-
Contracts.CheckParam(column.Index < _getters.Length, nameof(column), "Invalid col value in GetGetter");
541+
bool isSrc;
542+
int index = _bindings.MapColumnIndex(out isSrc, column.Index);
543+
if (isSrc)
544+
return _input.GetGetter<T>(_input.Schema[index]);
545+
Contracts.CheckParam(index < _getters.Length, nameof(column), "Invalid col value in GetGetter");
542546
Contracts.Check(IsColumnActive(column));
543-
var fn = _getters[column.Index] as ValueGetter<T>;
547+
var fn = _getters[index] as ValueGetter<T>;
544548
if (fn == null)
545549
throw Contracts.Except("Unexpected TValue in GetGetter");
546550
return fn;
@@ -554,8 +558,9 @@ public override Action<long> GetPinger() =>
554558
/// </summary>
555559
public override bool IsColumnActive(DataViewSchema.Column column)
556560
{
557-
Contracts.Check(column.Index < _getters.Length);
558-
return _getters[column.Index] != null;
561+
int index = _bindings.MapColumnIndex(out bool isSrc, column.Index);
562+
Contracts.Check(index < _getters.Length);
563+
return _getters[index] != null;
559564
}
560565
}
561566

test/Microsoft.ML.TimeSeries.Tests/TimeSeriesDirectApi.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,7 @@ public void ChangePointDetectionWithSeasonalityPredictionEngine()
249249

250250
// Pipeline.
251251
var pipeline = ml.Transforms.Text.FeaturizeText("Text_Featurized", "Text")
252+
.Append(ml.Transforms.Conversion.ConvertType("Value", "Value", DataKind.Single))
252253
.Append(new SsaChangePointEstimator(ml, new SsaChangePointDetector.Options()
253254
{
254255
Confidence = 95,

0 commit comments

Comments
 (0)