@@ -436,7 +436,7 @@ void mlir::configureGpuToNVVMConversionLegality(ConversionTarget &target) {
436436                      LLVM::FAbsOp, LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FRemOp,
437437                      LLVM::LogOp, LLVM::Log10Op, LLVM::Log2Op, LLVM::PowOp,
438438                      LLVM::RoundEvenOp, LLVM::RoundOp, LLVM::SinOp,
439-                       LLVM::SqrtOp>();
439+                       LLVM::SincosOp, LLVM:: SqrtOp>();
440440
441441  //  TODO: Remove once we support replacing non-root ops.
442442  target.addLegalOp <gpu::YieldOp, gpu::GPUModuleOp>();
@@ -466,6 +466,100 @@ void mlir::configureGpuToNVVMTypeConverter(LLVMTypeConverter &converter) {
466466  });
467467}
468468
469+ struct  SincosOpLowering  : public  ConvertOpToLLVMPattern <math::SincosOp> {
470+   using  ConvertOpToLLVMPattern<math::SincosOp>::ConvertOpToLLVMPattern;
471+ 
472+   LogicalResult
473+   matchAndRewrite (math::SincosOp op, OpAdaptor adaptor,
474+                   ConversionPatternRewriter &rewriter) const  override  {
475+     Location loc = op.getLoc ();
476+     Value input = adaptor.getOperand ();
477+     Type inputType = input.getType ();
478+     auto  convertedInput = maybeExt (input, rewriter);
479+     auto  computeType = convertedInput.getType ();
480+ 
481+     StringRef sincosFunc;
482+     if  (isa<Float32Type>(computeType)) {
483+       const  arith::FastMathFlags flag = op.getFastmath ();
484+       const  bool  useApprox =
485+           mlir::arith::bitEnumContainsAny (flag, arith::FastMathFlags::afn);
486+       sincosFunc = useApprox ? " __nv_fast_sincosf" " __nv_sincosf" 
487+     } else  if  (isa<Float64Type>(computeType)) {
488+       sincosFunc = " __nv_sincos" 
489+     } else  {
490+       return  rewriter.notifyMatchFailure (op,
491+                                          " unsupported operand type for sincos" 
492+     }
493+ 
494+     auto  ptrType = LLVM::LLVMPointerType::get (rewriter.getContext ());
495+ 
496+     Value sinPtr, cosPtr;
497+     {
498+       OpBuilder::InsertionGuard guard (rewriter);
499+       auto  *scope =
500+           op->getParentWithTrait <mlir::OpTrait::AutomaticAllocationScope>();
501+       assert (scope && " Expected op to be inside automatic allocation scope" 
502+       rewriter.setInsertionPointToStart (&scope->getRegion (0 ).front ());
503+       auto  one = rewriter.create <LLVM::ConstantOp>(
504+           loc, rewriter.getI32Type (), rewriter.getI32IntegerAttr (1 ));
505+       sinPtr =
506+           rewriter.create <LLVM::AllocaOp>(loc, ptrType, computeType, one, 0 );
507+       cosPtr =
508+           rewriter.create <LLVM::AllocaOp>(loc, ptrType, computeType, one, 0 );
509+     }
510+ 
511+     createSincosCall (rewriter, loc, sincosFunc, convertedInput, sinPtr, cosPtr,
512+                      op);
513+ 
514+     auto  sinResult = rewriter.create <LLVM::LoadOp>(loc, computeType, sinPtr);
515+     auto  cosResult = rewriter.create <LLVM::LoadOp>(loc, computeType, cosPtr);
516+ 
517+     rewriter.replaceOp (op, {maybeTrunc (sinResult, inputType, rewriter),
518+                             maybeTrunc (cosResult, inputType, rewriter)});
519+     return  success ();
520+   }
521+ 
522+ private: 
523+   Value maybeExt (Value operand, PatternRewriter &rewriter) const  {
524+     if  (isa<Float16Type, BFloat16Type>(operand.getType ()))
525+       return  rewriter.create <LLVM::FPExtOp>(
526+           operand.getLoc (), Float32Type::get (rewriter.getContext ()), operand);
527+     return  operand;
528+   }
529+ 
530+   Value maybeTrunc (Value operand, Type type, PatternRewriter &rewriter) const  {
531+     if  (operand.getType () != type)
532+       return  rewriter.create <LLVM::FPTruncOp>(operand.getLoc (), type, operand);
533+     return  operand;
534+   }
535+ 
536+   void  createSincosCall (ConversionPatternRewriter &rewriter, Location loc,
537+                         StringRef funcName, Value input, Value sinPtr,
538+                         Value cosPtr, Operation *op) const  {
539+     auto  voidType = LLVM::LLVMVoidType::get (rewriter.getContext ());
540+     auto  ptrType = sinPtr.getType ();
541+ 
542+     SmallVector<Type> operandTypes = {input.getType (), ptrType, ptrType};
543+     auto  funcType = LLVM::LLVMFunctionType::get (voidType, operandTypes);
544+ 
545+     auto  funcAttr = StringAttr::get (op->getContext (), funcName);
546+     auto  funcOp =
547+         SymbolTable::lookupNearestSymbolFrom<LLVM::LLVMFuncOp>(op, funcAttr);
548+ 
549+     if  (!funcOp) {
550+       auto  parentFunc = op->getParentOfType <FunctionOpInterface>();
551+       assert (parentFunc && " expected there to be a parent function" 
552+       OpBuilder b (parentFunc);
553+ 
554+       auto  globalloc = loc->findInstanceOfOrUnknown <FileLineColLoc>();
555+       funcOp = LLVM::LLVMFuncOp::create (b, globalloc, funcName, funcType);
556+     }
557+ 
558+     SmallVector<Value> callOperands = {input, sinPtr, cosPtr};
559+     rewriter.create <LLVM::CallOp>(loc, funcOp, callOperands);
560+   }
561+ };
562+ 
469563template  <typename  OpTy>
470564static  void  populateOpPatterns (const  LLVMTypeConverter &converter,
471565                               RewritePatternSet &patterns,
@@ -589,6 +683,9 @@ void mlir::populateLibDeviceConversionPatterns(
589683                                  " __nv_tan" " __nv_fast_tanf" 
590684  populateOpPatterns<math::TanhOp>(converter, patterns, benefit, " __nv_tanhf" 
591685                                   " __nv_tanh" 
686+ 
687+   //  Custom pattern for sincos since it returns two values
688+   patterns.add <SincosOpLowering>(converter, benefit);
592689}
593690
594691void  mlir::populateGpuToNVVMConversionPatterns (
0 commit comments