for iteration, batch in enumerate(tqdm_dataloader): user_idx, pos_idx, neg_idx = batch rep, user_pool = self.model(graph) user = rep[user_idx] + user_pool[user_idx] pos = rep[self.model.n_user + pos_idx] neg = rep[self.model.n_user + neg_idx] pos_preds = self.model.predict(user, pos) neg_preds = self.model.predict(user, neg) loss, losses = self.criterion(pos_preds, neg_preds, user, pos, neg) self.optimizer.zero_grad() loss.backward() self.optimizer.step() epoch_losses = [x + y for x, y in zip(epoch_losses, losses)] tqdm_dataloader.set_description('Epoch {}, loss: {:.4f}'.format(self.epoch, loss.item())) if self.scheduler is not None: self.scheduler.step() epoch_losses = [sum(epoch_losses)] + epoch_losses return epoch_losses

时间: 2024-04-20 09:22:52 浏览: 153
在这段代码中,是一个训练过程中的一个 epoch 的逻辑。 首先,使用 `enumerate` 函数迭代 `tqdm_dataloader`,返回迭代次数和每个批次的数据。 然后,从 `batch` 中解包得到 `user_idx`、`pos_idx` 和 `neg_idx`,表示当前批次中的用户、正样本和负样本的索引。 接下来,通过调用 `self.model(graph)` 方法,传入 `graph` 对象,获取模型的表示向量 `rep` 和用户池(user pool)的表示向量 `user_pool`。 然后,根据索引从 `rep` 中取出对应的用户向量 `user`、正样本向量 `pos` 和负样本向量 `neg`。 通过调用 `self.model.predict(user, pos)` 和 `self.model.predict(user, neg)`,分别得到正样本和负样本的预测结果 `pos_preds` 和 `neg_preds`。 接下来,调用 `self.criterion(pos_preds, neg_preds, user, pos, neg)` 方法,传入预测结果和实际数据,计算损失值和各个损失函数的值。 然后,调用 `self.optimizer.zero_grad()` 方法将模型参数的梯度置零,以便进行下一次反向传播。 接着,调用 `loss.backward()` 方法进行反向传播计算梯度。 然后,调用 `self.optimizer.step()` 方法更新模型参数。 通过迭代更新 `epoch_losses` 列表,将当前批次的损失值累加到 `epoch_losses` 中。 在 `tqdm_dataloader` 的描述信息中显示当前 epoch 的编号和损失值。 如果存在学习率调度器(scheduler),则调用 `self.scheduler.step()` 方法更新学习率。 最后,将 `epoch_losses` 列表中的各个损失值相加,并将总和作为第一个元素添加到 `epoch_losses` 列表的开头。 最后,返回 `epoch_losses` 列表,它包含了当前 epoch 中各个损失函数的累计损失值。
阅读全文

相关推荐

