1 Star 0 Fork 0

孙雷鸣/single_mult

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
train_and_save_models.py 2.06 KB
一键复制 编辑 原始数据 按行查看 历史
孙雷鸣 提交于 2024-04-10 05:58 . 1
import torch
import torch.nn as nn
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
X = torch.tensor([-1.0, 0.0, 1.0, 2.0, 3.0, 4.0])
y = torch.tensor([-3.0, -1.0, 1.0, 3.0, 5.0, 7.0])
class Perceptron(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(1, 1)
def forward(self, x):
y = self.fc(x)
return y
class NetModel(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(1, 2)
self.fc2 = nn.Linear(2, 2)
self.fc3 = nn.Linear(2, 1)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
y = self.fc3(x)
return y
model1 = Perceptron()
model2 = NetModel()
loss_fn = nn.MSELoss()
optimizer1 = torch.optim.SGD(model1.parameters(), lr=0.01)
optimizer2 = torch.optim.SGD(model2.parameters(), lr=0.01)
# 训练单层感知器
print("saved preception_model to preception_model.pth:")
for epoch in range(500):
y_pred = model1(X.unsqueeze(1))
loss = loss_fn(y_pred, y.unsqueeze(1))
optimizer1.zero_grad()
loss.backward()
optimizer1.step()
if (epoch+1) % 50 == 0:
print(f'Epoch [{epoch+1}/500], Loss: {loss.item():.4f}')
perceptron_path = 'perceptron.pth'
torch.save(model1.state_dict(), perceptron_path)
# 训练多层神经网络
print("saved net_model to net_model")
for epoch in range(500):
y_pred = model2(X.unsqueeze(1))
loss = loss_fn(y_pred, y.unsqueeze(1))
optimizer2.zero_grad()
loss.backward()
optimizer2.step()
if (epoch+1) % 50 == 0:
print(f'Epoch [{epoch+1}/500], Loss: {loss.item():.4f}')
net_model_path = 'net_model.pth'
torch.save(model2.state_dict(), net_model_path)
# 保存模型参数
# perceptron_path = 'perceptron.pth'
# net_model_path = 'net_model.pth'
# torch.save(model1.state_dict(), perceptron_path)
# torch.save(model2.state_dict(), net_model_path)
print(f"模型参数已保存到: {perceptron_path}{net_model_path}")
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/sun-leiming/single_mult.git
[email protected]:sun-leiming/single_mult.git
sun-leiming
single_mult
single_mult
master

搜索帮助