fw.train(model[args.model_name])
在train.py中调用fw.train()进行模型的训练
def train(self, model_pattern):
# 对ori_model进行定义
ori_model = model_pattern(self.config)
ori_model.cuda()
# 定义Adam优化器
optimizer = optim.Adam(filter(lambda p: p.requires_grad, ori_model.parameters()), lr = self.config.learning_rate)
# 选择是否使用多gpu训练:
if self.config.multi_gpu:
model = nn.DataParallel(ori_model)
else:
model = ori_model
# 定义损失函数
def cal_loss(target, predict, mask):
# 这里self.loss_function为交叉熵损失函数
loss = self.loss_function(predict, target)
# 这行代码将损失值乘以掩码mask,然后对乘积进行求和。mask的值为1或0,这样做的目的是在计算损失时只考虑需要计算损失的样本。
loss = torch.sum(loss * mask) / torch.sum(mask)
return loss
# check the check_point dir
if not os.path.exists(self.config.checkpoint_dir):
os.mkdir(self.config.checkpoint_dir)
# check the log dir
if not os.path.exists(self.config.log_dir):
os.mkdir(self.config.log_dir)
# training data 对数据进行处理
train_data_loader = data_loader.get_loader(self.config, prefix = self.config.train_prefix, num_workers = 2)
# dev data
test_data_loader = data_loader.get_loader(self.config, prefix=self.config.dev_prefix, is_test=True)
# 将模型设置为训练模式
model.train()
global_step = 0
loss_sum = 0
best_f1_score = 0
best_precision = 0
best_recall = 0
best_epoch = 0
init_time = time.time()
start_time = time.time()
# the training loop
for epoch in range(self.config.max_epoch):
train_data_prefetcher = data_loader.DataPreFetcher(train_data_loader)
data = train_data_prefetcher.next()
epoch_start_time = time.time()
while data is not None:
'''
通过model(data)的调用,输入数据data会被传递给模型的前向传播函数,然后模型会根据输入数据进行计算和预测。
模型的前向传播函数会将数据经过模型的层次结构,进行一系列的计算和变换,最终得到模型对输入数据的预测结果。
pred_triple_matrix会保存模型对输入数据预测结果的输出。
'''
pred_triple_matrix = model(data)
# target,predict,mask
triple_loss = cal_loss(data['triple_matrix'], pred_triple_matrix, data['loss_mask'])
optimizer.zero_grad()
triple_loss.backward()
optimizer.step()
global_step += 1
loss_sum += triple_loss.item()
if global_step % self.config.period == 0:
cur_loss = loss_sum / self.config.period
elapsed = time.time() - start_time
self.logging("epoch: {:3d}, step: {:4d}, speed: {:5.2f}ms/b, train loss: {:5.3f}".
format(epoch, global_step, elapsed * 1000 / self.config.period, cur_loss))
loss_sum = 0
start_time = time.time()
data = train_data_prefetcher.next()
print("total time {}".format(time.time() - epoch_start_time))
# if epoch > 20 and (epoch + 1) % self.config.test_epoch == 0:
if epoch == 0:
eval_start_time = time.time()
model.eval()
# call the test function
precision, recall, f1_score = self.test(test_data_loader, model, current_f1=best_f1_score, output=self.config.result_save_name)
self.logging('epoch {:3d}, eval time: {:5.2f}s, f1: {:4.3f}, precision: {:4.3f}, recall: {:4.3f}'.
format(epoch, time.time() - eval_start_time, f1_score, precision, recall))
if f1_score > best_f1_score:
best_f1_score = f1_score
best_epoch = epoch
best_precision = precision
best_recall = recall
self.logging("saving the model, epoch: {:3d}, precision: {:4.3f}, recall: {:4.3f}, best f1: {:4.3f}".
format(best_epoch, best_precision, best_recall, best_f1_score))
# save the best model
path = os.path.join(self.config.checkpoint_dir, self.config.model_save_name)
if not self.config.debug:
torch.save(ori_model.state_dict(), path)
model.train()
# manually release the unused cache
torch.cuda.empty_cache()
self.logging("finish training")
self.logging("best epoch: {:3d}, precision: {:4.3f}, recall: {:4.3}, best f1: {:4.3f}, total time: {:5.2f}s".
format(best_epoch, best_precision, best_recall, best_f1_score, time.time() - init_time))
train_data_loader = data_loader.get_loader(self.config, prefix = self.config.train_prefix, num_workers = 2)
现在进入OneRel模型的关键部分,对输入模型的数据的格式的处理
def get_loader(config, prefix, is_test=False, num_workers=0, collate_fn=re_collate_fn):
dataset = REDataset(config, prefix, is_test, tokenizer)
if not is_test:
data_loader = DataLoader(dataset=dataset,
batch_size=config.batch_size,
shuffle=True,
pin_memory=True,
num_workers=num_workers,
collate_fn=collate_fn)
else:
data_loader = DataLoader(dataset=dataset,
batch_size=1,
shuffle=False,
pin_memory=True,
num_workers=num_workers,
collate_fn=collate_fn)
return data_loader
dataset = REDataset(config, prefix, is_test, tokenizer)
调用REDataset类,
# 完成自定义数据集,新建REDataset类继承torch的Dataset类
class REDataset(Dataset):
def __init__(self, config, prefix, is_test, tokenizer):
self.config = config
# prefix就是传过来的文件名称
self.prefix = prefix
self.is_test = is_test
self.tokenizer = tokenizer
# 首先读取json文件的数据存储到self.json_data中
if self.config.debug:
self.json_data = json.load(open(os.path.join(self.config.data_path, prefix + '.json')))[:12]
else:
self.json_data = json.load(open(os.path.join(self.config.data_path, prefix + '.json')))
# 读取关系到id的对应,存到self.rel2id中
self.rel2id = json.load(open(os.path.join(self.config.data_path, 'rel2id.json')))[1]
# 读取tag到id的对应,存到self.tag2id中,为关系矩阵的标记做准备 {"A": 0,"HB-TB": 1,"HB-TE": 2,"HE-TE": 3}
self.tag2id = json.load(open('data/tag2id.json'))[1]
def __len__(self):
return len(self.json_data)
def __getitem__(self, idx):
# __getitem__函数接受一个索引idx作为输入,并获取与该索引相关联的JSON数据。然后从JSON数据中提取文本信息。
ins_json_data = self.json_data[idx]
text = ins_json_data['text']
# text = ' '.join(text.split()[:self.config.max_len])
# text = basicTokenizer.tokenize(text)
# text = " ".join(text[:self.config.max_len])
# 使用分词器对文本进行分词。如果分词后的token数超过了在config中定义的最大长度,就进行截断。变量text_len存储了分词后的文本长度。
tokens = self.tokenizer.tokenize(text)
if len(tokens) > self.config.bert_max_len:
tokens = tokens[: self.config.bert_max_len]
# 对text进行分词后都存在tokens中
text_len = len(tokens)
# 训练
if not self.is_test:
# 如果代码不处于测试模式,它会创建一个字典s2ro_map来存subject-object关系映射。 map[subject] = object+relation编号
# 它遍历JSON数据中的triple_list,对主题和客体字符串进行分词,找到它们在分词后的文本中的相应索引,并将它们存储在s2ro_map字典中。
s2ro_map = {}
for triple in ins_json_data['triple_list']:
# 0和2为头尾实体,1为关系
triple = (self.tokenizer.tokenize(triple[0])[1:-1],
triple[1], self.tokenizer.tokenize(triple[2])[1:-1])
sub_head_idx = find_head_idx(tokens, triple[0])
obj_head_idx = find_head_idx(tokens, triple[2])
# 若头尾实体均可以在tokens中找到
if sub_head_idx != -1 and obj_head_idx != -1:
sub = (sub_head_idx, sub_head_idx + len(triple[0]) - 1)
if sub not in s2ro_map:
s2ro_map[sub] = []
s2ro_map[sub].append((obj_head_idx, obj_head_idx + len(triple[2]) - 1, self.rel2id[triple[1]]))
if s2ro_map:
# token_ids:一个整数列表,表示经过编码的文本的标记或令牌。
# segment_ids:一个整数列表,用于指示文本的不同部分,例如句子的开始和结束。在一些语言模型中,这部分可能用于表示上下文的分段。
token_ids, segment_ids = self.tokenizer.encode(first=text)
# masks 作为掩码用来指示输入文本的哪些部分是真实的文本内容,哪些部分是填充或无效的。通过将 segment_ids 赋值给 masks,可以将 segment_ids 中的信息直接应用于掩码。
# 通常,segment_ids 用于处理输入中的不同句子或段落,其中每个句子或段落都被分配一个唯一的 segment_id。在这种情况下,masks 在处理文本时可以通过将填充部分标记为 0 或无效来忽略它们。
masks = segment_ids
if len(token_ids) > text_len:
token_ids = token_ids[:text_len]
masks = masks[:text_len]
mask_length = len(masks)
# 将token_ids和masks转化为numpy数组,对masks中的每个元素加1
token_ids = np.array(token_ids)
masks = np.array(masks) + 1
# 创建一个shape为(mask_length, mask_length)的全1数组loss_masks。
loss_masks = np.ones((mask_length, mask_length))
# 创建一个shape为(self.config.rel_num, text_len, text_len)、元素都为0的triple_matrix数组
triple_matrix = np.zeros((self.config.rel_num, text_len, text_len))
for s in s2ro_map:
sub_head = s[0]
sub_tail = s[1]
for ro in s2ro_map.get((sub_head, sub_tail), []):
obj_head, obj_tail, relation = ro
triple_matrix[relation][sub_head][obj_head] = self.tag2id['HB-TB']
triple_matrix[relation][sub_head][obj_tail] = self.tag2id['HB-TE']
triple_matrix[relation][sub_tail][obj_tail] = self.tag2id['HE-TE']
return token_ids, masks, loss_masks, text_len, triple_matrix, ins_json_data['triple_list'], tokens
else:
print(ins_json_data)
return None
# 测试
else:
token_ids, masks = self.tokenizer.encode(first=text)
if len(token_ids) > text_len:
token_ids = token_ids[:text_len]
masks = masks[:text_len]
token_ids = np.array(token_ids)
masks = np.array(masks) + 1
mask_length = len(masks)
loss_masks = np.array(masks) + 1
triple_matrix = np.zeros((self.config.rel_num, text_len, text_len))
return token_ids, masks, loss_masks, text_len, triple_matrix, ins_json_data['triple_list'], tokens