Skip to content

Commit 7b527fd

Browse files
committed
Bump jax commit
1 parent a9652b9 commit 7b527fd

File tree

2 files changed

+66
-64
lines changed

2 files changed

+66
-64
lines changed

src/enzyme_ad/jax/primitives.py

Lines changed: 63 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -241,64 +241,8 @@ def stablehlo_inject(self):
241241
def ad_level(self):
242242
return self.passes.count("enzyme-wrap")
243243

244-
245-
DefaultCPPPipeline = OldXLAPipeline() # NewXLAPipeline(None, True)
246-
DefaultJaXPipeline = JaXPipeline(
247-
"inline{default-pipeline=canonicalize max-iterations=4},canonicalize,cse,enzyme-hlo-unroll,canonicalize,cse,enzyme-hlo-opt{passses=24575},cse"
248-
)
249-
250-
251-
def pass_pipeline(options):
252-
if type(options) == type(""):
253-
return options
254-
else:
255-
return
256-
257-
258-
def resource_dir():
259-
import os
260-
261-
dn = os.path.dirname(enzyme_call.__file__)
262-
res = os.path.join(dn, "..", "..", "clang", "staging")
263-
return res
264-
265-
266-
def cflags():
267-
import platform
268-
import os
269-
270-
if platform.system() == "Darwin":
271-
res = (
272-
"-isysroot",
273-
"/Library/Developer/CommandLineTools/SDKs/MacOSX.sdk",
274-
"-isystem",
275-
"/Applications/Xcode.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX.sdk/usr/include/c++/v1",
276-
"-internal-isystem",
277-
os.path.join(resource_dir(), "include"),
278-
"-internal-externc-isystem",
279-
"/Applications/Xcode.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX.sdk/usr/include",
280-
"-internal-externc-isystem",
281-
"/Applications/Xcode.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain/usr/include",
282-
"-fgnuc-version=4.2.1",
283-
)
284-
else:
285-
res = ()
286-
if os.getenv("ENABLE_GDBLISTENER") is not None:
287-
res = res + (
288-
"-debug-info-kind=standalone",
289-
"-dwarf-version=5",
290-
"-debugger-tuning=gdb",
291-
)
292-
return res
293-
294-
295-
def optimize_module(mod, pipeline=None):
296-
if pipeline is None:
297-
pipeline = """
298-
inline{default-pipeline=canonicalize max-iterations=4},
299-
canonicalize,cse,
300-
canonicalize,
301-
enzyme-hlo-generate-td{
244+
def hlo_opts():
245+
return """enzyme-hlo-generate-td{
302246
patterns=compare_op_canon<16>;
303247
broadcast_in_dim_op_canon<16>;
304248
convert_op_canon<16>;
@@ -452,6 +396,63 @@ def optimize_module(mod, pipeline=None):
452396
transform-interpreter,
453397
enzyme-hlo-remove-transform
454398
"""
399+
400+
DefaultCPPPipeline = OldXLAPipeline() # NewXLAPipeline(None, True)
401+
DefaultJaXPipeline = JaXPipeline(
402+
"inline{default-pipeline=canonicalize max-iterations=4},canonicalize,cse,enzyme-hlo-unroll,canonicalize,cse," + hlo_opts() + ", cse"
403+
)
404+
405+
406+
def pass_pipeline(options):
407+
if type(options) == type(""):
408+
return options
409+
else:
410+
return
411+
412+
413+
def resource_dir():
414+
import os
415+
416+
dn = os.path.dirname(enzyme_call.__file__)
417+
res = os.path.join(dn, "..", "..", "clang", "staging")
418+
return res
419+
420+
421+
def cflags():
422+
import platform
423+
import os
424+
425+
if platform.system() == "Darwin":
426+
res = (
427+
"-isysroot",
428+
"/Library/Developer/CommandLineTools/SDKs/MacOSX.sdk",
429+
"-isystem",
430+
"/Applications/Xcode.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX.sdk/usr/include/c++/v1",
431+
"-internal-isystem",
432+
os.path.join(resource_dir(), "include"),
433+
"-internal-externc-isystem",
434+
"/Applications/Xcode.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX.sdk/usr/include",
435+
"-internal-externc-isystem",
436+
"/Applications/Xcode.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain/usr/include",
437+
"-fgnuc-version=4.2.1",
438+
)
439+
else:
440+
res = ()
441+
if os.getenv("ENABLE_GDBLISTENER") is not None:
442+
res = res + (
443+
"-debug-info-kind=standalone",
444+
"-dwarf-version=5",
445+
"-debugger-tuning=gdb",
446+
)
447+
return res
448+
449+
450+
def optimize_module(mod, pipeline=None):
451+
if pipeline is None:
452+
pipeline = """
453+
inline{default-pipeline=canonicalize max-iterations=4},
454+
canonicalize,cse,
455+
canonicalize,""" + hlo_opts()
455456
enzyme_call.optimize_module(mod, pipeline)
456457
return
457458

@@ -630,6 +631,7 @@ def maketup(ty):
630631
ty = ir.RankedTensorType(ty)
631632
tystr = ty.element_type.__str__()
632633
tystr = {
634+
"i1": "bool",
633635
"bf16": "bfloat16",
634636
"f32": "float",
635637
"f64": "double",
@@ -1169,10 +1171,10 @@ def make_zero(tan, prim):
11691171

11701172
outshapes = kwargs["out_shapes"]
11711173
ret_act_tup = ",".join(["enzyme_dup"] * len(outshapes))
1172-
afterad = "arith-raise{stablehlo=true}, enzyme-hlo-opt, cse, canonicalize"
1174+
afterad = "arith-raise{stablehlo=true}, " + hlo_opts() + ", cse, canonicalize"
11731175
newpasses = (
11741176
"inline{default-pipeline=canonicalize max-iterations=4},"
1175-
+ "enzyme-hlo-opt,cse,enzyme-wrap{infn=main outfn= retTys="
1177+
+ hlo_opts() + ", cse,enzyme-wrap{infn=main outfn= retTys="
11761178
+ ret_act_tup
11771179
+ " argTys="
11781180
+ arg_act_tup
@@ -1401,7 +1403,7 @@ def enzyme_vjp(shadow_rets, *prim_args, **kwargs):
14011403
newpasses = (
14021404
prev_passes
14031405
+ ad_pass
1404-
+ ",arith-raise{stablehlo=true},canonicalize, remove-unnecessary-enzyme-ops, enzyme-simplify-math, enzyme-hlo-opt, canonicalize, cse"
1406+
+ ",arith-raise{stablehlo=true},canonicalize, remove-unnecessary-enzyme-ops, enzyme-simplify-math, " + hlo_opts() + ", canonicalize, cse"
14051407
+ post_passes
14061408
)
14071409

workspace.bzl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
JAX_COMMIT = "8815b236b656f494171131301d1d81e84cf4c67c"
2-
JAX_SHA256 = "188787b8ec366dcda5805f24dcbfab7a349aa780498aeb9c3b728c9cec0a7e7d"
1+
JAX_COMMIT = "493698e6e053641aa8c51bca657cbd763a3ced19"
2+
JAX_SHA256 = ""
33

44
ENZYME_COMMIT = "9acbc0a667ec8ae76407b5708758667a65ff15aa"
55
ENZYME_SHA256 = "287143133ccf9501a02f1bdab351c34adcab3bbfc8648b180ebd79d0e058b3af"
@@ -32,4 +32,4 @@ XLA_PATCHES = [
3232
"find . -type f -name BUILD -exec sed -i.bak1 's/\\/\\/third_party\\/py\\/enzyme_ad\\/\\.\\.\\./public/g' {} +",
3333
"find . -type f -name BUILD -exec sed -i.bak2 's/\\/\\/xla\\/mlir\\/memref:friends/\\/\\/visibility:public/g' {} +",
3434
"find xla/mlir -type f -name BUILD -exec sed -i.bak3 's/\\/\\/xla:internal/\\/\\/\\/\\/visibility:public/g' {} +"
35-
]
35+
]

0 commit comments

Comments
 (0)