【实体关系抽取】之——OneRel代码学习笔记(二)

fw.train(model[args.model_name])

在train.py中调用fw.train()进行模型的训练

    def train(self, model_pattern):
        # 对ori_model进行定义
        ori_model = model_pattern(self.config)
        ori_model.cuda()
        # 定义Adam优化器
        optimizer = optim.Adam(filter(lambda p: p.requires_grad, ori_model.parameters()), lr = self.config.learning_rate)

        # 选择是否使用多gpu训练:
        if self.config.multi_gpu:
            model = nn.DataParallel(ori_model)
        else:
            model = ori_model

        # 定义损失函数
        def cal_loss(target, predict, mask):
            # 这里self.loss_function为交叉熵损失函数
            loss = self.loss_function(predict, target)
            # 这行代码将损失值乘以掩码mask,然后对乘积进行求和。mask的值为1或0,这样做的目的是在计算损失时只考虑需要计算损失的样本。
            loss = torch.sum(loss * mask) / torch.sum(mask)
            return loss

        # check the check_point dir
        if not os.path.exists(self.config.checkpoint_dir):
            os.mkdir(self.config.checkpoint_dir)

        # check the log dir
        if not os.path.exists(self.config.log_dir):
            os.mkdir(self.config.log_dir)

        # training data 对数据进行处理
        train_data_loader = data_loader.get_loader(self.config, prefix = self.config.train_prefix, num_workers = 2)
        # dev data
        test_data_loader = data_loader.get_loader(self.config, prefix=self.config.dev_prefix, is_test=True)

        # 将模型设置为训练模式
        model.train()
        global_step = 0
        loss_sum = 0

        best_f1_score = 0
        best_precision = 0
        best_recall = 0

        best_epoch = 0
        init_time = time.time()
        start_time = time.time()

        # the training loop
        for epoch in range(self.config.max_epoch):
            train_data_prefetcher = data_loader.DataPreFetcher(train_data_loader)
            data = train_data_prefetcher.next()
            epoch_start_time = time.time()

            while data is not None:

                '''
                通过model(data)的调用,输入数据data会被传递给模型的前向传播函数,然后模型会根据输入数据进行计算和预测。
                模型的前向传播函数会将数据经过模型的层次结构,进行一系列的计算和变换,最终得到模型对输入数据的预测结果。
                pred_triple_matrix会保存模型对输入数据预测结果的输出。
                '''
                pred_triple_matrix = model(data)

                # target,predict,mask
                triple_loss = cal_loss(data['triple_matrix'], pred_triple_matrix, data['loss_mask'])

                optimizer.zero_grad()
                triple_loss.backward()
                optimizer.step()

                global_step += 1
                loss_sum += triple_loss.item()

                if global_step % self.config.period == 0:
                    cur_loss = loss_sum / self.config.period
                    elapsed = time.time() - start_time
                    self.logging("epoch: {:3d}, step: {:4d}, speed: {:5.2f}ms/b, train loss: {:5.3f}".
                                 format(epoch, global_step, elapsed * 1000 / self.config.period, cur_loss))
                    loss_sum = 0
                    start_time = time.time()

                data = train_data_prefetcher.next()
            print("total time {}".format(time.time() - epoch_start_time))

            # if epoch > 20 and (epoch + 1) % self.config.test_epoch == 0:
            if epoch == 0:
                eval_start_time = time.time()
                model.eval()
                # call the test function
                precision, recall, f1_score = self.test(test_data_loader, model, current_f1=best_f1_score, output=self.config.result_save_name)
                
                self.logging('epoch {:3d}, eval time: {:5.2f}s, f1: {:4.3f}, precision: {:4.3f}, recall: {:4.3f}'.
                             format(epoch, time.time() - eval_start_time, f1_score, precision, recall))

                if f1_score > best_f1_score:
                    best_f1_score = f1_score
                    best_epoch = epoch
                    best_precision = precision
                    best_recall = recall
                    self.logging("saving the model, epoch: {:3d}, precision: {:4.3f}, recall: {:4.3f}, best f1: {:4.3f}".
                                 format(best_epoch, best_precision, best_recall, best_f1_score))
                    # save the best model
                    path = os.path.join(self.config.checkpoint_dir, self.config.model_save_name)
                    if not self.config.debug:
                        torch.save(ori_model.state_dict(), path)

                model.train()

            # manually release the unused cache
            torch.cuda.empty_cache()

        self.logging("finish training")
        self.logging("best epoch: {:3d}, precision: {:4.3f}, recall: {:4.3}, best f1: {:4.3f}, total time: {:5.2f}s".
                     format(best_epoch, best_precision, best_recall, best_f1_score, time.time() - init_time))

train_data_loader = data_loader.get_loader(self.config, prefix = self.config.train_prefix, num_workers = 2)

现在进入OneRel模型的关键部分,对输入模型的数据的格式的处理

def get_loader(config, prefix, is_test=False, num_workers=0, collate_fn=re_collate_fn):
    dataset = REDataset(config, prefix, is_test, tokenizer)
    if not is_test:
        data_loader = DataLoader(dataset=dataset,
                                 batch_size=config.batch_size,
                                 shuffle=True,
                                 pin_memory=True,
                                 num_workers=num_workers,
                                 collate_fn=collate_fn)
    else:
        data_loader = DataLoader(dataset=dataset,
                                 batch_size=1,
                                 shuffle=False,
                                 pin_memory=True,
                                 num_workers=num_workers,
                                 collate_fn=collate_fn)
    return data_loader
dataset = REDataset(config, prefix, is_test, tokenizer)

调用REDataset类,

