Skip to content

Commit 03f11ab

Browse files
committed
Bump xla and fix ci
1 parent 476c60d commit 03f11ab

File tree

7 files changed

+22
-18
lines changed

7 files changed

+22
-18
lines changed

.buildkite/pipeline.yml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ steps:
77
- x86_64
88
os:
99
- macos
10+
python:
11+
- "3.10"
1012
agents:
1113
queue: "juliaecosystem"
1214
os: "{{matrix.os}}"
@@ -71,7 +73,7 @@ steps:
7173
python -m pip install --user numpy wheel
7274
mkdir -p .baztmp
7375
rm -f bazel-bin/*.whl
74-
bazel --output_user_root=`pwd`/.baztmp build :enzyme_ad
76+
HERMETIC_PYTHON_VERSION=${{matrix.python}} bazel --output_user_root=`pwd`/.baztmp build :enzyme_ad
7577
cp bazel-bin/*.whl .
7678
python -m pip install *.whl
7779
python -m pip install "jax[cpu]"

.buildkite/secure_pipeline.yml

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,11 @@ steps:
77
- x86_64
88
os:
99
- macos
10+
python:
11+
- "3.10"
12+
- "3.11"
13+
- "3.12"
14+
- "3.13"
1015
agents:
1116
queue: "juliaecosystem"
1217
os: "{{matrix.os}}"
@@ -59,18 +64,15 @@ steps:
5964
wget https://github.com/bazelbuild/bazel/releases/download/6.2.1/bazel-6.2.1-windows-x86_64.exe
6065
fi
6166
mv bazel* .local/bin/bazel.exe
62-
start /wait "" Miniconda3*.exe /InstallationType=JustMe /RegisterPython=0 /S /D=`pwd`/conda
63-
rm Miniconda*.exe
6467
fi
6568
python -m ensurepip --upgrade
6669
python -m pip install --user numpy wheel
6770
mkdir baztmp
6871
export TAG=`echo $BUILDKITE_TAG | cut -c2-`
6972
sed -i.bak "s~version = \"[0-9.]*\"~version = \"\$TAG\"~g" BUILD
70-
bazel --output_user_root=`pwd`/baztmp build :enzyme_ad
73+
HERMETIC_PYTHON_VERSION=${{matrix.python}} bazel --output_user_root=`pwd`/baztmp build :enzyme_ad
7174
cp bazel-bin/*.whl .
7275
python -m pip install *.whl
73-
cd test && python -m pip install "jax[cpu]" && python test.py && cd ..
7476
python -m pip install --user twine
7577
python -m twine upload *.whl
7678
artifact_paths:

src/enzyme_ad/jax/BUILD

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,6 @@ cc_library(
258258
"@llvm-project//mlir:Pass",
259259
"@llvm-project//mlir:Support",
260260
"@llvm-project//mlir:TensorDialect",
261-
"@llvm-project//mlir:Transforms",
262261
"@llvm-project//mlir:TransformUtils",
263262
"@stablehlo//:reference_ops",
264263
"@stablehlo//:stablehlo_ops",

src/enzyme_ad/jax/Implementations/HLODerivatives.td

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ def PadToSliceStride : GlobalExpr</*needsprimal*/0, /*needsshadow*/0, [{
143143

144144
def ResultDotDim : GlobalExpr</*needsprimal*/0, /*needsshadow*/0, "op.getDotDimensionNumbersAttr()">;
145145
def ResultDotPrec : GlobalExpr</*needsprimal*/0, /*needsshadow*/0, "op.getPrecisionConfigAttr()">;
146-
146+
def ResultDotAlg : GlobalExpr</*needsprimal*/0, /*needsshadow*/0, "op.getAlgorithmAttr()">;
147147

148148
def ShadowLHSDotDim : GlobalExpr</*needsprimal*/0, /*needsshadow*/0, [{
149149
auto existingattr = op.getDotDimensionNumbersAttr();
@@ -572,10 +572,10 @@ def : HLODerivative<"DivOp", (Op $x, $y),
572572

573573
def : HLODerivative<"DotGeneralOp", (Op $lhs, $rhs),
574574
[
575-
(Transpose (TypeOf $lhs), (DotGeneral (ShadowLHSDotRes), (DiffeRet), $rhs, (ShadowLHSDotDim), (ResultDotPrec)), (ShadowLHSTranspose)),
576-
(Transpose (TypeOf $rhs), (DotGeneral (ShadowRHSDotRes), $lhs, (DiffeRet), (ShadowRHSDotDim), (ResultDotPrec)), (ShadowRHSTranspose))
575+
(Transpose (TypeOf $lhs), (DotGeneral (ShadowLHSDotRes), (DiffeRet), $rhs, (ShadowLHSDotDim), (ResultDotPrec), (ResultDotAlg)), (ShadowLHSTranspose)),
576+
(Transpose (TypeOf $rhs), (DotGeneral (ShadowRHSDotRes), $lhs, (DiffeRet), (ShadowRHSDotDim), (ResultDotPrec), (ResultDotAlg)), (ShadowRHSTranspose))
577577
],
578-
(Add (SelectIfActive $lhs, (DotGeneral (ResultTypes), (Shadow $lhs), $rhs, (ResultDotDim), (ResultDotPrec)), (HLOConstantFP<"0">)), (SelectIfActive $rhs, (DotGeneral (ResultTypes), $lhs, (Shadow $rhs), (ResultDotDim), (ResultDotPrec)), (HLOConstantFP<"0">)))
578+
(Add (SelectIfActive $lhs, (DotGeneral (ResultTypes), (Shadow $lhs), $rhs, (ResultDotDim), (ResultDotPrec), (ResultDotAlg)), (HLOConstantFP<"0">)), (SelectIfActive $rhs, (DotGeneral (ResultTypes), $lhs, (Shadow $rhs), (ResultDotDim), (ResultDotPrec), (ResultDotAlg)), (HLOConstantFP<"0">)))
579579
>;
580580

581581
def : HLOInactiveOp<"DynamicIotaOp">;

src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3590,13 +3590,13 @@ struct TransposeDotReorder
35903590
if (!convert) {
35913591
rewriter.replaceOpWithNewOp<stablehlo::DotGeneralOp>(
35923592
op, op.getType(), dot.getRhs(), dot.getLhs(), ndim,
3593-
dot.getPrecisionConfigAttr());
3593+
dot.getPrecisionConfigAttr(), dot.getAlgorithmAttr());
35943594
} else {
35953595
auto middot = rewriter.create<stablehlo::DotGeneralOp>(
35963596
op.getLoc(),
35973597
RankedTensorType::get(op.getType().getShape(),
35983598
dot.getType().getElementType()),
3599-
dot.getRhs(), dot.getLhs(), ndim, dot.getPrecisionConfigAttr());
3599+
dot.getRhs(), dot.getLhs(), ndim, dot.getPrecisionConfigAttr(), dot.getAlgorithmAttr());
36003600
rewriter.replaceOpWithNewOp<stablehlo::ConvertOp>(op, op.getType(),
36013601
middot);
36023602
}
@@ -3808,7 +3808,7 @@ struct DotTranspose : public OpRewritePattern<mlir::stablehlo::DotGeneralOp> {
38083808
rewriter.replaceOpWithNewOp<stablehlo::DotGeneralOp>(
38093809
dot, dot.getType(), lhs_trans ? lhs_trans.getOperand() : dot.getLhs(),
38103810
rhs_trans ? rhs_trans.getOperand() : dot.getRhs(), ndim,
3811-
dot.getPrecisionConfigAttr());
3811+
dot.getPrecisionConfigAttr(), dot.getAlgorithmAttr());
38123812
return success();
38133813
}
38143814
};
@@ -4712,7 +4712,8 @@ struct PadDotGeneral : public OpRewritePattern<mlir::stablehlo::DotGeneralOp> {
47124712
RankedTensorType::get(resultShape, op.getType().getElementType()),
47134713
otherIsLHS ? nextOtherArg : pad.getOperand(),
47144714
otherIsLHS ? pad.getOperand() : nextOtherArg,
4715-
op.getDotDimensionNumbersAttr(), op.getPrecisionConfigAttr());
4715+
op.getDotDimensionNumbersAttr(), op.getPrecisionConfigAttr(),
4716+
op.getAlgorithmAttr(), op.getAlgorithmAttr());
47164717

