3 Star 10 Fork 6

zhanghao/mmradar_hand_gesture_recognize

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
SVM.py 3.21 KB
一键复制 编辑 原始数据 按行查看 历史
# -*- coding: utf-8 -*-
"""
Created on Sun Nov 24 19:59:16 2019
读取数据,并进行svm训练
@author: EDMOND
"""
import numpy as np
from sklearn import svm
from scipy import signal
import pickle
paths = ['left', 'push&pull', 'right', 'zhua']
# 获取数据,并保存在4类矩阵中
left_sample = np.zeros((600, 448))
right_sample = np.zeros((600, 448))
push_pull_sample = np.zeros((600, 448))
catch_sample = np.zeros((600, 448))
b, a = signal.butter(8, [0.05,0.4], 'bandpass')
# 读取根目录下4种信号 把信号滤波 STFT ,将信号STFT的值放到四类样本数组中
for i in range(600):
for path in paths:
if path == 'left':
left_data = np.loadtxt(path + '/' + str(i) + '.txt');
data_filted = signal.filtfilt(b, a, left_data) # data为要过滤的信号
f, t, Zxx = signal.stft(data_filted, nperseg=30)
Zxx=abs(Zxx)
left_sample[i, :] = Zxx.reshape(1,448)
elif path == 'right':
right_data = np.loadtxt(path + '/' + str(i) + '.txt')
data_filted = signal.filtfilt(b, a, right_data) # data为要过滤的信号
f, t, Zxx = signal.stft(data_filted, nperseg=30)
Zxx=abs(Zxx)
right_sample[i, :] = Zxx.reshape(1,448)
elif path == 'push&pull':
push_data = np.loadtxt(path + '/' + str(i) + '.txt')
data_filted = signal.filtfilt(b, a, push_data) # data为要过滤的信号
f, t, Zxx = signal.stft(data_filted, nperseg=30)
Zxx=abs(Zxx)
push_pull_sample[i, :] = Zxx.reshape(1,448)
elif path == 'zhua':
zhua_data = np.loadtxt(path + '/' + str(i) + '.txt')
data_filted = signal.filtfilt(b, a, zhua_data) # data为要过滤的信号
f, t, Zxx = signal.stft(data_filted, nperseg=30)
Zxx=abs(Zxx)
catch_sample[i, :] = Zxx.reshape(1,448)
# 将训练集合并,作为数据矩阵
sample = np.vstack((left_sample, right_sample, push_pull_sample, catch_sample))
#标签
label = np.vstack((np.ones((600, 1)), 2 * np.ones((600, 1)), 3 * np.ones((600, 1)), 4 * np.ones((600, 1))))
# 支持向量机,ovo的形式(每两类组成一个向量机),多项式形式核函数,多项式度为2,其他取默认值
CLF = svm.SVC(decision_function_shape='ovr', max_iter=-1, cache_size=12000, kernel='poly', degree=3,C=10,coef0=10,gamma='auto');
# 划分训练集测试集
index = np.array([range(0, 100), range(600, 700), range(1200, 1300), range(1800, 1900)]).reshape(400, )
test_sample = sample[index, :]#测试集 数据
test_label = label[index, :]#测试集 标签
train_sample = np.delete(sample, index, axis=0) #训练集 数据
train_label = np.delete(label, index, axis=0)#训练集 标签
# 训练svm
CLF.fit(train_sample, train_label.reshape((2000,)))
with open('C:/Users/姚奕成/OneDrive/桌面/传感技术作业/data/clf.pickle','wb')as f: #python路径要用反斜杠
pickle.dump(CLF,f) #将模型dump进f里面
# 使用测试集做预测
pred = CLF.predict(test_sample)
# 和测试集标签比较,得出准确率
accr = np.mean(pred == test_label.reshape(400, ))
print(accr )
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/haobaobao/mmradar_hand_gesture_recognize.git
[email protected]:haobaobao/mmradar_hand_gesture_recognize.git
haobaobao
mmradar_hand_gesture_recognize
mmradar_hand_gesture_recognize
master

搜索帮助