param = {'num_leaves': 31, 'min_data_in_leaf': 20, 'objective': 'binary', 'learning_rate': 0.06, "boosting": "gbdt", "metric": 'None', "verbosity": -1} trn_data = lgb.Dataset(trn, trn_label) val_data = lgb.Dataset(val, val_label) num_round = 666 # clf = lgb.train(param, trn_data, num_round, valid_sets=[trn_data, val_data], verbose_eval=100, # early_stopping_rounds=300, feval=win_score_eval) clf = lgb.train(param, trn_data, num_round) # oof_lgb = clf.predict(val, num_iteration=clf.best_iteration) test_lgb = clf.predict(test, num_iteration=clf.best_iteration)thresh_hold = 0.5 oof_test_final = test_lgb >= thresh_hold print(metrics.accuracy_score(test_label, oof_test_final)) print(metrics.confusion_matrix(test_label, oof_test_final)) tp = np.sum(((oof_test_final == 1) & (test_label == 1))) pp = np.sum(oof_test_final == 1) print('accuracy1:%.3f'% (tp/(pp)))test_postive_idx = np.argwhere(oof_test_final == True).reshape(-1) # test_postive_idx = list(range(len(oof_test_final))) test_all_idx = np.argwhere(np.array(test_data_idx)).reshape(-1) stock_info['trade_date_id'] = stock_info['trade_date'].map(date_map) stock_info['trade_date_id'] = stock_info['trade_date_id'] + 1tmp_col = ['ts_code', 'trade_date', 'trade_date_id', 'open', 'high', 'low', 'close', 'ma5', 'ma13', 'ma21', 'label_final', 'name'] stock_info.iloc[test_all_idx[test_postive_idx]] tmp_df = stock_info[tmp_col].iloc[test_all_idx[test_postive_idx]].reset_index() tmp_df['label_prob'] = test_lgb[test_postive_idx] tmp_df['is_limit_up'] = tmp_df['close'] == tmp_df['high'] buy_df = tmp_df[(tmp_df['is_limit_up']==False)].reset_index() buy_df.drop(['index', 'level_0'], axis=1, inplace=True)buy_df['buy_flag'] = 1 stock_info_copy['sell_flag'] = 0tmp_idx = (index_df['trade_date'] == test_date_min+1) close1 = index_df[tmp_idx]['close'].values[0] test_date_max = 20220829 tmp_idx = (index_df['trade_date'] == test_date_max) close2 = index_df[tmp_idx]['close'].values[0]tmp_idx = (stock_info_copy['trade_date'] >= test_date_min) & (stock_info_copy['trade_date'] <= test_date_max) tmp_df = stock_info_copy[tmp_idx].reset_index(drop=True)from imp import reload import Account reload(Account) money_init = 200000 account = Account.Account(money_init, max_hold_period=20, stop_loss_rate=-0.07, stop_profit_rate=0.12) account.BackTest(buy_df, tmp_df, index_df, buy_price='open')tmp_df2 = buy_df[['ts_code', 'trade_date', 'label_prob', 'label_final']] tmp_df2 = tmp_df2.rename(columns={'trade_date':'buy_date'}) tmp_df = account.info tmp_df['buy_date'] = tmp_df['buy_date'].apply(lambda x: int(x)) tmp_df = tmp_df.merge(tmp_df2, on=['ts_code', 'buy_date'], how='left')最终的tmp_df是什么?tmp_df[tmp_df['label_final']==1]又选取了什么股票?

最新推荐

recommend-type

Keras框架中的epoch、bacth、batch size、iteration使用介绍

在Keras框架中,训练深度学习模型时,四个关键概念是epoch、batch、batch size以及iteration。理解这些术语对于优化模型的训练过程至关重要。 1. **Epoch** - Epoch是训练过程中的一个完整周期,意味着数据集中的...
recommend-type

Microsoft Edge Flash Player 浏览器插件

不是重庆代理的flash国内版,支持edge浏览器旧版flash功能
recommend-type

python-批量重命名脚本

一段轻量级批量重命名脚本,常被用来清理下载/抓取得到的大量文件名不规范的资料。脚本支持正则搜索-替换、递归子目录、以及“演习模式”(dry-run)——先看看即将发生什么,再决定是否执行。保存为 rename.py 直接运行即可。
recommend-type

redis python

redis
recommend-type

Java算法:二叉树的前中后序遍历实现

