Pytorch教程

PyTorch从Scratch训练Convent

PyTorch从Scratch训练Convent详细操作教程
在本章中,我们将重点学习如何从Scratch创建一个Convent。这推断了使用 torch 创建相应的修道院或样本神经网络。
第1步
使用各自的参数创建必要的类,参数包括具有随机值的权重。
# Filename : example.py
# Copyright : 2020 By Lidihuo
# Author by : www.lidihuo.com
# Date : 2020-08-23
class Neural_Network(nn.Module):
   def __init__(self, ):
      super(Neural_Network, self).__init__()
      self.inputSize = 2
      self.outputSize = 1
      self.hiddenSize = 3
      # weights
      self.W1 = torch.randn(self.inputSize,
      self.hiddenSize) # 3 X 2 tensor
      self.W2 = torch.randn(self.hiddenSize, self.outputSize) # 3 X 1 tensor
第2步
使用sigmoid函数创建函数的向前模式。
# Filename : example.py
# Copyright : 2020 By Lidihuo
# Author by : www.lidihuo.com
# Date : 2020-08-23
def forward(self, X):
   self.z = torch.matmul(X, self.W1) # 3 X 3 ".dot"
   does not broadcast in PyTorch
   self.z2 = self.sigmoid(self.z) # activation function
   self.z3 = torch.matmul(self.z2, self.W2)
   o = self.sigmoid(self.z3) # final activation
   function
   return o
   def sigmoid(self, s):
      return 1 / (1 + torch.exp(-s))
   def sigmoidPrime(self, s):
      # derivative of sigmoid
      return s * (1 - s)
   def backward(self, X, y, o):
      self.o_error = y - o # error in output
      self.o_delta = self.o_error * self.sigmoidPrime(o) # derivative of sig to error
      self.z2_error = torch.matmul(self.o_delta, torch.t(self.W2))
      self.z2_delta = self.z2_error * self.sigmoidPrime(self.z2)
      self.W1 + = torch.matmul(torch.t(X), self.z2_delta)
      self.W2 + = torch.matmul(torch.t(self.z2), self.o_delta)
第3步创建如下所述的培训和预测模型 -
# Filename : example.py
# Copyright : 2020 By Lidihuo
# Author by : www.lidihuo.com
# Date : 2020-08-23
def train(self, X, y):
   # forward + backward pass for training
   o = self.forward(X)
   self.backward(X, y, o)
def saveWeights(self, model):
   # Implement PyTorch internal storage functions
   torch.save(model, "NN")
   # you can reload model with all the weights and so forth with:
   # torch.load("NN")
def predict(self):
   print ("Predicted data based on trained weights: ")
   print ("Input (scaled): " + str(xPredicted))
   print ("Output: " + str(self.forward(xPredicted)))
昵称: 邮箱:
Copyright © 2022 立地货 All Rights Reserved.
备案号:京ICP备14037608号-4