diff --git a/src/Microsoft.Data.Analysis/DataFrame.Join.cs b/src/Microsoft.Data.Analysis/DataFrame.Join.cs index d5a1278371..381268dee2 100644 --- a/src/Microsoft.Data.Analysis/DataFrame.Join.cs +++ b/src/Microsoft.Data.Analysis/DataFrame.Join.cs @@ -252,9 +252,9 @@ public DataFrame Merge(DataFrame other, string leftJoinColumn, string righ // Hash the column with the smaller RowCount long leftRowCount = Rows.Count; long rightRowCount = other.Rows.Count; - DataFrame longerDataFrame = leftRowCount <= rightRowCount ? other : this; - DataFrame shorterDataFrame = ReferenceEquals(longerDataFrame, this) ? other : this; - DataFrameColumn hashColumn = (leftRowCount <= rightRowCount) ? Columns[leftJoinColumn] : other.Columns[rightJoinColumn]; + + var leftColumnIsSmaller = (leftRowCount <= rightRowCount); + DataFrameColumn hashColumn = leftColumnIsSmaller ? Columns[leftJoinColumn] : other.Columns[rightJoinColumn]; DataFrameColumn otherColumn = ReferenceEquals(hashColumn, Columns[leftJoinColumn]) ? other.Columns[rightJoinColumn] : Columns[leftJoinColumn]; Dictionary> multimap = hashColumn.GroupColumnValues(); @@ -270,23 +270,21 @@ public DataFrame Merge(DataFrame other, string leftJoinColumn, string righ { if (hashColumn[row] == null) { - leftRowIndices.Append(row); - rightRowIndices.Append(i); + leftRowIndices.Append(leftColumnIsSmaller ? row : i); + rightRowIndices.Append(leftColumnIsSmaller ? i : row); } } else { if (hashColumn[row] != null) { - leftRowIndices.Append(row); - rightRowIndices.Append(i); + leftRowIndices.Append(leftColumnIsSmaller ? row : i); + rightRowIndices.Append(leftColumnIsSmaller ? i : row); } } } } } - leftDataFrame = shorterDataFrame; - rightDataFrame = longerDataFrame; } else if (joinAlgorithm == JoinAlgorithm.FullOuter) { @@ -366,4 +364,5 @@ public DataFrame Merge(DataFrame other, string leftJoinColumn, string righ } } + } diff --git a/test/Microsoft.Data.Analysis.Tests/DataFrameTests.cs b/test/Microsoft.Data.Analysis.Tests/DataFrameTests.cs index 300babbffb..72072fd533 100644 --- a/test/Microsoft.Data.Analysis.Tests/DataFrameTests.cs +++ b/test/Microsoft.Data.Analysis.Tests/DataFrameTests.cs @@ -1579,6 +1579,25 @@ public void TestSample() Assert.Throws(()=> df.Sample(13)); } + [Theory] + [InlineData(1, 2)] + [InlineData(2, 1)] + public void TestDataCorrectnessForInnerMerge(int leftCount, int rightCount) + { + DataFrame left = MakeDataFrameWithNumericColumns(leftCount, false); + DataFrameColumn leftStringColumn = new StringDataFrameColumn("String", Enumerable.Range(0, leftCount).Select(x => "Left")); + left.Columns.Insert(left.Columns.Count, leftStringColumn); + + DataFrame right = MakeDataFrameWithNumericColumns(rightCount, false); + DataFrameColumn rightStringColumn = new StringDataFrameColumn("String", Enumerable.Range(0, rightCount).Select(x => "Right")); + right.Columns.Insert(right.Columns.Count, rightStringColumn); + + DataFrame merge = left.Merge(right, "Int", "Int", joinAlgorithm: JoinAlgorithm.Inner); + + Assert.Equal("Left", (string)merge.Columns["String_left"][0]); + Assert.Equal("Right", (string)merge.Columns["String_right"][0]); + } + [Fact] public void TestMerge() {