Skip to content

8278114: New addnode ideal optimization: converting "x + x" into "x << 1" #6675

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 17 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 24 additions & 8 deletions src/hotspot/share/opto/mulnode.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 1997, 2021, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 1997, 2022, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
Expand Down Expand Up @@ -789,16 +789,24 @@ Node *LShiftINode::Ideal(PhaseGVN *phase, bool can_reshape) {
return NULL;
}

// Left input is an add of a constant?
// Left input is an add?
Node *add1 = in(1);
int add1_op = add1->Opcode();
if( add1_op == Op_AddI ) { // Left input is an add?
assert( add1 != add1->in(1), "dead loop in LShiftINode::Ideal" );
const TypeInt *t12 = phase->type(add1->in(2))->isa_int();
if( t12 && t12->is_con() ){ // Left input is an add of a con?
// Transform is legal, but check for profit. Avoid breaking 'i2s'
// and 'i2b' patterns which typically fold into 'StoreC/StoreB'.
if( con < 16 ) {

// Transform is legal, but check for profit. Avoid breaking 'i2s'
// and 'i2b' patterns which typically fold into 'StoreC/StoreB'.
if( con < 16 ) {
// Left input is an add of the same number?
if (add1->in(1) == add1->in(2)) {
// Convert "(x + x) << c0" into "x << (c0 + 1)"
return new LShiftINode(add1->in(1), phase->intcon(con + 1));
}

// Left input is an add of a constant?
const TypeInt *t12 = phase->type(add1->in(2))->isa_int();
if( t12 && t12->is_con() ){ // Left input is an add of a con?
// Compute X << con0
Node *lsh = phase->transform( new LShiftINode( add1->in(1), in(2) ) );
// Compute X<<con0 + (con1<<con0)
Expand Down Expand Up @@ -902,12 +910,20 @@ Node *LShiftLNode::Ideal(PhaseGVN *phase, bool can_reshape) {
return NULL;
}

// Left input is an add of a constant?
// Left input is an add?
Node *add1 = in(1);
int add1_op = add1->Opcode();
if( add1_op == Op_AddL ) { // Left input is an add?
// Avoid dead data cycles from dead loops
assert( add1 != add1->in(1), "dead loop in LShiftLNode::Ideal" );

// Left input is an add of the same number?
if (add1->in(1) == add1->in(2)) {
// Convert "(x + x) << c0" into "x << (c0 + 1)"
return new LShiftLNode(add1->in(1), phase->intcon(con + 1));
}

// Left input is an add of a constant?
const TypeLong *t12 = phase->type(add1->in(2))->isa_long();
if( t12 && t12->is_con() ){ // Left input is an add of a con?
// Compute X << con0
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
/*
* Copyright (c) 2022, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License version 2 only, as
* published by the Free Software Foundation.
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
* You should have received a copy of the GNU General Public License version
* 2 along with this work; if not, write to the Free Software Foundation,
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
*
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
* or visit www.oracle.com if you need additional information or have any
* questions.
*/

package compiler.c2.irTests;

import jdk.test.lib.Asserts;
import compiler.lib.ir_framework.*;

/*
* @test
* @bug 8278114
* @summary Test that transformation from (x + x) >> c to x >> (c + 1) works as intended.
* @library /test/lib /
* @run driver compiler.c2.irTests.TestIRLShiftIdeal_XPlusX_LShiftC
*/
public class TestIRLShiftIdeal_XPlusX_LShiftC {

private static final int[] INT_IN = {
-10, -2, -1, 0, 1, 2, 10,
0x8000_0000, 0x7FFF_FFFF, 0x5678_1234,
};

private static final int[][] INT_OUT = {
// Do testInt0(x) for each x in INT_IN
{
-160, -32, -16, 0, 16, 32, 160,
0x0000_0000, 0xFFFF_FFF0, 0x6781_2340,
},

// Do testInt1(x) for each x in INT_IN
{
-10485760, -2097152, -1048576, 0, 1048576, 2097152, 10485760,
0x0000_0000, 0xFFF0_0000, 0x2340_0000,
},
};

private static final long[] LONG_IN = {
-10L, -2L, -1L, 0L, 1L, 2L, 10L,
0x8000_0000_0000_0000L, 0x7FFF_FFFF_FFFF_FFFFL, 0x5678_1234_4321_8765L,
};

private static final long[][] LONG_OUT = {
// Do testLong0(x) for each x in LONG_IN
{
-160L, -32L, -16L, 0L, 16L, 32L, 160L,
0x0000_0000_0000_0000L, 0xFFFF_FFFF_FFFF_FFF0L, 0x6781_2344_3218_7650L,
},

// Do testLong1(x) for each x in LONG_IN
{
-687194767360L, -137438953472L, -68719476736L, 0L, 68719476736L, 137438953472L, 687194767360L,
0x0000_0000_0000_0000L, 0xFFFF_FFF0_0000_0000L, 0x3218_7650_0000_0000L,
},
};

public static void main(String[] args) {
TestFramework.run();
}

@Test
@IR(failOn = {IRNode.ADD_I, IRNode.MUL_I})
@IR(counts = {IRNode.LSHIFT_I, "1"})
public int testInt0(int x) {
return (x + x) << 3; // transformed to x << 4
}

@Run(test = "testInt0")
public void checkTestInt0(RunInfo info) {
assertC2Compiled(info);
for (int i = 0; i < INT_IN.length; i++) {
Asserts.assertEquals(INT_OUT[0][i], testInt0(INT_IN[i]));
}
}

@Test
@IR(failOn = {IRNode.MUL_I})
@IR(counts = {IRNode.LSHIFT_I, "1",
IRNode.ADD_I, "1"})
public int testInt1(int x) {
return (x + x) << 19; // no transformation because 19 is
// greater than 16 (see implementation
// in LShiftINode::Ideal)
}

@Run(test = "testInt1")
public void checkTestInt1(RunInfo info) {
assertC2Compiled(info);
for (int i = 0; i < INT_IN.length; i++) {
Asserts.assertEquals(INT_OUT[1][i], testInt1(INT_IN[i]));
}
}

@Test
@IR(failOn = {IRNode.ADD_L, IRNode.MUL_L})
@IR(counts = {IRNode.LSHIFT_L, "1"})
public long testLong0(long x) {
return (x + x) << 3; // transformed to x << 4
}

@Run(test = "testLong0")
public void checkTestLong0(RunInfo info) {
assertC2Compiled(info);
for (int i = 0; i < LONG_IN.length; i++) {
Asserts.assertEquals(LONG_OUT[0][i], testLong0(LONG_IN[i]));
}
}

@Test
@IR(failOn = {IRNode.ADD_L, IRNode.MUL_L})
@IR(counts = {IRNode.LSHIFT_L, "1"})
public long testLong1(long x) {
return (x + x) << 35; // transformed to x << 36
}

@Run(test = "testLong1")
public void checkTestLong1(RunInfo info) {
assertC2Compiled(info);
for (int i = 0; i < LONG_IN.length; i++) {
Asserts.assertEquals(LONG_OUT[1][i], testLong1(LONG_IN[i]));
}
}

private void assertC2Compiled(RunInfo info) {
// Test VM allows C2 to work
Asserts.assertTrue(info.isC2CompilationEnabled());
if (!info.isWarmUp()) {
// C2 compilation happens
Asserts.assertTrue(info.isTestC2Compiled());
}
}
}
4 changes: 3 additions & 1 deletion test/hotspot/jtreg/compiler/lib/ir_framework/IRNode.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2022, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
Expand Down Expand Up @@ -141,6 +141,8 @@ public class IRNode {
public static final String ADD_L = START + "AddL" + MID + END;
public static final String SUB_I = START + "SubI" + MID + END;
public static final String SUB_L = START + "SubL" + MID + END;
public static final String MUL_I = START + "MulI" + MID + END;
public static final String MUL_L = START + "MulL" + MID + END;
public static final String CONV_I2L = START + "ConvI2L" + MID + END;

public static final String VECTOR_CAST_B2X = START + "VectorCastB2X" + MID + END;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
/*
* Copyright (c) 2022, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License version 2 only, as
* published by the Free Software Foundation.
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
* You should have received a copy of the GNU General Public License version
* 2 along with this work; if not, write to the Free Software Foundation,
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
*
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
* or visit www.oracle.com if you need additional information or have any
* questions.
*/

package org.openjdk.bench.vm.compiler;

import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.BenchmarkMode;
import org.openjdk.jmh.annotations.CompilerControl;
import org.openjdk.jmh.annotations.Fork;
import org.openjdk.jmh.annotations.Measurement;
import org.openjdk.jmh.annotations.Mode;
import org.openjdk.jmh.annotations.Scope;
import org.openjdk.jmh.annotations.Setup;
import org.openjdk.jmh.annotations.State;
import org.openjdk.jmh.annotations.OutputTimeUnit;
import org.openjdk.jmh.annotations.Warmup;
import org.openjdk.jmh.infra.Blackhole;

import java.util.concurrent.TimeUnit;

/**
* Tests transformation from "(x + x) << c" to "x << (c + 1)".
* <p>
* This benchmark needs to be run with {@code
* JAVA_OPTIONS=-Djmh.blackhole.mode=COMPILER} to force using compiler
* mode blackhole, which is enabled by default and thus not necessary
* since JMH 1.34.
*/
@BenchmarkMode(Mode.AverageTime)
@OutputTimeUnit(TimeUnit.NANOSECONDS)
@State(Scope.Thread)
@Warmup(iterations = 20, time = 1, timeUnit = TimeUnit.SECONDS)
@Measurement(iterations = 20, time = 1, timeUnit = TimeUnit.SECONDS)
@Fork(value = 3)
public class LShiftIdeal_XPlusX_LShiftC {

private int iFld = 4711;

private long lFld = 4711 * 4711 * 4711;

private final int SIZE = 10;

@Benchmark
public void baselineInt(Blackhole bh) {
for (int i = 0; i < SIZE; i++) {
bh.consume(iFld);
}
}

@Benchmark
public void baselineLong(Blackhole bh) {
for (int i = 0; i < SIZE; i++) {
bh.consume(lFld);
}
}

// Convert "(x + x) << 10" into "x << 11" for int.
// (x << 11) >>> 11 is then further converted into zero-extends.
@Benchmark
public void testInt(Blackhole bh) {
for (int i = 0; i < SIZE; i++) {
bh.consume(((iFld + iFld) << 10) >>> 11);
}
}

// Convert "(x + x) << 40" into "x << 41" for long.
// (x << 41) >>> 41 is then further converted into zero-extends.
@Benchmark
public void testLong(Blackhole bh) {
for (int i = 0; i < SIZE; i++) {
bh.consume(((lFld + lFld) << 40) >>> 41);
}
}
}