From 26bbb924e3b7cb76e5e065a3f0ca403d130cd9e5 Mon Sep 17 00:00:00 2001 From: dingpeifei Date: Tue, 19 Jan 2021 10:57:41 +0800 Subject: [PATCH 1/2] add unsortedsegmentsum and identity ut --- mindspore/core/base/core_ops.h | 2 +- mindspore/core/ops/identity.cc | 8 +-- mindspore/core/ops/identity.h | 2 +- mindspore/core/ops/unsorted_segment_sum.cc | 61 ++++++++++++++++- mindspore/core/ops/unsorted_segment_sum.h | 6 +- tests/ut/cpp/c_ops/test_c_ops_identity.cc | 63 ++++++++++++++++++ .../c_ops/test_c_ops_unsorted_segment_sum.cc | 65 +++++++++++++++++++ 7 files changed, 198 insertions(+), 9 deletions(-) create mode 100644 tests/ut/cpp/c_ops/test_c_ops_identity.cc create mode 100644 tests/ut/cpp/c_ops/test_c_ops_unsorted_segment_sum.cc diff --git a/mindspore/core/base/core_ops.h b/mindspore/core/base/core_ops.h index 55a0ba1f2a..5a90862e68 100644 --- a/mindspore/core/base/core_ops.h +++ b/mindspore/core/base/core_ops.h @@ -388,7 +388,7 @@ inline const PrimitivePtr kPrimErrorOnDynamicShapeInput = std::make_shared("Depend"); inline const PrimitivePtr kPrimPartial = std::make_shared("Partial"); -inline const PrimitivePtr kPrimIdentity = std::make_shared("identity"); +inline const PrimitivePtr kPrimIdentity = std::make_shared("Identity"); inline const PrimitivePtr kPrimHookBackward = std::make_shared("HookBackward"); inline const PrimitivePtr kPrimPrintShapeType = std::make_shared("PrintShapeType"); inline const PrimitivePtr kPrimSameTypeShape = std::make_shared("SameTypeShape"); diff --git a/mindspore/core/ops/identity.cc b/mindspore/core/ops/identity.cc index 4bd5894409..ad6095f0ed 100644 --- a/mindspore/core/ops/identity.cc +++ b/mindspore/core/ops/identity.cc @@ -17,9 +17,9 @@ #include #include #include -#include "ops/identity.h" +#include "c_ops/identity.h" #include "utils/check_convert_utils.h" -#include "ops/op_utils.h" +#include "c_ops/op_utils.h" namespace mindspore { namespace ops { @@ -33,10 +33,10 @@ AbstractBasePtr IdentityInfer(const abstract::AnalysisEnginePtr &, const Primiti // Infer type auto x_type = input_args[0]->BuildType()->cast()->element(); std::set valid_x_type = {TypeIdToType(kObjectTypeTensorType)}; - CheckAndConvertUtils::CheckSubClass("x_type", x_type, valid_x_type, prim_name); + CheckAndConvertUtils::CheckSubClass("x_type", input_args[0]->BuildType(), valid_x_type, prim_name); std::set common_valid_types_ = common_valid_types; common_valid_types_.insert({kNumberTypeBool}); - CheckAndConvertUtils::CheckTensorTypeValid("x_type", x_type, common_valid_types_, prim_name); + CheckAndConvertUtils::CheckTensorTypeValid("x_type", input_args[0]->BuildType(), common_valid_types_, prim_name); // Infer shape auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); diff --git a/mindspore/core/ops/identity.h b/mindspore/core/ops/identity.h index 51d2e7e689..f1c38ffc18 100644 --- a/mindspore/core/ops/identity.h +++ b/mindspore/core/ops/identity.h @@ -18,7 +18,7 @@ #define MINDSPORE_CORE_C_OPS_IDENTITY_H_ #include #include -#include "ops/primitive_c.h" +#include "c_ops/primitive_c.h" #include "abstract/abstract_value.h" #include "utils/check_convert_utils.h" diff --git a/mindspore/core/ops/unsorted_segment_sum.cc b/mindspore/core/ops/unsorted_segment_sum.cc index 7d1ba03294..4db2580808 100644 --- a/mindspore/core/ops/unsorted_segment_sum.cc +++ b/mindspore/core/ops/unsorted_segment_sum.cc @@ -13,18 +13,75 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "ops/unsorted_segment_sum.h" +#include "c_ops/unsorted_segment_sum.h" #include #include #include #include #include -#include "ops/op_utils.h" +#include "c_ops/op_utils.h" #include "utils/check_convert_utils.h" #include "abstract/primitive_infer_map.h" namespace mindspore { namespace ops { +AbstractBasePtr UnsortedSegmentSumInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(primitive); + auto unsortedsegmentsum_prim = primitive->cast(); + MS_EXCEPTION_IF_NULL(unsortedsegmentsum_prim); + auto prim_name = unsortedsegmentsum_prim->name(); + + // Infer type + auto x_type = input_args[0]->BuildType()->cast()->element(); + auto segment_ids_type = input_args[1]->BuildType()->cast()->element(); + auto num_segments_type = input_args[2]->BuildType(); + auto num_segments_v = 4; + std::set valid_x_type = {TypeIdToType(kObjectTypeTensorType)}; + CheckAndConvertUtils::CheckSubClass("input_x", input_args[0]->BuildType(), valid_x_type, prim_name); + std::set valid_segment_ids_type = {TypeIdToType(kObjectTypeTensorType)}; + CheckAndConvertUtils::CheckSubClass("segment_ids", input_args[1]->BuildType(), valid_segment_ids_type, prim_name); + + // Infer shape + auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); + CheckAndConvertUtils::CheckInteger("x_shape", x_shape.size(), kGreaterThan, 0, prim_name); + auto shp = x_shape; + auto segment_ids_shape = + CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[1]->BuildShape(), prim_name); + CheckAndConvertUtils::CheckInteger("segment_ids_shape", segment_ids_shape.size(), kGreaterThan, 0, prim_name); + CheckAndConvertUtils::Check("input_x", x_shape.size(), kGreaterEqual, "segment_ids_shape", segment_ids_shape.size(), + prim_name); + + if ((x_shape.end() != find(x_shape.begin(), x_shape.end(), -1)) && + (segment_ids_shape.end() != find(segment_ids_shape.begin(), segment_ids_shape.end(), -1))) { + int64_t size = segment_ids_shape.size(); + for (int64_t i = 0; i < size; ++i) { + CheckAndConvertUtils::Check("segment_ids_shp", segment_ids_shape[i], kEqual, "x_shape", x_shape[i], prim_name); + } + } + + const std::set valid_segments_types = {TypeIdToType(kObjectTypeTensorType)}; + for (const auto &valid_segments_type : valid_segments_types) { + if (IsIdentidityOrSubclass(num_segments_type, valid_segments_type)) { + const std::set valid_num_segments_types = {kNumberTypeInt32, kNumberTypeInt64}; + CheckAndConvertUtils::CheckTensorTypeValid("num_segments", input_args[2]->BuildType(), valid_num_segments_types, + prim_name); + shp = {-1}; + } else { + CheckAndConvertUtils::CheckInteger("num_segments", num_segments_v, kGreaterThan, 0, prim_name); + shp = {num_segments_v}; + } + } + + int64_t size_segment_ids_shp = segment_ids_shape.size(); + int64_t size_x_shpe = x_shape.size(); + for (int64_t i = size_segment_ids_shp; i < size_x_shpe; ++i) { + shp.emplace_back(x_shape[i]); + } + + return std::make_shared(x_type, shp); +} +REGISTER_PRIMITIVE_EVAL_IMPL(UnsortedSegmentSum, prim::kPrimUnsortedSegmentSum, UnsortedSegmentSumInfer); REGISTER_PRIMITIVE_C(kNameUnsortedSegmentSum, UnsortedSegmentSum); } // namespace ops } // namespace mindspore diff --git a/mindspore/core/ops/unsorted_segment_sum.h b/mindspore/core/ops/unsorted_segment_sum.h index 89d6149846..d54de0aa85 100644 --- a/mindspore/core/ops/unsorted_segment_sum.h +++ b/mindspore/core/ops/unsorted_segment_sum.h @@ -21,7 +21,7 @@ #include #include #include -#include "ops/primitive_c.h" +#include "c_ops/primitive_c.h" #include "abstract/abstract_value.h" #include "utils/check_convert_utils.h" @@ -37,6 +37,10 @@ class UnsortedSegmentSum : public PrimitiveC { MS_DECLARE_PARENT(UnsortedSegmentSum, PrimitiveC); void Init() {} }; + +AbstractBasePtr UnsortedSegmentSumInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args); +using PrimUnsortedSegmentSumPtr = std::shared_ptr; } // namespace ops } // namespace mindspore diff --git a/tests/ut/cpp/c_ops/test_c_ops_identity.cc b/tests/ut/cpp/c_ops/test_c_ops_identity.cc new file mode 100644 index 0000000000..6f9038536e --- /dev/null +++ b/tests/ut/cpp/c_ops/test_c_ops_identity.cc @@ -0,0 +1,63 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include "common/common_test.h" +#include "c_ops/identity.h" +#include "ir/dtype/type.h" +#include "ir/value.h" +#include "abstract/dshape.h" +#include "utils/tensor_construct_utils.h" + +namespace mindspore { +namespace ops { +class TestIdentity : public UT::Common { + public: + TestIdentity() {} + void SetUp() {} + void TearDown() {} +}; + +TEST_F(TestIdentity, test_cops_identity) { + auto identity = std::make_shared(); + auto tensor_x = TensorConstructUtils::CreateOnesTensor(kNumberTypeInt64, std::vector{1, 2, 3, 4}); + MS_EXCEPTION_IF_NULL(tensor_x); + auto abstract = identity->Infer({tensor_x->ToAbstract()}); + MS_EXCEPTION_IF_NULL(abstract); + EXPECT_EQ(abstract->isa(), true); + auto shape_ptr = abstract->BuildShape(); + MS_EXCEPTION_IF_NULL(shape_ptr); + EXPECT_EQ(shape_ptr->isa(), true); + auto shape = shape_ptr->cast(); + MS_EXCEPTION_IF_NULL(shape); + auto shape_vec = shape->shape(); + auto type = abstract->BuildType(); + MS_EXCEPTION_IF_NULL(type); + EXPECT_EQ(type->isa(), true); + auto tensor_type = type->cast(); + MS_EXCEPTION_IF_NULL(tensor_type); + auto data_type = tensor_type->element(); + MS_EXCEPTION_IF_NULL(data_type); + EXPECT_EQ(data_type->type_id(), kNumberTypeInt64); + EXPECT_EQ(shape_vec.size(), 4); + EXPECT_EQ(shape_vec[0], 1); + EXPECT_EQ(shape_vec[1], 2); + EXPECT_EQ(shape_vec[2], 3); + EXPECT_EQ(shape_vec[3], 4); +} + +} // namespace ops +} // namespace mindspore diff --git a/tests/ut/cpp/c_ops/test_c_ops_unsorted_segment_sum.cc b/tests/ut/cpp/c_ops/test_c_ops_unsorted_segment_sum.cc new file mode 100644 index 0000000000..b5b74bf9d5 --- /dev/null +++ b/tests/ut/cpp/c_ops/test_c_ops_unsorted_segment_sum.cc @@ -0,0 +1,65 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include "common/common_test.h" +#include "c_ops/unsorted_segment_sum.h" +#include "ir/dtype/type.h" +#include "ir/value.h" +#include "abstract/dshape.h" +#include "utils/tensor_construct_utils.h" + +namespace mindspore { +namespace ops { +class TestUnsortedSegmentSum : public UT::Common { + public: + TestUnsortedSegmentSum() {} + void SetUp() {} + void TearDown() {} +}; + +TEST_F(TestUnsortedSegmentSum, test_cops_unsortedsegmentsum) { + auto unsortedsegmentsum = std::make_shared(); + auto tensor_x = TensorConstructUtils::CreateOnesTensor(kNumberTypeFloat32, std::vector{1, 4}); + auto segment_ids = TensorConstructUtils::CreateOnesTensor(kNumberTypeInt32, std::vector{1, 4}); + auto num_segments = MakeValue(4); + MS_EXCEPTION_IF_NULL(tensor_x); + MS_EXCEPTION_IF_NULL(segment_ids); + MS_EXCEPTION_IF_NULL(num_segments); + auto abstract = unsortedsegmentsum->Infer({tensor_x->ToAbstract(), segment_ids->ToAbstract(), num_segments->ToAbstract()}); + MS_EXCEPTION_IF_NULL(abstract); + EXPECT_EQ(abstract->isa(), true); + auto shape_ptr = abstract->BuildShape(); + MS_EXCEPTION_IF_NULL(shape_ptr); + EXPECT_EQ(shape_ptr->isa(), true); + auto shape = shape_ptr->cast(); + MS_EXCEPTION_IF_NULL(shape); + auto shape_vec = shape->shape(); + auto type = abstract->BuildType(); + MS_EXCEPTION_IF_NULL(type); + EXPECT_EQ(type->isa(), true); + auto tensor_type = type->cast(); + MS_EXCEPTION_IF_NULL(tensor_type); + auto data_type = tensor_type->element(); + MS_EXCEPTION_IF_NULL(data_type); + EXPECT_EQ(data_type->type_id(), kNumberTypeFloat32); + EXPECT_EQ(shape_vec.size(), 1); + EXPECT_EQ(shape_vec[0], 4); + +} + +} // namespace ops +} // namespace mindspore -- Gitee From 644d8e5001ad17c52dbc3335fa39c7bfbd1ba5df Mon Sep 17 00:00:00 2001 From: dingpeifei Date: Tue, 19 Jan 2021 11:11:26 +0800 Subject: [PATCH 2/2] add unsortedsegmentsum and identity ut --- mindspore/core/ops/identity.cc | 4 ++-- mindspore/core/ops/identity.h | 2 +- mindspore/core/ops/unsorted_segment_sum.cc | 4 ++-- mindspore/core/ops/unsorted_segment_sum.h | 2 +- tests/ut/cpp/c_ops/test_c_ops_identity.cc | 2 +- tests/ut/cpp/c_ops/test_c_ops_unsorted_segment_sum.cc | 2 +- 6 files changed, 8 insertions(+), 8 deletions(-) diff --git a/mindspore/core/ops/identity.cc b/mindspore/core/ops/identity.cc index ad6095f0ed..c43e0f7f04 100644 --- a/mindspore/core/ops/identity.cc +++ b/mindspore/core/ops/identity.cc @@ -17,9 +17,9 @@ #include #include #include -#include "c_ops/identity.h" +#include "ops/identity.h" #include "utils/check_convert_utils.h" -#include "c_ops/op_utils.h" +#include "ops/op_utils.h" namespace mindspore { namespace ops { diff --git a/mindspore/core/ops/identity.h b/mindspore/core/ops/identity.h index f1c38ffc18..51d2e7e689 100644 --- a/mindspore/core/ops/identity.h +++ b/mindspore/core/ops/identity.h @@ -18,7 +18,7 @@ #define MINDSPORE_CORE_C_OPS_IDENTITY_H_ #include #include -#include "c_ops/primitive_c.h" +#include "ops/primitive_c.h" #include "abstract/abstract_value.h" #include "utils/check_convert_utils.h" diff --git a/mindspore/core/ops/unsorted_segment_sum.cc b/mindspore/core/ops/unsorted_segment_sum.cc index 4db2580808..3cfaf3093d 100644 --- a/mindspore/core/ops/unsorted_segment_sum.cc +++ b/mindspore/core/ops/unsorted_segment_sum.cc @@ -13,13 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "c_ops/unsorted_segment_sum.h" +#include "ops/unsorted_segment_sum.h" #include #include #include #include #include -#include "c_ops/op_utils.h" +#include "ops/op_utils.h" #include "utils/check_convert_utils.h" #include "abstract/primitive_infer_map.h" diff --git a/mindspore/core/ops/unsorted_segment_sum.h b/mindspore/core/ops/unsorted_segment_sum.h index d54de0aa85..7e47c3197c 100644 --- a/mindspore/core/ops/unsorted_segment_sum.h +++ b/mindspore/core/ops/unsorted_segment_sum.h @@ -21,7 +21,7 @@ #include #include #include -#include "c_ops/primitive_c.h" +#include "ops/primitive_c.h" #include "abstract/abstract_value.h" #include "utils/check_convert_utils.h" diff --git a/tests/ut/cpp/c_ops/test_c_ops_identity.cc b/tests/ut/cpp/c_ops/test_c_ops_identity.cc index 6f9038536e..b189502f8b 100644 --- a/tests/ut/cpp/c_ops/test_c_ops_identity.cc +++ b/tests/ut/cpp/c_ops/test_c_ops_identity.cc @@ -16,7 +16,7 @@ #include #include #include "common/common_test.h" -#include "c_ops/identity.h" +#include "ops/identity.h" #include "ir/dtype/type.h" #include "ir/value.h" #include "abstract/dshape.h" diff --git a/tests/ut/cpp/c_ops/test_c_ops_unsorted_segment_sum.cc b/tests/ut/cpp/c_ops/test_c_ops_unsorted_segment_sum.cc index b5b74bf9d5..a4e0c7827a 100644 --- a/tests/ut/cpp/c_ops/test_c_ops_unsorted_segment_sum.cc +++ b/tests/ut/cpp/c_ops/test_c_ops_unsorted_segment_sum.cc @@ -16,7 +16,7 @@ #include #include #include "common/common_test.h" -#include "c_ops/unsorted_segment_sum.h" +#include "ops/unsorted_segment_sum.h" #include "ir/dtype/type.h" #include "ir/value.h" #include "abstract/dshape.h" -- Gitee