diff --git a/mindspore/core/base/core_ops.h b/mindspore/core/base/core_ops.h index 55a0ba1f2ad65fee774f8e699185cd633535d479..5a90862e68889a741ffcabce90d267489b79baed 100644 --- a/mindspore/core/base/core_ops.h +++ b/mindspore/core/base/core_ops.h @@ -388,7 +388,7 @@ inline const PrimitivePtr kPrimErrorOnDynamicShapeInput = std::make_shared("Depend"); inline const PrimitivePtr kPrimPartial = std::make_shared("Partial"); -inline const PrimitivePtr kPrimIdentity = std::make_shared("identity"); +inline const PrimitivePtr kPrimIdentity = std::make_shared("Identity"); inline const PrimitivePtr kPrimHookBackward = std::make_shared("HookBackward"); inline const PrimitivePtr kPrimPrintShapeType = std::make_shared("PrintShapeType"); inline const PrimitivePtr kPrimSameTypeShape = std::make_shared("SameTypeShape"); diff --git a/mindspore/core/ops/identity.cc b/mindspore/core/ops/identity.cc index 4bd589440909ee9b5d090fc4d0cf77c5867233cd..c43e0f7f04adefa1eeaab5de5d266743a491cbbc 100644 --- a/mindspore/core/ops/identity.cc +++ b/mindspore/core/ops/identity.cc @@ -33,10 +33,10 @@ AbstractBasePtr IdentityInfer(const abstract::AnalysisEnginePtr &, const Primiti // Infer type auto x_type = input_args[0]->BuildType()->cast()->element(); std::set valid_x_type = {TypeIdToType(kObjectTypeTensorType)}; - CheckAndConvertUtils::CheckSubClass("x_type", x_type, valid_x_type, prim_name); + CheckAndConvertUtils::CheckSubClass("x_type", input_args[0]->BuildType(), valid_x_type, prim_name); std::set common_valid_types_ = common_valid_types; common_valid_types_.insert({kNumberTypeBool}); - CheckAndConvertUtils::CheckTensorTypeValid("x_type", x_type, common_valid_types_, prim_name); + CheckAndConvertUtils::CheckTensorTypeValid("x_type", input_args[0]->BuildType(), common_valid_types_, prim_name); // Infer shape auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); diff --git a/mindspore/core/ops/unsorted_segment_sum.cc b/mindspore/core/ops/unsorted_segment_sum.cc index 7d1ba03294dcaf54ba29de573dd269bb26392609..3cfaf3093df2a7dcf9b7b2fb93affec90ee4ce7b 100644 --- a/mindspore/core/ops/unsorted_segment_sum.cc +++ b/mindspore/core/ops/unsorted_segment_sum.cc @@ -25,6 +25,63 @@ namespace mindspore { namespace ops { +AbstractBasePtr UnsortedSegmentSumInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(primitive); + auto unsortedsegmentsum_prim = primitive->cast(); + MS_EXCEPTION_IF_NULL(unsortedsegmentsum_prim); + auto prim_name = unsortedsegmentsum_prim->name(); + + // Infer type + auto x_type = input_args[0]->BuildType()->cast()->element(); + auto segment_ids_type = input_args[1]->BuildType()->cast()->element(); + auto num_segments_type = input_args[2]->BuildType(); + auto num_segments_v = 4; + std::set valid_x_type = {TypeIdToType(kObjectTypeTensorType)}; + CheckAndConvertUtils::CheckSubClass("input_x", input_args[0]->BuildType(), valid_x_type, prim_name); + std::set valid_segment_ids_type = {TypeIdToType(kObjectTypeTensorType)}; + CheckAndConvertUtils::CheckSubClass("segment_ids", input_args[1]->BuildType(), valid_segment_ids_type, prim_name); + + // Infer shape + auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); + CheckAndConvertUtils::CheckInteger("x_shape", x_shape.size(), kGreaterThan, 0, prim_name); + auto shp = x_shape; + auto segment_ids_shape = + CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[1]->BuildShape(), prim_name); + CheckAndConvertUtils::CheckInteger("segment_ids_shape", segment_ids_shape.size(), kGreaterThan, 0, prim_name); + CheckAndConvertUtils::Check("input_x", x_shape.size(), kGreaterEqual, "segment_ids_shape", segment_ids_shape.size(), + prim_name); + + if ((x_shape.end() != find(x_shape.begin(), x_shape.end(), -1)) && + (segment_ids_shape.end() != find(segment_ids_shape.begin(), segment_ids_shape.end(), -1))) { + int64_t size = segment_ids_shape.size(); + for (int64_t i = 0; i < size; ++i) { + CheckAndConvertUtils::Check("segment_ids_shp", segment_ids_shape[i], kEqual, "x_shape", x_shape[i], prim_name); + } + } + + const std::set valid_segments_types = {TypeIdToType(kObjectTypeTensorType)}; + for (const auto &valid_segments_type : valid_segments_types) { + if (IsIdentidityOrSubclass(num_segments_type, valid_segments_type)) { + const std::set valid_num_segments_types = {kNumberTypeInt32, kNumberTypeInt64}; + CheckAndConvertUtils::CheckTensorTypeValid("num_segments", input_args[2]->BuildType(), valid_num_segments_types, + prim_name); + shp = {-1}; + } else { + CheckAndConvertUtils::CheckInteger("num_segments", num_segments_v, kGreaterThan, 0, prim_name); + shp = {num_segments_v}; + } + } + + int64_t size_segment_ids_shp = segment_ids_shape.size(); + int64_t size_x_shpe = x_shape.size(); + for (int64_t i = size_segment_ids_shp; i < size_x_shpe; ++i) { + shp.emplace_back(x_shape[i]); + } + + return std::make_shared(x_type, shp); +} +REGISTER_PRIMITIVE_EVAL_IMPL(UnsortedSegmentSum, prim::kPrimUnsortedSegmentSum, UnsortedSegmentSumInfer); REGISTER_PRIMITIVE_C(kNameUnsortedSegmentSum, UnsortedSegmentSum); } // namespace ops } // namespace mindspore diff --git a/mindspore/core/ops/unsorted_segment_sum.h b/mindspore/core/ops/unsorted_segment_sum.h index 89d61498461a4b738a6b20fabd74534f2abc8fbe..7e47c3197c60729b2638d9a602e12ada4101cb67 100644 --- a/mindspore/core/ops/unsorted_segment_sum.h +++ b/mindspore/core/ops/unsorted_segment_sum.h @@ -37,6 +37,10 @@ class UnsortedSegmentSum : public PrimitiveC { MS_DECLARE_PARENT(UnsortedSegmentSum, PrimitiveC); void Init() {} }; + +AbstractBasePtr UnsortedSegmentSumInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args); +using PrimUnsortedSegmentSumPtr = std::shared_ptr; } // namespace ops } // namespace mindspore diff --git a/tests/ut/cpp/c_ops/test_c_ops_identity.cc b/tests/ut/cpp/c_ops/test_c_ops_identity.cc new file mode 100644 index 0000000000000000000000000000000000000000..b189502f8b685452608cc79f82fb118788d118c1 --- /dev/null +++ b/tests/ut/cpp/c_ops/test_c_ops_identity.cc @@ -0,0 +1,63 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include "common/common_test.h" +#include "ops/identity.h" +#include "ir/dtype/type.h" +#include "ir/value.h" +#include "abstract/dshape.h" +#include "utils/tensor_construct_utils.h" + +namespace mindspore { +namespace ops { +class TestIdentity : public UT::Common { + public: + TestIdentity() {} + void SetUp() {} + void TearDown() {} +}; + +TEST_F(TestIdentity, test_cops_identity) { + auto identity = std::make_shared(); + auto tensor_x = TensorConstructUtils::CreateOnesTensor(kNumberTypeInt64, std::vector{1, 2, 3, 4}); + MS_EXCEPTION_IF_NULL(tensor_x); + auto abstract = identity->Infer({tensor_x->ToAbstract()}); + MS_EXCEPTION_IF_NULL(abstract); + EXPECT_EQ(abstract->isa(), true); + auto shape_ptr = abstract->BuildShape(); + MS_EXCEPTION_IF_NULL(shape_ptr); + EXPECT_EQ(shape_ptr->isa(), true); + auto shape = shape_ptr->cast(); + MS_EXCEPTION_IF_NULL(shape); + auto shape_vec = shape->shape(); + auto type = abstract->BuildType(); + MS_EXCEPTION_IF_NULL(type); + EXPECT_EQ(type->isa(), true); + auto tensor_type = type->cast(); + MS_EXCEPTION_IF_NULL(tensor_type); + auto data_type = tensor_type->element(); + MS_EXCEPTION_IF_NULL(data_type); + EXPECT_EQ(data_type->type_id(), kNumberTypeInt64); + EXPECT_EQ(shape_vec.size(), 4); + EXPECT_EQ(shape_vec[0], 1); + EXPECT_EQ(shape_vec[1], 2); + EXPECT_EQ(shape_vec[2], 3); + EXPECT_EQ(shape_vec[3], 4); +} + +} // namespace ops +} // namespace mindspore diff --git a/tests/ut/cpp/c_ops/test_c_ops_unsorted_segment_sum.cc b/tests/ut/cpp/c_ops/test_c_ops_unsorted_segment_sum.cc new file mode 100644 index 0000000000000000000000000000000000000000..a4e0c7827a113c0972f1412e0fb1c78bd45aa901 --- /dev/null +++ b/tests/ut/cpp/c_ops/test_c_ops_unsorted_segment_sum.cc @@ -0,0 +1,65 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include "common/common_test.h" +#include "ops/unsorted_segment_sum.h" +#include "ir/dtype/type.h" +#include "ir/value.h" +#include "abstract/dshape.h" +#include "utils/tensor_construct_utils.h" + +namespace mindspore { +namespace ops { +class TestUnsortedSegmentSum : public UT::Common { + public: + TestUnsortedSegmentSum() {} + void SetUp() {} + void TearDown() {} +}; + +TEST_F(TestUnsortedSegmentSum, test_cops_unsortedsegmentsum) { + auto unsortedsegmentsum = std::make_shared(); + auto tensor_x = TensorConstructUtils::CreateOnesTensor(kNumberTypeFloat32, std::vector{1, 4}); + auto segment_ids = TensorConstructUtils::CreateOnesTensor(kNumberTypeInt32, std::vector{1, 4}); + auto num_segments = MakeValue(4); + MS_EXCEPTION_IF_NULL(tensor_x); + MS_EXCEPTION_IF_NULL(segment_ids); + MS_EXCEPTION_IF_NULL(num_segments); + auto abstract = unsortedsegmentsum->Infer({tensor_x->ToAbstract(), segment_ids->ToAbstract(), num_segments->ToAbstract()}); + MS_EXCEPTION_IF_NULL(abstract); + EXPECT_EQ(abstract->isa(), true); + auto shape_ptr = abstract->BuildShape(); + MS_EXCEPTION_IF_NULL(shape_ptr); + EXPECT_EQ(shape_ptr->isa(), true); + auto shape = shape_ptr->cast(); + MS_EXCEPTION_IF_NULL(shape); + auto shape_vec = shape->shape(); + auto type = abstract->BuildType(); + MS_EXCEPTION_IF_NULL(type); + EXPECT_EQ(type->isa(), true); + auto tensor_type = type->cast(); + MS_EXCEPTION_IF_NULL(tensor_type); + auto data_type = tensor_type->element(); + MS_EXCEPTION_IF_NULL(data_type); + EXPECT_EQ(data_type->type_id(), kNumberTypeFloat32); + EXPECT_EQ(shape_vec.size(), 1); + EXPECT_EQ(shape_vec[0], 4); + +} + +} // namespace ops +} // namespace mindspore