diff --git a/paddle/fluid/operators/complex_op.h b/paddle/fluid/operators/complex_op.h index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..3dd5ea9f7e83dbfaa353378cfee10231c445c222 100644 --- a/paddle/fluid/operators/complex_op.h +++ b/paddle/fluid/operators/complex_op.h @@ -0,0 +1,111 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/operators/elementwise/elementwise_op_function.h" +#include "paddle/fluid/operators/math/complex_functors.h" +#include "paddle/fluid/platform/complex.h" + +namespace paddle { +namespace operators { + +// functors to use with ElementwiseComputeEx +template <typename T> +struct RealAndImagToComplexFunctor { + inline HOSTDEVICE platform::complex<T> operator()(const T x, const T y) { + return platform::complex<T>(x, y); + } +}; + +template <typename T> +struct ImagAndRealToComplexFunctor { + inline HOSTDEVICE platform::complex<T> operator()(const T y, const T x) { + return platform::complex<T>(x, y); + } +}; + +template <typename T> +struct ComplexGradForRealFunctor { + inline HOSTDEVICE T operator()(const T x, const T y, + const platform::complex<T> out, + const platform::complex<T> dout) { + return dout.real; + } +}; + +template <typename T> +struct ComplexGradForImagFunctor { + inline HOSTDEVICE T operator()(const T x, const T y, + const platform::complex<T> out, + const platform::complex<T> dout) { + return dout.imag; + } +}; + +template <typename DeviceContext, typename T> +class ComplexKernel : public framework::OpKernel<T> { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + const auto* x = ctx.Input<framework::Tensor>("X"); + const auto* y = ctx.Input<framework::Tensor>("Y"); + auto* z = ctx.Output<framework::Tensor>("Out"); + + using C = platform::complex<T>; + z->mutable_data<C>(ctx.GetPlace()); + +// NOTE(chenfeiyu): be careful of the caveats of calling elementwise-related +// facility functions +#if defined(__NVCC__) || defined(__HIPCC__) + ElementwiseComputeEx<RealAndImagToComplexFunctor<T>, DeviceContext, T, C>( + ctx, x, y, /*axis*/ -1, RealAndImagToComplexFunctor<T>(), z); +#else + auto x_dims = x->dims(); + auto y_dims = y->dims(); + if (x_dims.size() >= y_dims.size()) { + ElementwiseComputeEx<RealAndImagToComplexFunctor<T>, DeviceContext, T, C>( + ctx, x, y, /*axis*/ -1, RealAndImagToComplexFunctor<T>(), z); + } else { + ElementwiseComputeEx<ImagAndRealToComplexFunctor<T>, DeviceContext, T, C>( + ctx, x, y, /*axis*/ -1, ImagAndRealToComplexFunctor<T>(), z); + } +#endif + } +}; + +template <typename DeviceContext, typename T> +class ComplexGradKernel : public framework::OpKernel<T> { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + using Tensor = framework::Tensor; + + auto* x = ctx.Input<Tensor>("X"); + auto* y = ctx.Input<Tensor>("Y"); + auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out")); + auto* dx = ctx.Output<Tensor>(framework::GradVarName("X")); + auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y")); + using C = platform::complex<T>; + + // skip out in a hacky way + auto* out = dout; + ElemwiseGradCompute<DeviceContext, T, ComplexGradForRealFunctor<T>, + ComplexGradForImagFunctor<T>, C>( + ctx, *x, *y, *out, *dout, /*axis*/ -1, dx, dy, + ComplexGradForRealFunctor<T>(), ComplexGradForImagFunctor<T>()); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/label_smooth_op.cu b/paddle/fluid/operators/label_smooth_op.cu index 2e7d1de3bd756d02fa37ee14de15879c01b64385..2c7a08de0f65b8b9ccbd03db5c96d669ae284180 100644 --- a/paddle/fluid/operators/label_smooth_op.cu +++ b/paddle/fluid/operators/label_smooth_op.cu @@ -28,7 +28,7 @@ struct LabelSmoothFunctor { label_dim = static_cast<T>(label_dim_data); } - __device__ __forceinline__ T operator()(const T& x) const { + __device__ __forceinline__ T operator()(const T x) const { return (static_cast<T>(1 - epsilon) * x + static_cast<T>(epsilon / label_dim)); } @@ -42,7 +42,7 @@ struct LabelSmoothGradFunctor { epsilon = static_cast<T>(epsilon_data); } - __device__ __forceinline__ T operator()(const T& x) const { + __device__ __forceinline__ T operator()(const T x) const { return static_cast<T>(1 - epsilon) * x; } }; diff --git a/paddle/fluid/operators/lgamma_op.cu b/paddle/fluid/operators/lgamma_op.cu index baf86c99b5678dfb5475c7217f8be17f5bccd505..da40518d9b4b2ca708e614238ce4d63b4d3a1e2f 100644 --- a/paddle/fluid/operators/lgamma_op.cu +++ b/paddle/fluid/operators/lgamma_op.cu @@ -21,7 +21,7 @@ namespace operators { template <typename T> struct CudaLgammaFunctor { - __device__ __forceinline__ T operator()(const T& x) const { + __device__ __forceinline__ T operator()(const T x) const { return Eigen::numext::lgamma(x); } }; diff --git a/paddle/fluid/operators/matrix_rank_op.h b/paddle/fluid/operators/matrix_rank_op.h index 7fa74368332d0ab582919238ed6b0db6a32252ec..c3d99a21b72358df5dedc7741072a7913de174af 100644 --- a/paddle/fluid/operators/matrix_rank_op.h +++ b/paddle/fluid/operators/matrix_rank_op.h @@ -48,17 +48,17 @@ static DDim RemoveLastDim(const DDim& dim) { template <typename T> struct GreaterThanFunctor { - HOSTDEVICE int operator()(const T& a, const T& b) const { return a > b; } + HOSTDEVICE int operator()(const T a, const T b) const { return a > b; } }; template <typename T> struct LessThanFunctor { - HOSTDEVICE int operator()(const T& a, const T& b) const { return a < b; } + HOSTDEVICE int operator()(const T a, const T b) const { return a < b; } }; template <typename T> struct GreaterElementFunctor { - HOSTDEVICE T operator()(const T& a, const T& b) const { + HOSTDEVICE T operator()(const T a, const T b) const { if (a > b) { return a; } else {