47174718
if (!resultDimsToPad.empty()) {
47184719
SmallVector<int64_t> low(op.getType().getShape().size(), 0);

src/enzyme_ad/jax/enzyme_call.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -902,7 +902,7 @@ class CpuKernel {
902902
pyargv, mode, lang, xla_runtime, pass_pipeline);
903903

904904
if (!JIT) {
905-
DL = std::make_unique<llvm::DataLayout>(mod.get());
905+
DL = std::make_unique<llvm::DataLayout>(mod->getDataLayoutStr());
906906
auto tJIT =
907907
llvm::orc::LLJITBuilder()
908908
.setDataLayout(*DL.get())

workspace.bzl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
JAX_COMMIT = "734ebd570891ceaf8c7104e12256a1edfe942b14"
2-
JAX_SHA256 = "ef7b578520f7fdc189ac2579f2929c81d2a6bd07b7a1e1cfa0017111fc0d12ca"
1+
JAX_COMMIT = "810a91968a853b4ae15aa5c5282e5673136bb980"
2+
JAX_SHA256 = ""
33

4-
ENZYME_COMMIT = "6e8dd3e3faff7e766a9e957bf4068d6fcda54539"
4+
ENZYME_COMMIT = "cc65abdb55e5e2e142d773e0508e1083d4b3ac52"
55
ENZYME_SHA256 = ""
66

77
XLA_PATCHES = [

0 commit comments

Comments
 (0)