在深入探讨如何用Java实现二叉树及其三种基本遍历(前序遍历、中序遍历和后序遍历)之前,我们需要了解一些基础知识。 首先,二叉树是一种被广泛使用的数据结构,它具有以下特性: 1. 每个节点最多有两个子节点,分别是左子节点和右子节点。 2. 左子树和右子树都是二叉树。 3. 每个节点都包含三个部分:值、左子节点的引用和右子节点的引用。 4. 二叉树的遍历通常用于访问树中的每个节点,且访问的顺序可以是前序、中序和后序。 接下来,我们将详细介绍如何用Java来构建这样一个树结构,并实现这些遍历方式。 ### Java实现二叉树结构 要实现二叉树结构,我们首先需要一个节点类(Node.java),该类将包含节点值以及指向左右子节点的引用。其次,我们需要一个树类(Tree.java),它将包含根节点,并提供方法来构建树以及执行不同的遍历。 #### Node.java ```java public class Node { int value; Node left; Node right; public Node(int value) { this.value = value; left = null; right = null; } } ``` #### Tree.java ```java import java.util.Stack; public class Tree { private Node root; public Tree() { root = null; } // 这里可以添加插入、删除等方法 // ... // 前序遍历 public void preOrderTraversal(Node node) { if (node != null) { System.out.print(node.value + " "); preOrderTraversal(node.left); preOrderTraversal(node.right); } } // 中序遍历 public void inOrderTraversal(Node node) { if (node != null) { inOrderTraversal(node.left); System.out.print(node.value + " "); inOrderTraversal(node.right); } } // 后序遍历 public void postOrderTraversal(Node node) { if (node != null) { postOrderTraversal(node.left); postOrderTraversal(node.right); System.out.print(node.value + " "); } } // 迭代形式的前序遍历 public void preOrderTraversalIterative() { Stack<Node> stack = new Stack<>(); stack.push(root); while (!stack.isEmpty()) { Node node = stack.pop(); System.out.print(node.value + " "); if (node.right != null) { stack.push(node.right); } if (node.left != null) { stack.push(node.left); } } System.out.println(); } // 迭代形式的中序遍历 public void inOrderTraversalIterative() { Stack<Node> stack = new Stack<>(); Node current = root; while (current != null || !stack.isEmpty()) { while (current != null) { stack.push(current); current = current.left; } current = stack.pop(); System.out.print(current.value + " "); current = current.right; } System.out.println(); } // 迭代形式的后序遍历 public void postOrderTraversalIterative() { Stack<Node> stack = new Stack<>(); Stack<Node> output = new Stack<>(); stack.push(root); while (!stack.isEmpty()) { Node node = stack.pop(); output.push(node); if (node.left != null) { stack.push(node.left); } if (node.right != null) { stack.push(node.right); } } while (!output.isEmpty()) { System.out.print(output.pop().value + " "); } System.out.println(); } } ``` ### Java实现的二叉树遍历详细解析 #### 前序遍历(Pre-order Traversal) 前序遍历是先访问根节点,然后递归地前序遍历左子树,接着递归地前序遍历右子树。遍历的顺序是:根 -> 左 -> 右。 #### 中序遍历(In-order Traversal) 中序遍历是先递归地中序遍历左子树,然后访问根节点,最后递归地中序遍历右子树。对于二叉搜索树来说,中序遍历可以按从小到大的顺序访问所有节点。遍历的顺序是:左 -> 根 -> 右。 #### 后序遍历(Post-order Traversal) 后序遍历是先递归地后序遍历左子树,然后递归地后序遍历右子树,最后访问根节点。遍历的顺序是:左 -> 右 -> 根。 ### 迭代形式的遍历 在上述`Tree.java`类中,我们还实现了迭代形式的遍历,通过使用栈来模拟递归过程。这种方法在处理大型树结构时,可以避免递归导致的栈溢出问题,并且可以提高效率。 ### 总结 通过上述代码和解释,我们可以看到,使用Java实现二叉树及其遍历方法相对直接。核心在于理解二叉树节点的结构和递归逻辑,以及如何使用栈来模拟递归过程。在实践中,了解并掌握这些基本算法对于解决复杂问题是非常有用的。此外,理解这些基本概念后,可以进一步探索更高级的二叉树算法,如平衡二叉树(AVL树)、红黑树等。
recommend-type

【性能测试基准】:为RK3588选择合适的NVMe性能测试工具指南

# 1. NVMe性能测试基础 ## 1.1 NVMe协议简介 NVMe,全称为Non-Volatile Memory Express,是专为固态驱动器设计的逻辑设备接口规范。与传统的SATA接口相比,NVMe通过使用PCI Express(PCIe)总线,大大提高了存储设备的数据吞吐量和IOPS(每秒输入输出操作次数),特别适合于高速的固态存储设备。
recommend-type

grant usage on ** to bcc@*%