# 完成自定义数据集,新建REDataset类继承torch的Dataset类
class REDataset(Dataset):
    def __init__(self, config, prefix, is_test, tokenizer):
        self.config = config
        # prefix就是传过来的文件名称
        self.prefix = prefix
        self.is_test = is_test
        self.tokenizer = tokenizer

        # 首先读取json文件的数据存储到self.json_data中
        if self.config.debug:
            self.json_data = json.load(open(os.path.join(self.config.data_path, prefix + '.json')))[:12]
        else:
            self.json_data = json.load(open(os.path.join(self.config.data_path, prefix + '.json')))
        
        # 读取关系到id的对应,存到self.rel2id中
        self.rel2id = json.load(open(os.path.join(self.config.data_path, 'rel2id.json')))[1]
        # 读取tag到id的对应,存到self.tag2id中,为关系矩阵的标记做准备     {"A": 0,"HB-TB": 1,"HB-TE": 2,"HE-TE": 3}
        self.tag2id = json.load(open('data/tag2id.json'))[1]

    def __len__(self):
        return len(self.json_data)

    def __getitem__(self, idx):
        # __getitem__函数接受一个索引idx作为输入,并获取与该索引相关联的JSON数据。然后从JSON数据中提取文本信息。
        ins_json_data = self.json_data[idx]
        text = ins_json_data['text']
        # text = ' '.join(text.split()[:self.config.max_len])
        # text = basicTokenizer.tokenize(text)
        # text = " ".join(text[:self.config.max_len])
        # 使用分词器对文本进行分词。如果分词后的token数超过了在config中定义的最大长度,就进行截断。变量text_len存储了分词后的文本长度。
        tokens = self.tokenizer.tokenize(text)
        if len(tokens) > self.config.bert_max_len:
            tokens = tokens[: self.config.bert_max_len]
        # 对text进行分词后都存在tokens中
        text_len = len(tokens)

        # 训练
        if not self.is_test:
            # 如果代码不处于测试模式,它会创建一个字典s2ro_map来存subject-object关系映射。 map[subject] = object+relation编号
            # 它遍历JSON数据中的triple_list,对主题和客体字符串进行分词,找到它们在分词后的文本中的相应索引,并将它们存储在s2ro_map字典中。
            s2ro_map = {}
            for triple in ins_json_data['triple_list']:
                # 0和2为头尾实体,1为关系
                triple = (self.tokenizer.tokenize(triple[0])[1:-1],
                         triple[1], self.tokenizer.tokenize(triple[2])[1:-1])
                sub_head_idx = find_head_idx(tokens, triple[0])
                obj_head_idx = find_head_idx(tokens, triple[2])
                # 若头尾实体均可以在tokens中找到
                if sub_head_idx != -1 and obj_head_idx != -1:
                    sub = (sub_head_idx, sub_head_idx + len(triple[0]) - 1)
                    if sub not in s2ro_map:
                        s2ro_map[sub] = []
                    s2ro_map[sub].append((obj_head_idx, obj_head_idx + len(triple[2]) - 1, self.rel2id[triple[1]]))
            
            if s2ro_map:
                # token_ids:一个整数列表,表示经过编码的文本的标记或令牌。
                # segment_ids:一个整数列表,用于指示文本的不同部分,例如句子的开始和结束。在一些语言模型中,这部分可能用于表示上下文的分段。
                token_ids, segment_ids = self.tokenizer.encode(first=text)
                # masks 作为掩码用来指示输入文本的哪些部分是真实的文本内容,哪些部分是填充或无效的。通过将 segment_ids 赋值给 masks,可以将 segment_ids 中的信息直接应用于掩码。
                # 通常,segment_ids 用于处理输入中的不同句子或段落,其中每个句子或段落都被分配一个唯一的 segment_id。在这种情况下,masks 在处理文本时可以通过将填充部分标记为 0 或无效来忽略它们。
                masks = segment_ids
                if len(token_ids) > text_len:
                    token_ids = token_ids[:text_len]
                    masks = masks[:text_len]
                mask_length = len(masks)
                # 将token_ids和masks转化为numpy数组,对masks中的每个元素加1
                token_ids = np.array(token_ids)
                masks = np.array(masks) + 1
                # 创建一个shape为(mask_length, mask_length)的全1数组loss_masks。
                loss_masks = np.ones((mask_length, mask_length))
                # 创建一个shape为(self.config.rel_num, text_len, text_len)、元素都为0的triple_matrix数组
                triple_matrix = np.zeros((self.config.rel_num, text_len, text_len))
                
                for s in s2ro_map:
                    sub_head = s[0]
                    sub_tail = s[1]
                    for ro in s2ro_map.get((sub_head, sub_tail), []):
                        obj_head, obj_tail, relation = ro
                        triple_matrix[relation][sub_head][obj_head] = self.tag2id['HB-TB']
                        triple_matrix[relation][sub_head][obj_tail] = self.tag2id['HB-TE']
                        triple_matrix[relation][sub_tail][obj_tail] = self.tag2id['HE-TE']
                
                return token_ids, masks, loss_masks, text_len, triple_matrix, ins_json_data['triple_list'], tokens
            else:
                print(ins_json_data)
                return None

        # 测试
        else:
            token_ids, masks = self.tokenizer.encode(first=text)
            if len(token_ids) > text_len:
                token_ids = token_ids[:text_len]
                masks = masks[:text_len]
            token_ids = np.array(token_ids)
            masks = np.array(masks) + 1
            mask_length = len(masks)
            loss_masks = np.array(masks) + 1
            triple_matrix = np.zeros((self.config.rel_num, text_len, text_len))
            return token_ids, masks, loss_masks, text_len, triple_matrix, ins_json_data['triple_list'], tokens
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值