代码拉取完成,页面将自动刷新
package main
import (
"encoding/csv"
"fmt"
"io"
"math"
"os"
"reflect"
"sort"
"strings"
)
func main() {
trainDataSet, testDataSet, features := loadDataSet(810)
var remainLabels []string
tree := createTree(trainDataSet, features, remainLabels)
fmt.Println(tree)
total := 0
correctNum := 0
for _, temp := range testDataSet {
result := classify(tree, features, temp[:len(temp)-1])
if strings.Compare(result, temp[len(temp)-1]) == 0 {
correctNum++
}
total++
}
rate := float64(correctNum) / float64(total) * 100
fmt.Println("测试集正确率:" + fmt.Sprintf("%.2f", rate) + "%")
/*var s []string
s = append(s, "a")
s = append(s[:0], s[:0]...)
fmt.Println(s)*/
}
func loadDataSet(trainScale int) ([][]string, [][]string, []string) {
file, err := os.Open("titanic.csv")
if err != nil {
fmt.Println("Error:", err)
return nil, nil, nil
}
defer func(file *os.File) {
err := file.Close()
if err != nil {
}
}(file)
reader := csv.NewReader(file)
var features []string
var trainDataSet [][]string
var testDataSet [][]string
temp, _ := reader.Read()
//features = append(features, temp[2])
// 选择标签
features = append(features, temp[4])
features = append(features, temp[11])
curr := 0
for {
record, err := reader.Read()
if err == io.EOF {
break
} else if err != nil {
fmt.Println("Error:", err)
return nil, nil, nil
}
var tempRecord []string
//tempRecord = append(tempRecord, record[2])
// 选择特征
tempRecord = append(tempRecord, record[4])
if record[11] == "" {
tempRecord = append(tempRecord, "S")
} else {
tempRecord = append(tempRecord, record[11])
}
tempRecord = append(tempRecord, record[1])
if curr <= trainScale {
trainDataSet = append(trainDataSet, tempRecord)
} else {
testDataSet = append(testDataSet, tempRecord)
}
curr++
}
return trainDataSet, testDataSet, features
}
func calcEnt(data [][]string) float64 {
// 数据行数
num := len(data)
// 记录标签出现的次数
labelMap := make(map[string]int)
for _, temp := range data {
curLabel := temp[len(temp)-1]
if _, ok := labelMap[curLabel]; !ok {
labelMap[curLabel] = 0
}
labelMap[curLabel]++
}
ent := 0.0
// 计算经验熵
for _, v := range labelMap {
prob := float64(v) / float64(num)
ent -= math.Log2(prob) * prob
}
return ent
}
func splitDataSet(dataSet [][]string, axis int, value string) [][]string {
var res [][]string
for _, temp := range dataSet {
if strings.Compare(temp[axis], value) == 0 {
// 先复制一个切片,防止对数据集的修改
tar := make([]string, len(temp))
copy(tar, temp)
reduceFeatVec := tar[:axis]
reduceFeatVec = append(reduceFeatVec, tar[axis+1:]...)
//fmt.Println(reduceFeatVec)
res = append(res, reduceFeatVec)
}
}
return res
}
func chooseBestFeature(dataSet [][]string) int {
// 特征数量
featureNum := len(dataSet[0]) - 1
// 计算数据集的熵
baseEntropy := calcEnt(dataSet)
// 信息增益
bestInfoGain := 0.0
// 最优特征的索引值
bestFeatureIdx := -1
// 遍历所有特征
for i := 0; i < featureNum; i++ {
// 获取某一列的所有特征值
var featList []string
for _, temp := range dataSet {
featList = append(featList, temp[i])
}
// 获取不同的特征值
uniqueFeatureValues := distinct(featList)
// 经验条件熵
newEntropy := 0.0
// 计算信息增益
for _, temp := range uniqueFeatureValues {
// 划分子集
subDataSet := splitDataSet(dataSet, i, temp.(string))
// 计算子集的概率
prob := float64(len(subDataSet)) / float64(len(dataSet))
// 计算经验条件熵
newEntropy += prob * calcEnt(subDataSet)
}
// 信息增益
infoGain := baseEntropy - newEntropy
// 打印每个特征的信息增益
fmt.Printf("特征%d的增益为%.3f\n", i, infoGain)
// 计算信息增益
if infoGain > bestInfoGain {
// 更新信息增益,找到最大的信息增益
bestInfoGain = infoGain
// 记录信息增益最大的特征的索引
bestFeatureIdx = i
}
}
return bestFeatureIdx
}
func vote(classList []string) string {
classMap := make(map[string]int)
// 记录特征值出现的次数
for _, temp := range classList {
if _, ok := classMap[temp]; !ok {
classMap[temp] = 0
}
classMap[temp]++
}
// 排序
type entry struct {
feature string
count int
}
var sortedMap []entry
for k, v := range classMap {
sortedMap = append(sortedMap, entry{k, v})
}
sort.Slice(sortedMap, func(i, j int) bool {
return sortedMap[i].count > sortedMap[j].count
})
return sortedMap[0].feature
}
func createTree(dataSet [][]string, labels []string, remainFeatures []string) map[string]interface{} {
// 获取分类标签
var classList []string
for _, temp := range dataSet {
classList = append(classList, temp[len(temp)-1])
}
// 如果类别相同,就停止划分
if len(classList) == count(classList, classList[0]) {
return map[string]interface{}{classList[0]: nil}
}
// 返回出现次数最多的类标签
if len(dataSet[0]) == 1 {
return map[string]interface{}{vote(classList): nil}
}
// 选择最优特征
bestFeatIdx := chooseBestFeature(dataSet)
// 获取最优特征的标签
bestFeatLabel := labels[bestFeatIdx]
remainFeatures = append(remainFeatures, bestFeatLabel)
// 根据最优特征的标签生成树
tree := make(map[string]interface{})
// 删除已经使用的特征标签
tar := make([]string, len(labels))
copy(tar, labels)
labels = append(tar[:bestFeatIdx], tar[bestFeatIdx+1:]...)
// 获取最优特征中的属性值
var featValues []string
for _, temp := range dataSet {
featValues = append(featValues, temp[bestFeatIdx])
}
// 去掉重复的属性值
uniqueValues := distinct(featValues)
// 遍历特征创建决策树
for _, temp := range uniqueValues {
if _, ok := tree[bestFeatLabel]; !ok {
tree[bestFeatLabel] = make(map[string]interface{})
}
tree[bestFeatLabel].(map[string]interface{})[temp.(string)] = createTree(splitDataSet(dataSet, bestFeatIdx, temp.(string)), labels, remainFeatures)
}
return tree
}
func classify(tree map[string]interface{}, features []string, testVec []string) string {
// 获取决策树根节点
var firstStr string
for k, v := range tree {
if v == nil {
return k
}
firstStr = k
}
root := tree[firstStr].(map[string]interface{})
featIdx := index(features, firstStr)
var classLabel string
for k, v := range root {
if strings.Compare(testVec[featIdx], k) == 0 {
if v == nil {
classLabel = k
} else {
classLabel = classify(root[k].(map[string]interface{}), features, testVec)
}
}
}
return classLabel
}
func index(target []string, value string) int {
for i, temp := range target {
if strings.Compare(temp, value) == 0 {
return i
}
}
return 0
}
func count(target []string, value string) int {
num := 0
for _, temp := range target {
if strings.Compare(temp, value) == 0 {
num++
}
}
return num
}
func duplicate(a interface{}) (ret []interface{}) {
va := reflect.ValueOf(a)
for i := 0; i < va.Len(); i++ {
if i > 0 && reflect.DeepEqual(va.Index(i-1).Interface(), va.Index(i).Interface()) {
continue
}
ret = append(ret, va.Index(i).Interface())
}
return ret
}
func distinct(val []string) []interface{} {
sort.Strings(val)
return duplicate(val)
}
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。