<think>我们正在处理一个关于MySQL权限授予的问题。用户的具体需求是:为数据库用户'bcc'@'%'授予对所有数据库的使用权限(USAGE)。根据引用[1]和引用[4],我们知道在MySQL中,使用GRANT语句来授予权限。注意:USAGE权限实际上是一个“无权限”的权限,它仅仅表示用户存在,但没有任何实际权限(除了连接数据库)。如果用户只想允许用户连接数据库而不做任何操作,那么授予USAGE是合适的。但是,用户要求的是“使用权限”,我们需要确认用户是否真的只需要USAGE权限,还是需要其他权限?根据问题描述,用户明确说“使用权限”,并且指定了USAGE(在问题中提到了grantusa
recommend-type

Nokia手机通用密码计算器:解锁神器

根据给定的文件信息,我们可以了解到一个关于诺基亚(Nokia)手机解锁密码生成工具的知识点。在这个场景中,文件标题“Nokia手机密码计算器”表明了这是一个专门用于生成Nokia手机解锁密码的应用程序。描述中提到的“输入手机串号,就可得到10位通用密码,用于解锁手机”说明了该工具的使用方法和功能。 知识点详解如下: 1. Nokia手机串号的含义: 串号(Serial Number),也称为序列号,是每部手机独一无二的标识,通常印在手机的电池槽内或者在手机的设置信息中可以查看。它对于手机的售后维修、技术支持以及身份识别等方面具有重要意义。串号通常由15位数字组成,能够提供制造商、型号、生产日期和制造地点等相关信息。 2. Nokia手机密码计算器的工作原理: Nokia手机密码计算器通过特定的算法将手机的串号转换成一个10位的数字密码。这个密码是为了帮助用户在忘记手机的PIN码(个人识别码)、PUK码(PIN解锁码)或者某些情况下手机被锁定时,能够解锁手机。 3. 通用密码与安全性: 这种“通用密码”是基于一定算法生成的,不是随机的。它通常适用于老型号的Nokia手机,因为这些手机在设计时通常会采用固定的算法来生成密码。然而,随着科技的发展和安全需求的提高,现代手机通常不会提供此类算法生成的通用密码,以防止未经授权的解锁尝试。 4. Nokia手机的安全机制: 老型号的Nokia手机在设计时,通常会考虑到用户可能忘记密码的情况。为了保证用户在这种情况下的手机依然能够被解锁使用,制造商设置了一套安全机制,即通用密码系统。但这同时也带来了潜在的安全风险,因为如果算法被破解,那么任何知道串号的人都可能解锁这部手机。 5. MasterCode.exe文件的作用: 文件列表中的“MasterCode.exe”很可能就是上述“Nokia手机密码计算器”的可执行文件。用户需要运行这个程序,并按照程序的指示输入手机的串号,程序便会根据内部的算法计算出用于解锁的密码。 6. 注意事项和法律风险: 尽管此类工具在技术上帮助了用户,但必须强调的是,使用此类解锁工具或破解手机可能会违反相关的法律法规,特别是如果手机并非属于解锁者本人。在大多数国家,未经授权解锁手机都是违法的,尤其是在手机是通过运营商签订合约购买的情况下。因此,用户在尝试使用通用密码解锁手机前,应确保了解当地的法律法规,并且只在合法和合理的范围内使用此类工具。 7. 替代解锁方法: 对于现代智能手机,如果用户忘记了解锁密码,通常需要通过官方的客户服务来解决,例如联系手机制造商的客服或到指定的维修点进行解锁。一些手机还提供了账号解锁的功能,比如Apple的“查找我的iPhone”功能,以及Google的账号解锁选项。 总结来说,Nokia手机密码计算器是一个基于特定算法的实用工具,可帮助用户在忘记密码时解锁其Nokia手机。然而,用户在使用此类工具时应谨慎,并且必须遵守当地的法律法规。
recommend-type

【固态硬盘寿命延长】:RK3588平台NVMe维护技巧大公开

# 1. 固态硬盘寿命延长的基础知识 ## 1.1 固态硬盘的基本概念 固态硬盘(SSD)是现代计算设备中不可或缺的存储设备之一。与传统的机械硬盘(HDD)相比,SSD拥有更快的读写速度、更小的体积和更低的功耗。但是,SSD也有其生命周期限制,主要受限于NAND闪存的写入次数。 ## 1.2 SSD的写入次数和寿命 每块SSD中的NAND闪存单元都有有限的写入次数。这意味着,随着时间的推移,SSD的
recommend-type

缺省参数是什么

<think>首先,用户的问题是:“我想理解编程中缺省参数的概念和使用方法请问编程语言中缺省参数的定义及示例”。用户想了解缺省参数在编程中的概念和使用方法。我有系统级指令:-所有行内数学表达式必须使用$...$格式。-独立公式必须使用$$...$$格式并单独成段。-LaTeX语法正确。-使用中文回答。-生成相关问题。-回答中引用的段落末尾自然地添加引用标识,例如[^1]。用户可见层指令:-回答结构清晰,帮助用户逐步解决问题。-尽量保证回答真实可靠。参考站内引用:-引用[1]:缺省参数是声明或定义函数时为函数的参数指定的一个默认值。在调用该函数时,如果没有指定实参则采用该默认值,否则使用指定的实