免费发布信息网站大全有哪些,o2o与网站建设,asp做素材网站,室内设计联盟免费下载文章目录 前言一、原始程序---计算原型#xff0c;开始训练#xff0c;计算损失二、每一行代码的详细解释2.1 粗略分析2.2 每一行代码详细分析 前言
承接系列4#xff0c;此部分属于原型类中的计算原型#xff0c;开始训练#xff0c;计算损失函数。 一、原始程序—计算原… 文章目录 前言一、原始程序---计算原型开始训练计算损失二、每一行代码的详细解释2.1 粗略分析2.2 每一行代码详细分析 前言
承接系列4此部分属于原型类中的计算原型开始训练计算损失函数。 一、原始程序—计算原型开始训练计算损失
def compute_center(self,data_set): #data_set是一个numpy对象是某一个支持集计算支持集对应的中心的点center 0for i in range(self.Ns):data np.reshape(data_set[i], [1, self.input_shape[0], self.input_shape[1], self.input_shape[2]])data Variable(torch.from_numpy(data))data self.model(data)[0] #将查询点嵌入另一个空间if i 0:center dataelse:center datacenter / self.Nsreturn centerdef train(self,labels_data,class_number): #网络的训练#Select class indices for episodeclass_index list(range(class_number))random.shuffle(class_index)choss_class_index class_index[:self.Nc]#选20个类sample {xc:[],xq:[]}for label in choss_class_index:D_set labels_data[label]#从D_set随机取支持集和查询集support_set,query_set self.randomSample(D_set)#计算中心点self.center[label] self.compute_center(support_set)#将中心和查询集存储在list中sample[xc].append(self.center[label]) #listsample[xq].append(query_set)#优化器optimizer torch.optim.Adam(self.model.parameters(),lr0.001)optimizer.zero_grad()protonets_loss self.loss(sample)protonets_loss.backward()optimizer.step()def loss(self,sample): #自定义lossloss_1 autograd.Variable(torch.FloatTensor([0]))for i in range(self.Nc):query_dataSet sample[xq][i]for n in range(self.Nq):data np.reshape(query_dataSet[n], [1, self.input_shape[0], self.input_shape[1], self.input_shape[2]])data Variable(torch.from_numpy(data))data self.model(data)[0] #将查询点嵌入另一个空间#查询点与每个中心点逐个计算欧氏距离predict 0for j in range(self.Nc):center_j sample[xc][j]if j 0:predict eucli_tensor(data,center_j)else:predict torch.cat((predict, eucli_tensor(data,center_j)), 0)#为loss叠加loss_1 -1*F.log_softmax(predict,dim0)[i]loss_1 / self.Nq*self.Ncreturn loss_1二、每一行代码的详细解释
2.1 粗略分析
第一个函数 compute_center(self,data_set) 用于计算支持集中心点的坐标。输入参数 data_set 是一个 numpy 对象代表支持集。该函数中用了一个 for 循环遍历了每一个支持集中的样本将其嵌入到另一个空间并计算其总和来求得所有样本的中心点。最后返回计算出的中心点的坐标。
第二个函数 train(self,labels_data,class_number) 是网络的训练函数。其中 labels_data 是标签数据class_number 是类别数。首先从 class_number 中随机选取出 Nc 个类对于每个选出来的类从其标签数据 D_set 中随机选取出支持集和查询集并将支持集传入 compute_center() 函数计算中心点。接着将计算出的中心点和查询集存储在样本字典 sample 中。最后使用 Adam 优化器对模型进行优化并计算损失调用了 loss 函数将反向传播得到的梯度更新到模型中。
第三个函数def loss(self,sample)是一个自定义的损失函数它的作用是计算样本的损失值。在这个损失函数中使用了欧氏距离和softmax函数。
2.2 每一行代码详细分析
def compute_center(self,data_set): - 这是一个方法用于计算给定数据集支持集的中心点。
2-4. center 0 - 初始化中心点的变量为0。
5-8. for i in range(self.Ns): - 遍历数据集中的每个数据点。
9-14. 这部分代码将数据集中的每个数据点重塑为适应模型输入的形状并将其转换为PyTorch的Variable。然后使用模型将查询点嵌入另一个空间。
if i 0: - 如果这是第一个数据点则将查询点设置为中心点。
16-19. 否则将查询点添加到中心点。
center / self.Ns - 计算中心点这是所有数据点的平均值。
return center - 返回计算得到的中心点。
接下来是 train 方法
23-24. 从给定的标签数据中选择类别索引并随机洗牌。选择特定数量的类别self.Nc。
25-30. 对于所选类别中的每一个从其数据中随机选择支持集和查询集。
31-33. 使用 compute_center 方法计算每个类的中心点并将其存储在列表中。同时将查询集也存储在列表中。
34-37. 初始化优化器这里使用Adam优化算法学习率设置为0.001。然后清空梯度缓存。
38-42. 计算损失函数值该损失函数是根据自定义的损失函数计算的。然后进行反向传播以计算梯度。
optimizer.step() - 使用优化器更新模型的参数。
最后是自定义的损失函数 loss
45-46. 初始化一个张量 loss_1 为0它用于累计损失值。
47-52. 对于每个类别self.Nc遍历查询集中的每个数据点。对于每个查询点将其嵌入到另一个空间中并计算它与每个中心点之间的欧氏距离。
53-57. 将所有的距离组合在一起并使用softmax函数将其转换为概率值。然后对于每个查询点累加其与所有中心点的负对数似然损失值。
loss_1 / self.Nq*self.Nc - 将损失值除以查询集中的数据点数量和类别数量以获得平均损失值。
return loss_1 - 返回计算得到的损失值。