Skip to content

Commit

Permalink
[XLA:CPUGPU] Make tanh emission more like the other ops
Browse files Browse the repository at this point in the history
No reason for tanh to be the odd one out.

PiperOrigin-RevId: 208072018
  • Loading branch information
d0k authored and tensorflower-gardener committed Aug 9, 2018
1 parent 570b5dd commit 1b5edbf
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 66 deletions.
74 changes: 33 additions & 41 deletions tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,47 +30,6 @@ limitations under the License.
namespace xla {
namespace cpu {

StatusOr<llvm::Value*> CpuElementalIrEmitter::EmitFloatUnaryOp(
const HloInstruction* op, llvm::Value* operand_value) const {
switch (op->opcode()) {
case HloOpcode::kTanh: {
PrimitiveType element_type = op->shape().element_type();
bool cast_result_to_fp16 = false;
string function_name;
switch (element_type) {
case F16:
cast_result_to_fp16 = true;
operand_value = b_->CreateFPCast(operand_value, b_->getFloatTy());
TF_FALLTHROUGH_INTENDED;
case F32:
function_name = "tanhf";
break;
case F64:
function_name = "tanh";
break;
default:
return Unimplemented("tanh");
}
// Create a function declaration.
llvm::Function* function =
llvm::cast<llvm::Function>(module_->getOrInsertFunction(
llvm_ir::AsStringRef(function_name), operand_value->getType(),
operand_value->getType()));
function->setCallingConv(llvm::CallingConv::C);
function->setDoesNotThrow();
function->setDoesNotAccessMemory();
// Create an instruction to call the function.
llvm::Value* result = b_->CreateCall(function, operand_value);
if (cast_result_to_fp16) {
result = b_->CreateFPCast(result, b_->getHalfTy());
}
return result;
}
default:
return ElementalIrEmitter::EmitFloatUnaryOp(op, operand_value);
}
}

StatusOr<llvm::Value*> CpuElementalIrEmitter::EmitAtan2(
PrimitiveType prim_type, llvm::Value* lhs, llvm::Value* rhs) const {
string function_name;
Expand Down Expand Up @@ -106,6 +65,39 @@ StatusOr<llvm::Value*> CpuElementalIrEmitter::EmitAtan2(
return result;
}

StatusOr<llvm::Value*> CpuElementalIrEmitter::EmitTanh(
PrimitiveType prim_type, llvm::Value* value) const {
bool cast_result_to_fp16 = false;
string function_name;
switch (prim_type) {
case F16:
cast_result_to_fp16 = true;
value = b_->CreateFPCast(value, b_->getFloatTy());
TF_FALLTHROUGH_INTENDED;
case F32:
function_name = "tanhf";
break;
case F64:
function_name = "tanh";
break;
default:
return Unimplemented("tanh");
}
// Create a function declaration.
llvm::Function* function = llvm::cast<llvm::Function>(
module_->getOrInsertFunction(llvm_ir::AsStringRef(function_name),
value->getType(), value->getType()));
function->setCallingConv(llvm::CallingConv::C);
function->setDoesNotThrow();
function->setDoesNotAccessMemory();
// Create an instruction to call the function.
llvm::Value* result = b_->CreateCall(function, value);
if (cast_result_to_fp16) {
result = b_->CreateFPCast(result, b_->getHalfTy());
}
return result;
}

llvm_ir::ElementGenerator CpuElementalIrEmitter::MakeElementGenerator(
const HloInstruction* hlo,
const HloToElementGeneratorMap& operand_to_generator) const {
Expand Down
4 changes: 2 additions & 2 deletions tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,10 @@ class CpuElementalIrEmitter : public ElementalIrEmitter {
const HloToElementGeneratorMap& operand_to_generator) const override;

protected:
StatusOr<llvm::Value*> EmitFloatUnaryOp(
const HloInstruction* op, llvm::Value* operand_value) const override;
StatusOr<llvm::Value*> EmitAtan2(PrimitiveType prim_type, llvm::Value* lhs,
llvm::Value* rhs) const override;
StatusOr<llvm::Value*> EmitTanh(PrimitiveType prim_type,
llvm::Value* value) const override;

IrEmitter* ir_emitter_;
};
Expand Down
7 changes: 7 additions & 0 deletions tensorflow/compiler/xla/service/elemental_ir_emitter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,8 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatUnaryOp(
return EmitCos(op->shape().element_type(), operand_value);
case HloOpcode::kSin:
return EmitSin(op->shape().element_type(), operand_value);
case HloOpcode::kTanh:
return EmitTanh(op->shape().element_type(), operand_value);
case HloOpcode::kFloor:
return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::floor,
{operand_value},
Expand Down Expand Up @@ -1060,6 +1062,11 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitAtan2(PrimitiveType prim_type,
return Unimplemented("atan2");
}

StatusOr<llvm::Value*> ElementalIrEmitter::EmitTanh(PrimitiveType prim_type,
llvm::Value* value) const {
return Unimplemented("tanh");
}

StatusOr<llvm::Value*> ElementalIrEmitter::EmitReducePrecision(
const HloInstruction* hlo, llvm::Value* x) const {
if (hlo->operand(0)->shape().element_type() != F32) {
Expand Down
3 changes: 3 additions & 0 deletions tensorflow/compiler/xla/service/elemental_ir_emitter.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,9 @@ class ElementalIrEmitter {
llvm::Value* lhs,
llvm::Value* rhs) const;

virtual StatusOr<llvm::Value*> EmitTanh(PrimitiveType prim_type,
llvm::Value* value) const;

virtual StatusOr<llvm::Value*> EmitReducePrecision(const HloInstruction* hlo,
llvm::Value* x) const;

Expand Down
31 changes: 11 additions & 20 deletions tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -272,27 +272,18 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitAtan2(
prim_type);
}

StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitFloatUnaryOp(
const HloInstruction* op, llvm::Value* operand_value) const {
PrimitiveType input_type = op->operand(0)->shape().element_type();
PrimitiveType output_type = op->shape().element_type();
switch (op->opcode()) {
case HloOpcode::kTanh:
// If we don't care much about precision, emit a fast approximation of
// tanh.
if (hlo_module_config_.debug_options().xla_enable_fast_math()) {
// Upcast F16 to F32 if necessary.
llvm::Type* type =
input_type == F16 ? b_->getFloatTy() : operand_value->getType();
llvm::Value* input = b_->CreateFPCast(operand_value, type);
llvm::Value* fast_tanh = llvm_ir::EmitFastTanh(b_, input);
return b_->CreateFPCast(fast_tanh, operand_value->getType());
}
return EmitLibdeviceMathCall("__nv_tanh", {operand_value}, {input_type},
output_type);
default:
return ElementalIrEmitter::EmitFloatUnaryOp(op, operand_value);
StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitTanh(
PrimitiveType prim_type, llvm::Value* value) const {
// If we don't care much about precision, emit a fast approximation of
// tanh.
if (hlo_module_config_.debug_options().xla_enable_fast_math()) {
// Upcast F16 to F32 if necessary.
llvm::Type* type = prim_type == F16 ? b_->getFloatTy() : value->getType();
llvm::Value* input = b_->CreateFPCast(value, type);
llvm::Value* fast_tanh = llvm_ir::EmitFastTanh(b_, input);
return b_->CreateFPCast(fast_tanh, value->getType());
}
return EmitLibdeviceMathCall("__nv_tanh", {value}, {prim_type}, prim_type);
}

llvm::Value* GpuElementalIrEmitter::EmitDeviceFunctionCall(
Expand Down
6 changes: 3 additions & 3 deletions tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,6 @@ class GpuElementalIrEmitter : public ElementalIrEmitter {
const HloToElementGeneratorMap& operand_to_generator) const override;

protected:
StatusOr<llvm::Value*> EmitFloatUnaryOp(
const HloInstruction* op, llvm::Value* operand_value) const override;

StatusOr<llvm::Value*> EmitFloatBinaryOp(
const HloInstruction* op, llvm::Value* lhs_value,
llvm::Value* rhs_value) const override;
Expand Down Expand Up @@ -85,6 +82,9 @@ class GpuElementalIrEmitter : public ElementalIrEmitter {
StatusOr<llvm::Value*> EmitAtan2(PrimitiveType prim_type, llvm::Value* lhs,
llvm::Value* rhs) const override;

StatusOr<llvm::Value*> EmitTanh(PrimitiveType prim_type,
llvm::Value* value) const override;

llvm::Value* EmitThreadId() const override;

private:
Expand Down

0 comments on commit 1b5edbf

Please sign in to comment.