-
Notifications
You must be signed in to change notification settings - Fork 44
/
Copy pathprecision_policy.py
89 lines (78 loc) · 2.52 KB
/
precision_policy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
# Enum for precision policy
from enum import Enum, auto
import jax.numpy as jnp
import warp as wp
class Precision(Enum):
FP64 = auto()
FP32 = auto()
FP16 = auto()
UINT8 = auto()
BOOL = auto()
@property
def wp_dtype(self):
if self == Precision.FP64:
return wp.float64
elif self == Precision.FP32:
return wp.float32
elif self == Precision.FP16:
return wp.float16
elif self == Precision.UINT8:
return wp.uint8
elif self == Precision.BOOL:
return wp.bool
else:
raise ValueError("Invalid precision")
@property
def jax_dtype(self):
if self == Precision.FP64:
return jnp.float64
elif self == Precision.FP32:
return jnp.float32
elif self == Precision.FP16:
return jnp.float16
elif self == Precision.UINT8:
return jnp.uint8
elif self == Precision.BOOL:
return jnp.bool_
else:
raise ValueError("Invalid precision")
class PrecisionPolicy(Enum):
FP64FP64 = auto()
FP64FP32 = auto()
FP64FP16 = auto()
FP32FP32 = auto()
FP32FP16 = auto()
@property
def compute_precision(self):
if self == PrecisionPolicy.FP64FP64:
return Precision.FP64
elif self == PrecisionPolicy.FP64FP32:
return Precision.FP64
elif self == PrecisionPolicy.FP64FP16:
return Precision.FP64
elif self == PrecisionPolicy.FP32FP32:
return Precision.FP32
elif self == PrecisionPolicy.FP32FP16:
return Precision.FP32
else:
raise ValueError("Invalid precision policy")
@property
def store_precision(self):
if self == PrecisionPolicy.FP64FP64:
return Precision.FP64
elif self == PrecisionPolicy.FP64FP32:
return Precision.FP32
elif self == PrecisionPolicy.FP64FP16:
return Precision.FP16
elif self == PrecisionPolicy.FP32FP32:
return Precision.FP32
elif self == PrecisionPolicy.FP32FP16:
return Precision.FP16
else:
raise ValueError("Invalid precision policy")
def cast_to_compute_jax(self, array):
compute_precision = self.compute_precision
return jnp.array(array, dtype=compute_precision.jax_dtype)
def cast_to_store_jax(self, array):
store_precision = self.store_precision
return jnp.array(array, dtype=store_precision.jax_dtype)