活动介绍

pred_y_logits = model(t_x, t_cc, t_cd, u) pred_y = torch.argmax(torch.softmax(pred_y_logits, dim=2), axis=2) loss = F.mse_loss(pred_y.to(torch.float64), t_y)怎么让pred_y保留梯度

时间: 2025-08-21 12:27:02 AIGC 浏览: 30
<think>好的,用户现在的问题是关于在PyTorch中使用MSE损失函数时,pred_y无法保留梯度的问题,并且提到了argmax断开梯度的情况。用户希望知道如何让pred_y保留梯度。需要结合用户提供的错误信息和之前的回答来分析。 首先,用户之前的错误是关于AttributeError,指出在计算MSE损失时,可能错误地使用了.long方法而不是张量本身。现在的问题不同,是关于梯度保留的问题。用户可能在训练过程中发现pred_y的梯度没有被正确保留,导致无法更新模型参数。 需要回顾PyTorch的自动求导机制。根据引用[4],张量需要requires_grad=True才能保留梯度,而且必须是浮点类型。如果pred_y是通过某些操作得到的,比如argmax,那可能中断了梯度。因为argmax返回的是索引,属于整数类型,且不可导,导致梯度无法传播。 用户提到使用argmax断开梯度,这可能是因为在某些处理步骤中使用了argmax,比如将模型输出转换为预测标签,这样会导致梯度中断。例如,如果用户在计算损失之前对pred_y执行了argmax操作,那么后续的损失计算将无法回溯到模型参数,因为argmax是不可导的。 但用户现在的问题是使用MSE损失,而MSE通常用于回归任务,预测连续值。所以可能用户的情况是,pred_y是模型输出,但可能中间经过了某些不可导的操作,比如argmax,导致梯度无法保留。或者,可能在计算损失之前对pred_y进行了某种处理,比如转换为整数类型,从而切断了梯度。 根据引用[2],BCEWithLogitsLoss在计算时保持了高精度,但这里用户使用的是MSE,所以可能需要检查pred_y是否保持了浮点类型,并且requires_grad=True。同时,检查是否有操作导致计算图断开,比如使用detach()或转换为numpy数组等。 解决方案可能包括确保pred_y在计算损失时是浮点张量,并且未经过不可导操作。如果必须使用argmax,可能需要使用Gumbel-Softmax等技巧来近似梯度,或者在计算损失时避免使用argmax,直接使用原始输出。 另外,根据引用[3],优化器需要正确的参数梯度,所以需要确保模型参数在优化器中正确注册,并且反向传播后梯度被正确计算和更新。 总结步骤: 1. 检查pred_y是否requires_grad=True,是否为浮点类型。 2. 确保在计算损失前没有执行不可导操作(如argmax)。 3. 使用MSE损失时,输入和目标都应为浮点型,且维度匹配。 4. 避免在计算图中使用中断梯度的操作,如索引转换或类型转换。 5. 如果需要离散操作,考虑使用替代方法如Gumbel-Softmax或Straight-Through Estimator。 可能需要用户提供更多代码细节,但根据现有信息,给出上述建议。同时,参考之前的回答,确保张量类型和维度正确,并避免不可导操作。</think>在PyTorch中解决`pred_y`无法保留梯度的问题,需结合**梯度传播机制**和**张量操作特性**分析。以下是关键步骤与解决方案: --- ### 一、问题根源分析 1. **`torch.argmax`导致的梯度中断** `torch.argmax`返回的是**离散索引值**(整型张量),其操作不可导,会**截断计算图**,导致`pred_y`无法保留梯度[^1]。 2. **张量类型或属性错误** - 若`pred_y`被转换为整型(如`.long()`),会丢失梯度属性[^4] - 未正确设置`requires_grad=True` --- ### 二、解决方案 #### 方法1:绕过不可导操作(推荐) 在训练阶段**避免使用`argmax`**,直接使用模型输出的原始概率值计算损失: ```python # 模型输出为浮点型张量(需保留梯度) pred_y = model(inputs) # shape: [batch_size, num_classes] # 使用原始输出计算MSE损失(需确保target是浮点型) loss = F.mse_loss(pred_y, target.float()) loss.backward() ``` #### 方法2:梯度近似(特殊场景) 若必须使用离散操作(如分类任务),可通过以下技术**近似梯度**: 1. **Gumbel-Softmax技巧** 引入温度参数τ,通过可导方式近似离散分布: ```python logits = model(inputs) # 模型输出未归一化的logits temperature = 1.0 samples = F.gumbel_softmax(logits, tau=temperature, hard=True) loss = F.mse_loss(samples, target) ``` 2. **Straight-Through Estimator (STE)** 在前向传播中使用`argmax`,反向传播时绕过不可导操作: ```python class STEArgmax(torch.autograd.Function): @staticmethod def forward(ctx, input): return torch.argmax(input, dim=1) @staticmethod def backward(ctx, grad_output): return grad_output.clone().float() # 伪梯度 # 使用方式 pred_indices = STEArgmax.apply(logits) one_hot = F.one_hot(pred_indices, num_classes=10).float() loss = F.mse_loss(one_hot, target) ``` --- ### 三、关键检查点 | 检查项 | 验证方法 | 修正措施 | |-------|---------|----------| | **梯度属性** | `print(pred_y.requires_grad)` | 确保`pred_y`的计算路径包含可导操作 | | **张量类型** | `print(pred_y.dtype)` | 转换为浮点型:`pred_y = pred_y.float()` | | **维度匹配** | `print(pred_y.shape, target.shape)` | 使用`view()`或`squeeze()`调整 | | **优化器关联** | `print(optimizer.param_groups)` | 确认模型参数已注册到优化器 | --- ### 四、完整示例代码 ```python import torch import torch.nn as nn import torch.nn.functional as F class Model(nn.Module): def __init__(self): super().__init__() self.fc = nn.Linear(10, 5) def forward(self, x): return self.fc(x) # 输出未归一化的logits model = Model() optimizer = torch.optim.Adam(model.parameters()) # 前向传播(保留梯度) inputs = torch.randn(32, 10, requires_grad=True) pred_y = model(inputs) # shape: [32,5] # 构造目标(需与pred_y同维度) target = torch.randn(32, 5) # 计算MSE损失(自动保留梯度) loss = F.mse_loss(pred_y, target) loss.backward() optimizer.step() ``` --- ### 五、扩展知识 1. **计算图保护机制** 若需多次反向传播,需设置`retain_graph=True`: ```python loss.backward(retain_graph=True) ``` 2. **梯度累积技巧** 小批量训练时可通过累积梯度模拟大批量效果: ```python for i, (inputs, targets) in enumerate(dataloader): loss = model(inputs, targets) loss.backward() if (i+1) % accumulation_steps == 0: optimizer.step() optimizer.zero_grad() ``` --- ### 六、相关问题 1. 如何可视化PyTorch计算图的梯度流向? 2. `detach()`和`with torch.no_grad()`有什么区别? 3. 如何处理分类任务中MSE损失与交叉熵损失的选择?
阅读全文

相关推荐

class DiceLoss(_Loss): def __init__( self, mode: str, classes: Optional[List[int]] = None, log_loss: bool = False, from_logits: bool = True, smooth: float = 0.0, ignore_index: Optional[int] = None, eps: float = 1e-7, ): """Dice loss for image segmentation task. It supports binary, multiclass and multilabel cases Args: mode: Loss mode 'binary', 'multiclass' or 'multilabel' classes: List of classes that contribute in loss computation. By default, all channels are included. log_loss: If True, loss computed as - log(dice_coeff), otherwise 1 - dice_coeff from_logits: If True, assumes input is raw logits smooth: Smoothness constant for dice coefficient (a) ignore_index: Label that indicates ignored pixels (does not contribute to loss) eps: A small epsilon for numerical stability to avoid zero division error (denominator will be always greater or equal to eps) Shape - **y_pred** - torch.Tensor of shape (N, C, H, W) - **y_true** - torch.Tensor of shape (N, H, W) or (N, C, H, W) Reference https://github.com/BloodAxe/pytorch-toolbelt """ assert mode in {BINARY_MODE, MULTILABEL_MODE, MULTICLASS_MODE} super(DiceLoss, self).__init__() self.mode = mode if classes is not None: assert mode != BINARY_MODE, ( "Masking classes is not supported with mode=binary" ) classes = to_tensor(classes, dtype=torch.long) self.classes = classes self.from_logits = from_logits self.smooth = smooth self.eps = eps self.log_loss = log_loss self.ignore_index = ignore_index def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: assert y_true.size(0) == y_pred.size(0) if self.from_logits: # Apply activations to get [0..1] class probabilities # Using Log-Exp as this gives more numerically stable result and does not cause vanishing gradient on # extreme values 0 and 1 if self.mode == MULTICLASS_MODE: y_pred = y_pred.log_softmax(dim=1).exp() else: y_pred = F.logsigmoid(y_pred).exp() bs = y_true.size(0) num_classes = y_pred.size(1) dims = (0, 2) if self.mode == BINARY_MODE: y_true = y_true.view(bs, 1, -1) y_pred = y_pred.view(bs, 1, -1) if self.ignore_index is not None: mask = y_true != self.ignore_index y_pred = y_pred * mask y_true = y_true * mask if self.mode == MULTICLASS_MODE: y_true = y_true.view(bs, -1) y_pred = y_pred.view(bs, num_classes, -1) if self.ignore_index is not None: mask = y_true != self.ignore_index y_pred = y_pred * mask.unsqueeze(1) y_true = F.one_hot( (y_true * mask).to(torch.long), num_classes ) # N,H*W -> N,H*W, C y_true = y_true.permute(0, 2, 1) * mask.unsqueeze(1) # N, C, H*W else: y_true = F.one_hot(y_true, num_classes) # N,H*W -> N,H*W, C y_true = y_true.permute(0, 2, 1) # N, C, H*W if self.mode == MULTILABEL_MODE: y_true = y_true.view(bs, num_classes, -1) y_pred = y_pred.view(bs, num_classes, -1) if self.ignore_index is not None: mask = y_true != self.ignore_index y_pred = y_pred * mask y_true = y_true * mask scores = self.compute_score( y_pred, y_true.type_as(y_pred), smooth=self.smooth, eps=self.eps, dims=dims ) if self.log_loss: loss = -torch.log(scores.clamp_min(self.eps)) else: loss = 1.0 - scores # Dice loss is undefined for non-empty classes # So we zero contribution of channel that does not have true pixels # NOTE: A better workaround would be to use loss term mean(y_pred) # for this case, however it will be a modified jaccard loss mask = y_true.sum(dims) > 0 loss *= mask.to(loss.dtype) if self.classes is not None: loss = loss[self.classes] return self.aggregate_loss(loss) def aggregate_loss(self, loss): return loss.mean() def compute_score( self, output, target, smooth=0.0, eps=1e-7, dims=None ) -> torch.Tensor: return soft_dice_score(output, target, smooth, eps, dims)源代码如下

# 这是一个示例 Python 脚本。 # 按 Shift+F10 执行或将其替换为您的代码。 # 按 双击 Shift 在所有地方搜索类、文件、工具窗口、操作和设置。 import argparse import math import pickle import torch import torch.nn as nn import torch.nn.functional as F from tqdm import tqdm from omegaconf import OmegaConf from sklearn.metrics import f1_score from torch.utils.data import Dataset, DataLoader from torch.nn import TransformerEncoderLayer, TransformerEncoder restypes = [ 'A', 'R', 'N', 'D', 'C', 'Q', 'E', 'G', 'H', 'I', 'L', 'K', 'M', 'F', 'P', 'S', 'T', 'W', 'Y', 'V' ] unsure_restype = 'X' unknown_restype = 'U' def make_dataset(data_config, train_rate=0.7, valid_rate=0.2): data_path = data_config.data_path with open(data_path, 'rb') as f: data = pickle.load(f) total_number = len(data) train_sep = int(total_number * train_rate) valid_sep = int(total_number * (train_rate + valid_rate)) train_data_dicts = data[:train_sep] valid_data_dicts = data[train_sep:valid_sep] test_data_dicts = data[valid_sep:] train_dataset = DisProtDataset(train_data_dicts) valid_dataset = DisProtDataset(valid_data_dicts) test_dataset = DisProtDataset(test_data_dicts) return train_dataset, valid_dataset, test_dataset class DisProtDataset(Dataset): def __init__(self, dict_data): sequences = [d['sequence'] for d in dict_data] labels = [d['label'] for d in dict_data] assert len(sequences) == len(labels) self.sequences = sequences self.labels = labels self.residue_mapping = {'X':20} self.residue_mapping.update(dict(zip(restypes, range(len(restypes))))) def __len__(self): return len(self.sequences) def __getitem__(self, idx): sequence = torch.zeros(len(self.sequences[idx]), len(self.residue_mapping)) for i, c in enumerate(self.sequences[idx]): if c not in restypes: c = 'X' sequence[i][self.residue_mapping[c]] = 1 label = torch.tensor([int(c) for c in self.labels[idx]], dtype=torch.long) return sequence, label class PositionalEncoding(nn.Module): def __init__(self, d_model, dropout=0.0, max_len=40): super().__init__() position = torch.arange(max_len).unsqueeze(1) div_term = torch.exp( torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model) ) pe = torch.zeros(1, max_len, d_model) pe[0, :, 0::2] = torch.sin(position * div_term) pe[0, :, 1::2] = torch.cos(position * div_term) self.register_buffer("pe", pe) self.dropout = nn.Dropout(p=dropout) def forward(self, x): if len(x.shape) == 3: x = x + self.pe[:, : x.size(1)] elif len(x.shape) == 4: x = x + self.pe[:, :x.size(1), None, :] return self.dropout(x) class DisProtModel(nn.Module): def __init__(self, model_config): super().__init__() self.d_model = model_config.d_model self.n_head = model_config.n_head self.n_layer = model_config.n_layer self.input_layer = nn.Linear(model_config.i_dim, self.d_model) self.position_embed = PositionalEncoding(self.d_model, max_len=20000) self.input_norm = nn.LayerNorm(self.d_model) self.dropout_in = nn.Dropout(p=0.1) encoder_layer = TransformerEncoderLayer( d_model=self.d_model, nhead=self.n_head, activation='gelu', batch_first=True) self.transformer = TransformerEncoder(encoder_layer, num_layers=self.n_layer) self.output_layer = nn.Sequential( nn.Linear(self.d_model, self.d_model), nn.GELU(), nn.Dropout(p=0.1), nn.Linear(self.d_model, model_config.o_dim) ) def forward(self, x): x = self.input_layer(x) x = self.position_embed(x) x = self.input_norm(x) x = self.dropout_in(x) x = self.transformer(x) x = self.output_layer(x) return x def metric_fn(pred, gt): pred = pred.detach().cpu() gt = gt.detach().cpu() pred_labels = torch.argmax(pred, dim=-1).view(-1) gt_labels = gt.view(-1) score = f1_score(y_true=gt_labels, y_pred=pred_labels, average='micro') return score if __name__ == '__main__': device = 'cuda' if torch.cuda.is_available() else 'cpu' parser = argparse.ArgumentParser('IDRs prediction') parser.add_argument('--config_path', default='./config.yaml') args = parser.parse_args() config = OmegaConf.load(args.config_path) train_dataset, valid_dataset, test_dataset = make_dataset(config.data) train_dataloader = DataLoader(dataset=train_dataset, **config.train.dataloader) valid_dataloader = DataLoader(dataset=valid_dataset, batch_size=1, shuffle=False) model = DisProtModel(config.model) model = model.to(device) optimizer = torch.optim.AdamW(model.parameters(), lr=config.train.optimizer.lr, weight_decay=config.train.optimizer.weight_decay) loss_fn = nn.CrossEntropyLoss() model.eval() metric = 0. with torch.no_grad(): for sequence, label in valid_dataloader: sequence = sequence.to(device) label = label.to(device) pred = model(sequence) metric += metric_fn(pred, label) print("init f1_score:", metric / len(valid_dataloader)) for epoch in range(config.train.epochs): # train loop progress_bar = tqdm( train_dataloader, initial=0, desc=f"epoch:{epoch:03d}", ) model.train() total_loss = 0. for sequence, label in progress_bar: sequence = sequence.to(device) label = label.to(device) pred = model(sequence) loss = loss_fn(pred.permute(0, 2, 1), label) progress_bar.set_postfix(loss=loss.item()) total_loss += loss.item() optimizer.zero_grad() loss.backward() optimizer.step() avg_loss = total_loss / len(train_dataloader) # valid loop model.eval() metric = 0. with torch.no_grad(): for sequence, label in valid_dataloader: sequence = sequence.to(device) label = label.to(device) pred = model(sequence) metric += metric_fn(pred, label) print(f"avg_training_loss: {avg_loss}, f1_score: {metric / len(valid_dataloader)}") # 保存当前 epoch 的模型 save_path = f"model.pkl" torch.save(model.state_dict(), save_path) print(f"Model saved to {save_path}") 帮我分析一下这个代码是干什么的

import dgl import numpy as np import torch import torch.nn as nn import dgl.function as fn # 生成10个节点和15条边的图 g = dgl.rand_graph(10, 15) # 为每个节点随机生成一个特征向量 feat = np.random.rand(10, 5) # 为每条边随机生成一个特征向量 e_feat = np.random.rand(15, 3) # 将特征向量添加到图中 g.ndata['feat'] = torch.from_numpy(feat) g.edata['e_feat'] =torch.from_numpy(e_feat) # 随机给每个节点分配一个标签 labels = np.random.randint(0, 3, size=(10,)) g.ndata['label'] = torch.from_numpy(labels) class GraphSAGE(nn.Module): def __init__(self, in_feats, h_feats, num_classes): super(GraphSAGE, self).__init__() self.conv1 = dgl.nn.SAGEConv(in_feats, h_feats, 'mean') self.conv2 = dgl.nn.SAGEConv(h_feats, num_classes, 'mean') def forward(self, g, in_feat): h = self.conv1(g, in_feat) h = torch.relu(h) h = self.conv2(g, h) g.ndata['h'] = h hg = dgl.mean_nodes(g, 'h') return hg # 定义超参数 in_feats = 5 h_feats = 10 num_classes = 3 lr = 0.01 num_epochs = 20 # 创建模型和优化器 model = GraphSAGE(in_feats, h_feats, num_classes) optimizer = torch.optim.Adam(model.parameters(), lr=lr) # 训练模型 for epoch in range(num_epochs): logits = model(g, g.ndata['feat']) labels = g.ndata['label'] loss = nn.CrossEntropyLoss()(logits, labels) optimizer.zero_grad() loss.backward() optimizer.step() print('Epoch %d | Loss: %.4f' % (epoch, loss.item())) # 预测 model.eval() with torch.no_grad(): logits = model(g, g.ndata['feat']) pred = logits.argmax(1) print('Predicted labels:', pred) 报错:RuntimeError: expected scalar type Double but found Float

import argparse import os import sys import numpy as np import json import torch from PIL import Image sys.path.append(os.path.join(os.getcwd(), "GroundingDINO")) sys.path.append(os.path.join(os.getcwd(), "segment_anything")) # Grounding DINO import GroundingDINO.groundingdino.datasets.transforms as T from GroundingDINO.groundingdino.models import build_model from GroundingDINO.groundingdino.util.slconfig import SLConfig from GroundingDINO.groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap # segment anything from segment_anything import ( sam_model_registry, sam_hq_model_registry, SamPredictor ) import cv2 import numpy as np import matplotlib.pyplot as plt def load_image(image_path): # load image image_pil = Image.open(image_path).convert("RGB") # load image transform = T.Compose( [ T.RandomResize([800], max_size=1333), T.ToTensor(), T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ] ) image, _ = transform(image_pil, None) # 3, h, w return image_pil, image def load_model(model_config_path, model_checkpoint_path, bert_base_uncased_path, device): args = SLConfig.fromfile(model_config_path) args.device = device args.bert_base_uncased_path = bert_base_uncased_path model = build_model(args) checkpoint = torch.load(model_checkpoint_path, map_location="cpu") load_res = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False) print(load_res) _ = model.eval() return model def get_grounding_output(model, image, caption, box_threshold, text_threshold, with_logits=True, device="cpu"): caption = caption.lower() caption = caption.strip() if not caption.endswith("."): caption = caption + "." model = model.to(device) image = image.to(device) with torch.no_grad(): outputs = model(image[None], captions=[caption]) logits = outputs["pred_logits"].cpu().sigmoid()[0] # (nq, 256) boxes = outputs["pred_boxes"].cpu()[0] # (nq, 4) logits.shape[0] # filter output logits_filt = logits.clone() boxes_filt = boxes.clone() filt_mask = logits_filt.max(dim=1)[0] > box_threshold logits_filt = logits_filt[filt_mask] # num_filt, 256 boxes_filt = boxes_filt[filt_mask] # num_filt, 4 logits_filt.shape[0] # get phrase tokenlizer = model.tokenizer tokenized = tokenlizer(caption) # build pred pred_phrases = [] for logit, box in zip(logits_filt, boxes_filt): pred_phrase = get_phrases_from_posmap(logit > text_threshold, tokenized, tokenlizer) if with_logits: pred_phrases.append(pred_phrase + f"({str(logit.max().item())[:4]})") else: pred_phrases.append(pred_phrase) return boxes_filt, pred_phrases def show_mask(mask, ax, random_color=False): if random_color: color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) else: color = np.array([30/255, 144/255, 255/255, 0.6]) h, w = mask.shape[-2:] mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) ax.imshow(mask_image) def show_box(box, ax, label): x0, y0 = box[0], box[1] w, h = box[2] - box[0], box[3] - box[1] ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2)) ax.text(x0, y0, label) def save_mask_data(output_dir, mask_list, box_list, label_list): value = 0 # 0 for background mask_img = torch.zeros(mask_list.shape[-2:]) for idx, mask in enumerate(mask_list): mask_img[mask.cpu().numpy()[0] == True] = value + idx + 1 plt.figure(figsize=(10, 10)) plt.imshow(mask_img.numpy()) plt.axis('off') plt.savefig(os.path.join(output_dir, 'mask.jpg'), bbox_inches="tight", dpi=300, pad_inches=0.0) json_data = [{ 'value': value, 'label': 'background' }] for label, box in zip(label_list, box_list): value += 1 name, logit = label.split('(') logit = logit[:-1] # the last is ')' json_data.append({ 'value': value, 'label': name, 'logit': float(logit), 'box': box.numpy().tolist(), }) with open(os.path.join(output_dir, 'mask.json'), 'w') as f: json.dump(json_data, f) if __name__ == "__main__": parser = argparse.ArgumentParser("Grounded-Segment-Anything Demo", add_help=True) parser.add_argument("--config", type=str, default="./GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py", help="path to config file") parser.add_argument( "--grounded_checkpoint", type=str, default="./groundingdino_swint_ogc.pth", help="path to checkpoint file" ) parser.add_argument( "--sam_version", type=str, default="vit_h", required=False, help="SAM ViT version: vit_b / vit_l / vit_h" ) parser.add_argument( "--sam_checkpoint", type=str, default="./sam_vit_h_4b8939.pth", help="path to sam checkpoint file" ) parser.add_argument( "--sam_hq_checkpoint", type=str, default=None, help="path to sam-hq checkpoint file" ) parser.add_argument( "--use_sam_hq", action="store_true", help="using sam-hq for prediction" ) parser.add_argument("--input_image", type=str, required=True, help="path to image file") parser.add_argument("--text_prompt", type=str, required=True, help="text prompt") parser.add_argument( "--output_dir", "-o", type=str, default="./outputs", help="output directory" ) parser.add_argument("--box_threshold", type=float, default=0.3, help="box threshold") parser.add_argument("--text_threshold", type=float, default=0.25, help="text threshold") parser.add_argument("--device", type=str, default="cpu", help="running on cpu only!, default=False") parser.add_argument("--bert_base_uncased_path", type=str, required=False, help="bert_base_uncased model path, default=False") args = parser.parse_args() # cfg config_file = args.config # change the path of the model config file grounded_checkpoint = args.grounded_checkpoint # change the path of the model sam_version = args.sam_version sam_checkpoint = args.sam_checkpoint sam_hq_checkpoint = args.sam_hq_checkpoint use_sam_hq = args.use_sam_hq image_path = args.input_image text_prompt = args.text_prompt output_dir = args.output_dir box_threshold = args.box_threshold text_threshold = args.text_threshold device = args.device bert_base_uncased_path = args.bert_base_uncased_path # make dir os.makedirs(output_dir, exist_ok=True) # load image image_pil, image = load_image(image_path) # load model model = load_model(config_file, grounded_checkpoint, bert_base_uncased_path, device=device) # visualize raw image image_pil.save(os.path.join(output_dir, "raw_image.jpg")) # run grounding dino model boxes_filt, pred_phrases = get_grounding_output( model, image, text_prompt, box_threshold, text_threshold, device=device ) # initialize SAM if use_sam_hq: predictor = SamPredictor(sam_hq_model_registry[sam_version](checkpoint=sam_hq_checkpoint).to(device)) else: predictor = SamPredictor(sam_model_registry[sam_version](checkpoint=sam_checkpoint).to(device)) image = cv2.imread(image_path) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) predictor.set_image(image) size = image_pil.size H, W = size[1], size[0] for i in range(boxes_filt.size(0)): boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H]) boxes_filt[i][:2] -= boxes_filt[i][2:] / 2 boxes_filt[i][2:] += boxes_filt[i][:2] boxes_filt = boxes_filt.cpu() transformed_boxes = predictor.transform.apply_boxes_torch(boxes_filt, image.shape[:2]).to(device) masks, _, _ = predictor.predict_torch( point_coords = None, point_labels = None, boxes = transformed_boxes.to(device), multimask_output = False, ) # draw output image plt.figure(figsize=(10, 10)) plt.imshow(image) for mask in masks: show_mask(mask.cpu().numpy(), plt.gca(), random_color=True) for box, label in zip(boxes_filt, pred_phrases): show_box(box.numpy(), plt.gca(), label) plt.axis('off') plt.savefig( os.path.join(output_dir, "grounded_sam_output.jpg"), bbox_inches="tight", dpi=300, pad_inches=0.0 ) save_mask_data(output_dir, masks, boxes_filt, pred_phrases) 报错C:\Users\29386\.conda\envs\grounded_sam\python.exe C:\Users\29386\segment-anything\Grounded-Segment-Anything\grounded_sam_demo.py --config GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py --grounded_checkpoint weights/groundingdino_swint_ogc.pth --sam_checkpoint weights/sam_vit_h_4b8939.pth --input_image assets/demo1.jpg --output_dir outputs C:\Users\29386\.conda\envs\grounded_sam\lib\site-packages\timm\models\layers\__init__.py:48: FutureWarning: Importing from timm.models.layers is deprecated, please import via timm.layers warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.layers", FutureWarning) Traceback (most recent call last): File "C:\Users\29386\segment-anything\Grounded-Segment-Anything\grounded_sam_demo.py", line 22, in <module> from segment_anything import ( ImportError: cannot import name 'sam_hq_model_registry' from 'segment_anything' (C:\Users\29386\segment-anything\segment_anything\__init__.py) 进程已结束,退出代码为 1 如何解决

import torch from torch_geometric.data import Data from rdkit import Chem from rdkit.Chem import rdmolops def smi_to_graph(smi_file): graphs = [] with open(smi_file, 'r') as f: for line in f: smi, mol_id = line.strip().split('\t') # 假设格式:SMILES\tID # 转换为分子图 mol = Chem.MolFromSmiles(smi) if not mol: continue # 原子特征(节点) atom_features = [] for atom in mol.GetAtoms(): features = [ atom.GetAtomicNum(), # 原子序数 atom.GetDegree(), # 度 atom.GetFormalCharge(), # 形式电荷 int(atom.IsInRing()) # 是否在环中 ] atom_features.append(features) # 化学键(边) edge_index = [] for bond in mol.GetBonds(): i = bond.GetBeginAtomIdx() j = bond.GetEndAtomIdx() edge_index.append([i, j]) edge_index.append([j, i]) # 无向图 # 构建PyG数据对象 graph = Data( x=torch.tensor(atom_features, dtype=torch.float), edge_index=torch.tensor(edge_index).t().contiguous(), y=torch.tensor([int(mol_id)], dtype=torch.long) # 分子ID作为标签 ) graphs.append(graph) return graphs import torch.nn.functional as F from torch_geometric.nn import GCNConv, global_mean_pool from torch_geometric.data import DataLoader class JCZS_GNN(torch.nn.Module): def __init__(self, input_dim=4, hidden_dim=128, output_dim=1, num_classes=3): super().__init__() # 图卷积层 self.conv1 = GCNConv(input_dim, hidden_dim) self.conv2 = GCNConv(hidden_dim, hidden_dim) # 载药量预测头 self.reg_head = torch.nn.Sequential( torch.nn.Linear(hidden_dim * 2, 64), torch.nn.ReLU(), torch.nn.Linear(64, output_dim) ) # 君臣佐使分类头 self.cls_head = torch.nn.Sequential( torch.nn.Linear(hidden_dim * 2, hidden_dim), torch.nn.ReLU(), torch.nn.Linear(hidden_dim, num_classes) ) def forward(self, data): x, edge_index, batch = data.x, data.edge_index, data.batch # 图卷积 x = F.relu(self.conv1(x, edge_index)) x = F.dropout(x, p=0.5, training=self.training) x = F.relu(self.conv2(x, edge_index)) # 全局池化 graph_embed = global_mean_pool(x, batch) # 多任务输出 y_reg = self.reg_head(graph_embed) # 载药量预测 y_cls = self.cls_head(graph_embed) # 君臣分类 return y_reg, F.softmax(y_cls, dim=1) # 训练流程 def train_gnn(dataset, epochs=100): loader = DataLoader(dataset, batch_size=32, shuffle=True) model = JCZS_GNN() optimizer = torch.optim.Adam(model.parameters(), lr=0.001) for epoch in range(epochs): model.train() for batch in loader: optimizer.zero_grad() # 载药量预测损失 y_reg_pred, y_cls_pred = model(batch) reg_loss = F.mse_loss(y_reg_pred.squeeze(), batch.y_reg) # 假设batch.y_reg是载药量标签 # 君臣分类损失(加权交叉熵) weights = torch.tensor([3.0, 2.0, 1.0]) # 君、臣、佐使权重 cls_loss = F.cross_entropy(y_cls_pred, batch.y_cls, weight=weights) # 假设batch.y_cls是分类标签 total_loss = reg_loss + cls_loss total_loss.backward() optimizer.step() import torch import torch.nn as nn class ConditionedGenerator(nn.Module): def __init__(self, latent_dim=100, cond_dim=13, vocab_size=100, max_len=80): super().__init__() self.embed = nn.Embedding(vocab_size, 256) self.cond_proj = nn.Linear(cond_dim, 256) self.latent_proj = nn.Linear(latent_dim, 256) # Transformer解码器 decoder_layer = nn.TransformerDecoderLayer(d_model=256, nhead=8) self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=6) self.fc_out = nn.Linear(256, vocab_size) self.max_len = max_len def forward(self, z, condition, start_token): # 条件向量投影 cond_embed = self.cond_proj(condition).unsqueeze(0) # (1, batch, 256) # 噪声向量投影 latent_embed = self.latent_proj(z).unsqueeze(0) # (1, batch, 256) # 起始token嵌入 tokens = start_token token_embeds = self.embed(tokens).permute(1, 0, 2) # (seq_len, batch, 256) # 自回归生成 for i in range(self.max_len - 1): output = self.decoder(token_embeds, cond_embed) logits = self.fc_out(output[-1]) # (batch, vocab_size) next_token = torch.argmax(logits, dim=1) tokens = torch.cat([tokens, next_token.unsqueeze(1)], dim=1) # 更新嵌入 new_embed = self.embed(next_token).unsqueeze(0) token_embeds = torch.cat([token_embeds, new_embed], dim=0) return tokens # 判别器(含君臣权重检测) class JCZS_Discriminator(nn.Module): def __init__(self, gnn_model): super().__init__() self.gnn = gnn_model # 共享GNN权重 self.fc = nn.Sequential( nn.Linear(256, 128), nn.ReLU(), nn.Linear(128, 1) ) def forward(self, mol_graph): # 提取分子特征 _, edge_index, _ = mol_graph.x, mol_graph.edge_index, mol_graph.batch x = F.relu(self.gnn.conv1(mol_graph.x, edge_index)) x = F.relu(self.gnn.conv2(x, edge_index)) graph_embed = global_mean_pool(x, mol_graph.batch) # 君臣权重合规检测 cls_probs = self.gnn.cls_head(graph_embed) jczs_compliance = torch.abs(3.0 * cls_probs[:, 0] + 2.0 * cls_probs[:, 1] + 1.0 * cls_probs[:, 2] - 2.5) # 判别器输出 validity_score = torch.sigmoid(self.fc(graph_embed)) return validity_score * (1.0 - jczs_compliance) # 1. 数据准备 smi_file = r"D:\PyCharm 2025.2.0.1\project\分子数据.smi" # 包含SMILES和ID dataset = smi_to_graph(smi_file) # 2. 训练GNN gnn_model = JCZS_GNN() train_gnn(dataset) # 3. 训练GAN generator = ConditionedGenerator() discriminator = JCZS_Discriminator(gnn_model) # 共享GNN权重 def syndrome_to_vector(syndrome): mapping = { "气滞血瘀": [1, 0, 0, -1, -1, 0, 0], # TNF-α↓, VEGF↓ "阴虚火旺": [0, 1, 0, 0, 0, -1, 1] # IL-6↑, CRP↑ } return torch.tensor(mapping[syndrome]) # 4. 生成新分子 def generate_molecule(syndrome_type): # 证型条件向量 (e.g., "气滞血瘀"→[TNF-α↓, VEGF↓]) condition = syndrome_to_vector(syndrome_type) # 拼接君臣佐使权重 jczs_weights = torch.tensor([3.0, 2.0, 1.0]) cond_vector = torch.cat([condition, jczs_weights]) # 生成分子 z = torch.randn(1, 100) # 噪声向量 start_token = torch.tensor([[SOS_IDX]]) # 起始符 smiles_seq = generator(z, cond_vector, start_token) return decode_smiles(smiles_seq) # 示例:生成针对"气滞血瘀"的新分子 new_molecule = generate_molecule("气滞血瘀")报错

RuntimeError: CUDA error: device-side assert triggered CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect. For debugging consider passing CUDA_LAUNCH_BLOCKING=1 Compile with TORCH_USE_CUDA_DSA to enable device-side assertions. !pip install transformers datasets torch rouge-score matplotlib import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import Dataset, DataLoader from transformers import BertTokenizerFast import time import numpy as np from datasets import load_dataset from rouge_score import rouge_scorer import matplotlib.pyplot as plt from IPython.display import clear_output # 设备配置 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"使用设备: {device}") # 数据预处理(严格过滤无效样本) class SummaryDataset(Dataset): def __init__(self, dataset_split, tokenizer, max_article_len=384, max_summary_len=96, subset_size=0.01): self.tokenizer = tokenizer self.max_article_len = max_article_len self.max_summary_len = max_summary_len self.subset = dataset_split.select(range(int(len(dataset_split) * subset_size))) # 严格过滤无效样本 self.articles = [] self.summaries = [] self.vocab = set(tokenizer.vocab.keys()) for item in self.subset: article = item['article'].strip() summary = item['highlights'].strip() if len(article) > 20 and len(summary) > 10: article_tokens = tokenizer.tokenize(article) summary_tokens = tokenizer.tokenize(summary) if all(t in self.vocab for t in article_tokens) and all(t in self.vocab for t in summary_tokens): self.articles.append(article) self.summaries.append(summary) self.pad_token_id = tokenizer.pad_token_id self.unk_token_id = tokenizer.unk_token_id def __len__(self): return len(self.articles) def __getitem__(self, idx): src = self.tokenizer( self.articles[idx], max_length=self.max_article_len, truncation=True, padding='max_length', return_tensors='pt', add_special_tokens=True ) tgt = self.tokenizer( self.summaries[idx], max_length=self.max_summary_len, truncation=True, padding='max_length', return_tensors='pt', add_special_tokens=True ) tgt_labels = tgt['input_ids'].squeeze() tgt_labels[tgt_labels == self.pad_token_id] = -100 # 忽略填充 tgt_labels[tgt_labels >= len(self.tokenizer.vocab)] = self.unk_token_id # 过滤无效id return { 'input_ids': src['input_ids'].squeeze(), 'attention_mask': src['attention_mask'].squeeze(), 'labels': tgt_labels } # 基础Seq2Seq模型 class BasicEncoder(nn.Module): def __init__(self, vocab_size, emb_dim=128, hidden_dim=256): super().__init__() self.embedding = nn.Embedding(vocab_size, emb_dim, padding_idx=0) self.gru = nn.GRU(emb_dim, hidden_dim, num_layers=2, batch_first=True, bidirectional=True) self.fc_hidden = nn.Linear(hidden_dim * 2, hidden_dim) def forward(self, src): embedded = self.embedding(src) outputs, hidden = self.gru(embedded) # 取第二层双向隐藏状态 forward_hidden = hidden[-2, :, :] # 第二层正向 backward_hidden = hidden[-1, :, :] # 第二层反向 hidden = torch.cat([forward_hidden, backward_hidden], dim=1) # (batch, 2*hidden_dim) hidden = self.fc_hidden(hidden).unsqueeze(0) # (1, batch, hidden_dim) return hidden class BasicDecoder(nn.Module): def __init__(self, vocab_size, emb_dim=128, hidden_dim=256): super().__init__() self.embedding = nn.Embedding(vocab_size, emb_dim, padding_idx=0) self.gru = nn.GRU(emb_dim + hidden_dim, hidden_dim, num_layers=1, batch_first=True) self.fc = nn.Linear(hidden_dim * 2 + emb_dim, vocab_size) def forward(self, input_ids, hidden, context): input_embedded = self.embedding(input_ids.unsqueeze(1)) # (batch, 1, emb_dim) input_combined = torch.cat([input_embedded, context.unsqueeze(1)], dim=2) # (batch, 1, emb_dim+hidden_dim) output, hidden = self.gru(input_combined, hidden) # (batch, 1, hidden_dim) output = output.squeeze(1) # (batch, hidden_dim) combined = torch.cat([output, context, input_embedded.squeeze(1)], dim=1) # (batch, 2*hidden_dim+emb_dim) logits = self.fc(combined) return logits, hidden class BasicSeq2Seq(nn.Module): def __init__(self, vocab_size, emb_dim=128, hidden_dim=256): super().__init__() self.encoder = BasicEncoder(vocab_size, emb_dim, hidden_dim) self.decoder = BasicDecoder(vocab_size, emb_dim, hidden_dim) self.device = device self.sos_token_id = 101 # [CLS] self.eos_token_id = 102 # [SEP] self.unk_token_id = 100 # [UNK] def forward(self, src, tgt): hidden = self.encoder(src) context = hidden.squeeze(0) batch_size, tgt_len = tgt.size() outputs = torch.zeros(batch_size, tgt_len, self.decoder.fc.out_features).to(device) input_ids = tgt[:, 0] for t in range(1, tgt_len): logits, hidden = self.decoder(input_ids, hidden, context) outputs[:, t] = logits input_ids = tgt[:, t] return outputs def generate(self, src, max_length=80): src = src.to(device) hidden = self.encoder(src) context = hidden.squeeze(0) # 修正后的生成初始化 generated = torch.full((src.size(0), 1), self.sos_token_id, device=device) # 注意这里的修正 for _ in range(max_length-1): logits, hidden = self.decoder(generated[:, -1], hidden, context) next_token = torch.argmax(logits, dim=1, keepdim=True) # 防止过早生成标点 if generated.size(1) < 5: punctuation = [',', '.', ';', ':', '!', '?', "'", '"', '', '~'] punct_ids = [self.tokenizer.convert_tokens_to_ids(p) for p in punctuation] if next_token.item() in punct_ids: # 替换为最常见的实词 next_token = torch.tensor([[self.tokenizer.convert_tokens_to_ids('the')]], device=device) generated = torch.cat([generated, next_token], dim=1) if (next_token == self.eos_token_id).all(): break return generated # 注意力Seq2Seq模型 class Attention(nn.Module): def __init__(self, hidden_dim): super().__init__() self.W = nn.Linear(2 * hidden_dim, hidden_dim) self.v = nn.Linear(hidden_dim, 1, bias=False) def forward(self, hidden, encoder_outputs): src_len = encoder_outputs.size(1) hidden = hidden.unsqueeze(1).repeat(1, src_len, 1) # (batch, src_len, hidden_dim) combined = torch.cat([hidden, encoder_outputs], dim=2) # (batch, src_len, 2*hidden_dim) energy = self.v(torch.tanh(self.W(combined))).squeeze(2) # (batch, src_len) return torch.softmax(energy, dim=1) class AttnEncoder(nn.Module): def __init__(self, vocab_size, emb_dim=128, hidden_dim=256): super().__init__() self.embedding = nn.Embedding(vocab_size, emb_dim, padding_idx=0) self.lstm = nn.LSTM(emb_dim, hidden_dim, num_layers=2, batch_first=True, bidirectional=True, dropout=0.1) self.fc_hidden = nn.Linear(hidden_dim * 2, hidden_dim) # 双向输出拼接 self.fc_cell = nn.Linear(hidden_dim * 2, hidden_dim) def forward(self, src): embedded = self.embedding(src) outputs, (hidden, cell) = self.lstm(embedded) # outputs: (batch, src_len, 2*hidden_dim) # 取第二层双向隐藏状态 hidden = torch.cat([hidden[-2, :, :], hidden[-1, :, :]], dim=1) # (batch, 2*hidden_dim) cell = torch.cat([cell[-2, :, :], cell[-1, :, :]], dim=1) hidden = self.fc_hidden(hidden).unsqueeze(0) # (1, batch, hidden_dim) cell = self.fc_cell(cell).unsqueeze(0) return outputs, (hidden, cell) class AttnDecoder(nn.Module): def __init__(self, vocab_size, emb_dim=128, hidden_dim=256): super().__init__() self.embedding = nn.Embedding(vocab_size, emb_dim, padding_idx=0) self.attention = Attention(hidden_dim) self.lstm = nn.LSTM(emb_dim + 2 * hidden_dim, hidden_dim, num_layers=1, batch_first=True) self.fc = nn.Linear(hidden_dim + emb_dim, vocab_size) def forward(self, input_ids, hidden, cell, encoder_outputs): input_embedded = self.embedding(input_ids.unsqueeze(1)) # (batch, 1, emb_dim) attn_weights = self.attention(hidden.squeeze(0), encoder_outputs) # (batch, src_len) context = torch.bmm(attn_weights.unsqueeze(1), encoder_outputs) # (batch, 1, 2*hidden_dim) lstm_input = torch.cat([input_embedded, context], dim=2) # (batch, 1, emb_dim+2*hidden_dim) output, (hidden, cell) = self.lstm(lstm_input, (hidden, cell)) # output: (batch, 1, hidden_dim) logits = self.fc(torch.cat([output.squeeze(1), input_embedded.squeeze(1)], dim=1)) # (batch, vocab_size) return logits, hidden, cell class AttnSeq2Seq(nn.Module): def __init__(self, vocab_size, emb_dim=128, hidden_dim=256): super().__init__() self.encoder = AttnEncoder(vocab_size, emb_dim, hidden_dim) self.decoder = AttnDecoder(vocab_size, emb_dim, hidden_dim) self.device = device self.sos_token_id = 101 # [CLS] self.eos_token_id = 102 # [SEP] self.unk_token_id = 100 # [UNK] def forward(self, src, tgt): encoder_outputs, (hidden, cell) = self.encoder(src) batch_size, tgt_len = tgt.size() outputs = torch.zeros(batch_size, tgt_len, self.decoder.fc.out_features).to(device) input_ids = tgt[:, 0] for t in range(1, tgt_len): logits, hidden, cell = self.decoder(input_ids, hidden, cell, encoder_outputs) outputs[:, t] = logits input_ids = tgt[:, t] return outputs def generate(self, src, max_length=80): encoder_outputs, (hidden, cell) = self.encoder(src) # 修正后的生成初始化 generated = torch.full((src.size(0), 1), self.sos_token_id, device=device) # 注意这里的修正 for _ in range(max_length-1): logits, hidden, cell = self.decoder(generated[:, -1], hidden, cell, encoder_outputs) next_token = torch.argmax(logits, dim=1, keepdim=True) # 防止过早生成标点 if generated.size(1) < 5: punctuation = [',', '.', ';', ':', '!', '?', "'", '"', '', '~'] punct_ids = [self.tokenizer.convert_tokens_to_ids(p) for p in punctuation] if next_token.item() in punct_ids: # 替换为最常见的实词 next_token = torch.tensor([[self.tokenizer.convert_tokens_to_ids('the')]], device=device) generated = torch.cat([generated, next_token], dim=1) if (next_token == self.eos_token_id).all(): break return generated # Transformer模型 class PositionalEncoding(nn.Module): def __init__(self, d_model, max_len=5000): super().__init__() pe = torch.zeros(max_len, d_model) position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model)) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) self.register_buffer('pe', pe.unsqueeze(0)) def forward(self, x): return x + self.pe[:, :x.size(1)] class TransformerModel(nn.Module): def __init__(self, vocab_size, d_model=128, nhead=8, num_layers=3, dim_feedforward=512, max_len=5000): super().__init__() self.embedding = nn.Embedding(vocab_size, d_model, padding_idx=0) self.pos_encoder = PositionalEncoding(d_model, max_len) # 编码器 encoder_layer = nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout=0.1) self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers) # 解码器 decoder_layer = nn.TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout=0.1) self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers) self.fc = nn.Linear(d_model, vocab_size) self.d_model = d_model self.sos_token_id = 101 # [CLS] self.eos_token_id = 102 # [SEP] def _generate_square_subsequent_mask(self, sz): mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) return mask def forward(self, src, tgt): src_mask = None tgt_mask = self._generate_square_subsequent_mask(tgt.size(1)).to(device) src_key_padding_mask = (src == 0) tgt_key_padding_mask = (tgt == 0) src = self.embedding(src) * np.sqrt(self.d_model) src = self.pos_encoder(src) tgt = self.embedding(tgt) * np.sqrt(self.d_model) tgt = self.pos_encoder(tgt) memory = self.transformer_encoder(src.transpose(0, 1), src_mask, src_key_padding_mask) output = self.transformer_decoder( tgt.transpose(0, 1), memory, tgt_mask, None, tgt_key_padding_mask, src_key_padding_mask ) output = self.fc(output.transpose(0, 1)) return output def generate(self, src, max_length=80): src_mask = None src_key_padding_mask = (src == 0) src = self.embedding(src) * np.sqrt(self.d_model) src = self.pos_encoder(src) memory = self.transformer_encoder(src.transpose(0, 1), src_mask, src_key_padding_mask) batch_size = src.size(0) generated = torch.full((batch_size, 1), self.sos_token_id, device=device) for i in range(max_length-1): tgt_mask = self._generate_square_subsequent_mask(generated.size(1)).to(device) tgt_key_padding_mask = (generated == 0) tgt = self.embedding(generated) * np.sqrt(self.d_model) tgt = self.pos_encoder(tgt) output = self.transformer_decoder( tgt.transpose(0, 1), memory, tgt_mask, None, tgt_key_padding_mask, src_key_padding_mask ) output = self.fc(output.transpose(0, 1)[:, -1, :]) next_token = torch.argmax(output, dim=1, keepdim=True) generated = torch.cat([generated, next_token], dim=1) if (next_token == self.eos_token_id).all(): break return generated # 训练函数 def train_model(model, train_loader, optimizer, criterion, epochs=3): model.train() optimizer = optim.Adam(model.parameters(), lr=1e-4) scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=1, factor=0.5) start_time = time.time() for epoch in range(epochs): total_loss = 0 model.train() for i, batch in enumerate(train_loader): src = batch['input_ids'].to(device) tgt = batch['labels'].to(device) optimizer.zero_grad() outputs = model(src, tgt[:, :-1]) # 检查模型输出有效性 if torch.isnan(outputs).any(): print("警告:模型输出包含NaN,跳过此批次") continue loss = criterion(outputs.reshape(-1, outputs.size(-1)), tgt[:, 1:].reshape(-1)) loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5) # 梯度裁剪 optimizer.step() total_loss += loss.item() if (i+1) % 10 == 0: print(f"Epoch {epoch+1}/{epochs} | Batch {i+1}/{len(train_loader)} | Loss: {loss.item():.4f}") avg_loss = total_loss / len(train_loader) scheduler.step(avg_loss) print(f"Epoch {epoch+1} | 平均损失: {avg_loss:.4f}") torch.cuda.empty_cache() total_time = time.time() - start_time print(f"训练完成!总耗时: {total_time:.2f}s ({total_time/60:.2f}分钟)") return model, total_time # 评估函数 def evaluate_model(model, val_loader, tokenizer, num_examples=2): model.eval() scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True) rouge_scores = {'rouge1': [], 'rouge2': [], 'rougeL': []} valid_count = 0 with torch.no_grad(): for i, batch in enumerate(val_loader): src = batch['input_ids'].to(device) tgt = batch['labels'].to(device) generated = model.generate(src) for s, p, t in zip(src, generated, tgt): src_txt = tokenizer.decode(s, skip_special_tokens=True) pred_txt = tokenizer.decode(p, skip_special_tokens=True) true_txt = tokenizer.decode(t[t != -100], skip_special_tokens=True) if len(pred_txt.split()) > 3 and len(true_txt.split()) > 3: valid_count += 1 if valid_count <= num_examples: print(f"\n原文: {src_txt[:100]}...") print(f"生成: {pred_txt}") print(f"参考: {true_txt[:80]}...") print("-"*60) if true_txt and pred_txt: scores = scorer.score(true_txt, pred_txt) for key in rouge_scores: rouge_scores[key].append(scores[key].fmeasure) if valid_count > 0: avg_scores = {key: sum(rouge_scores[key])/len(rouge_scores[key]) for key in rouge_scores} print(f"\n评估结果 (基于{valid_count}个样本):") print(f"ROUGE-1: {avg_scores['rouge1']*100:.2f}%") print(f"ROUGE-2: {avg_scores['rouge2']*100:.2f}%") print(f"ROUGE-L: {avg_scores['rougeL']*100:.2f}%") else: print("警告:未生成有效摘要") avg_scores = {key: 0.0 for key in rouge_scores} return avg_scores # 可视化模型性能 def visualize_model_performance(model_names, train_times, rouge_scores): plt.figure(figsize=(15, 6)) # 训练时间对比图 plt.subplot(1, 2, 1) bars = plt.bar(model_names, train_times) plt.title('模型训练时间对比') plt.ylabel('时间 (分钟)') for bar in bars: height = bar.get_height() plt.text(bar.get_x() + bar.get_width()/2., height, f'{height:.1f} min', ha='center', va='bottom') # ROUGE分数对比图 plt.subplot(1, 2, 2) x = np.arange(len(model_names)) width = 0.25 plt.bar(x - width, [scores['rouge1'] for scores in rouge_scores], width, label='ROUGE-1') plt.bar(x, [scores['rouge2'] for scores in rouge_scores], width, label='ROUGE-2') plt.bar(x + width, [scores['rougeL'] for scores in rouge_scores], width, label='ROUGE-L') plt.title('模型ROUGE分数对比') plt.ylabel('F1分数') plt.xticks(x, model_names) plt.legend() plt.tight_layout() plt.savefig('performance_comparison.png') plt.show() print("性能对比图已保存为 performance_comparison.png") # 交互式文本摘要生成 def interactive_summarization(models, tokenizer, model_names, max_length=80): while True: print("\n" + "="*60) print("文本摘要交互式测试 (输入 'q' 退出)") print("="*60) input_text = input("请输入要摘要的文本:\n") if input_text.lower() == 'q': break if len(input_text) < 50: print("请输入更长的文本(至少50个字符)") continue # 生成摘要 inputs = tokenizer( input_text, max_length=384, truncation=True, padding='max_length', return_tensors='pt' ).to(device) print("\n生成摘要中...") all_summaries = [] for i, model in enumerate(models): model.eval() with torch.no_grad(): generated = model.generate(inputs["input_ids"]) summary = tokenizer.decode(generated[0], skip_special_tokens=True) all_summaries.append(summary) # 打印结果 print(f"\n{model_names[i]} 摘要:") print("-"*50) print(summary) print("-"*50) print("\n所有模型摘要对比:") for i, (name, summary) in enumerate(zip(model_names, all_summaries)): print(f"{i+1}. {name}: {summary}") # 主程序 print("加载数据集...") dataset = load_dataset("cnn_dailymail", "3.0.0") tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased') vocab_size = len(tokenizer.vocab) # 准备训练数据 print("准备训练数据...") train_ds = SummaryDataset(dataset['train'], tokenizer, subset_size=0.01) # 使用1%的数据 val_ds = SummaryDataset(dataset['validation'], tokenizer, subset_size=0.01) train_loader = DataLoader(train_ds, batch_size=4, shuffle=True, num_workers=0) val_loader = DataLoader(val_ds, batch_size=8, shuffle=False, num_workers=0) # 定义损失函数 criterion = nn.CrossEntropyLoss(ignore_index=-100) # 训练基础Seq2Seq print("\n" + "="*60) print("训练基础Seq2Seq模型") print("="*60) basic_model = BasicSeq2Seq(vocab_size).to(device) trained_basic, basic_time = train_model(basic_model, train_loader, None, criterion, epochs=3) basic_rouge = evaluate_model(trained_basic, val_loader, tokenizer) # 训练注意力Seq2Seq print("\n" + "="*60) print("训练注意力Seq2Seq模型") print("="*60) attn_model = AttnSeq2Seq(vocab_size).to(device) trained_attn, attn_time = train_model(attn_model, train_loader, None, criterion, epochs=3) attn_rouge = evaluate_model(trained_attn, val_loader, tokenizer) # 训练Transformer print("\n" + "="*60) print("训练Transformer模型") print("="*60) transformer_model = TransformerModel(vocab_size).to(device) trained_transformer, transformer_time = train_model(transformer_model, train_loader, None, criterion, epochs=3) transformer_rouge = evaluate_model(trained_transformer, val_loader, tokenizer) # 可视化模型性能 print("\n" + "="*60) print("模型性能对比") print("="*60) model_names = ['基础Seq2Seq', '注意力Seq2Seq', 'Transformer'] train_times = [basic_time/60, attn_time/60, transformer_time/60] rouge_scores = [basic_rouge, attn_rouge, transformer_rouge] visualize_model_performance(model_names, train_times, rouge_scores) # 交互式测试 print("\n" + "="*60) print("交互式文本摘要测试") print("="*60) print("提示:输入一段文本,将同时生成三个模型的摘要结果") interactive_summarization( [trained_basic, trained_attn, trained_transformer], tokenizer, model_names ) 修改完错误后发完整代码给我

class VarifocalLoss(nn.Module): """ Varifocal loss by Zhang et al. https://arxiv.org/abs/2008.13367. """ def __init__(self): """Initialize the VarifocalLoss class.""" super().__init__() @staticmethod def forward(pred_score, gt_score, label, alpha=0.75, gamma=2.0): """Compute varifocal loss between predictions and ground truth.""" weight = alpha * pred_score.sigmoid().pow(gamma) * (1 - label) + gt_score * label with autocast(enabled=False): loss = ( (F.binary_cross_entropy_with_logits(pred_score.float(), gt_score.float(), reduction="none") * weight) .mean(1) .sum() ) return loss class v8DetectionLoss: """Criterion class for computing training losses for YOLOv8 object detection.""" def __init__(self, model, tal_topk=10): # model must be de-paralleled """Initialize v8DetectionLoss with model parameters and task-aligned assignment settings.""" device = next(model.parameters()).device # get model device h = model.args # hyperparameters m = model.model[-1] # Detect() module self.bce = nn.BCEWithLogitsLoss(reduction="none") self.hyp = h self.stride = m.stride # model strides self.nc = m.nc # number of classes self.no = m.nc + m.reg_max * 4 self.reg_max = m.reg_max self.device = device self.use_dfl = m.reg_max > 1 self.assigner = TaskAlignedAssigner(topk=tal_topk, num_classes=self.nc, alpha=0.5, beta=6.0) self.bbox_loss = BboxLoss(m.reg_max).to(device) self.proj = torch.arange(m.reg_max, dtype=torch.float, device=device) def preprocess(self, targets, batch_size, scale_tensor): """Preprocess targets by converting to tensor format and scaling coordinates.""" nl, ne = targets.shape if nl == 0: out = torch.zeros(batch_size, 0, ne - 1, device=self.device) else: i = targets[:, 0] # image index _, counts = i.unique(return_counts=True) counts = counts.to(dtype=torch.int32) out = torch.zeros(batch_size, counts.max(), ne - 1, device=self.device) for j in range(batch_size): matches = i == j if n := matches.sum(): out[j, :n] = targets[matches, 1:] out[..., 1:5] = xywh2xyxy(out[..., 1:5].mul_(scale_tensor)) return out def bbox_decode(self, anchor_points, pred_dist): """Decode predicted object bounding box coordinates from anchor points and distribution.""" if self.use_dfl: b, a, c = pred_dist.shape # batch, anchors, channels pred_dist = pred_dist.view(b, a, 4, c // 4).softmax(3).matmul(self.proj.type(pred_dist.dtype)) # pred_dist = pred_dist.view(b, a, c // 4, 4).transpose(2,3).softmax(3).matmul(self.proj.type(pred_dist.dtype)) # pred_dist = (pred_dist.view(b, a, c // 4, 4).softmax(2) * self.proj.type(pred_dist.dtype).view(1, 1, -1, 1)).sum(2) return dist2bbox(pred_dist, anchor_points, xywh=False) def __call__(self, preds, batch): """Calculate the sum of the loss for box, cls and dfl multiplied by batch size.""" loss = torch.zeros(3, device=self.device) # box, cls, dfl feats = preds[1] if isinstance(preds, tuple) else preds pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split( (self.reg_max * 4, self.nc), 1 ) pred_scores = pred_scores.permute(0, 2, 1).contiguous() pred_distri = pred_distri.permute(0, 2, 1).contiguous() dtype = pred_scores.dtype batch_size = pred_scores.shape[0] imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w) anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5) # Targets targets = torch.cat((batch["batch_idx"].view(-1, 1), batch["cls"].view(-1, 1), batch["bboxes"]), 1) targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]]) gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0) # Pboxes pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4) # dfl_conf = pred_distri.view(batch_size, -1, 4, self.reg_max).detach().softmax(-1) # dfl_conf = (dfl_conf.amax(-1).mean(-1) + dfl_conf.amax(-1).amin(-1)) / 2 _, target_bboxes, target_scores, fg_mask, _ = self.assigner( # pred_scores.detach().sigmoid() * 0.8 + dfl_conf.unsqueeze(-1) * 0.2, pred_scores.detach().sigmoid(), (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype), anchor_points * stride_tensor, gt_labels, gt_bboxes, mask_gt, ) target_scores_sum = max(target_scores.sum(), 1) # Cls loss # loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way loss[1] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE # Bbox loss if fg_mask.sum(): target_bboxes /= stride_tensor loss[0], loss[2] = self.bbox_loss( pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask ) loss[0] *= self.hyp.box # box gain loss[1] *= self.hyp.cls # cls gain loss[2] *= self.hyp.dfl # dfl gain return loss * batch_size, loss.detach() # loss(box, cls, dfl) 这是原本yolov8中loss.py里面的两个代码段,我想要改进它们,使其更加适合Stanford40数据集的识别精度

最新推荐

recommend-type

毕设&课设:人工智能产生式系统 - 球星产生式系统案例.zip

经导师指导并认可通过的大作业设计项目源码,适用人群:计算机,电子信息工程、数学等专业的大学生课程设计、期末大作业或毕业设计,作为“参考资料”使用。
recommend-type

CireNeikual-LD32:探索开源作曲新境界

从给定文件信息中,我们可以提取出以下知识点: 1. Ludum Dare: Ludum Dare是一种全球性的游戏开发活动,鼓励开发者在限定时间内(通常是48小时或72小时)创造出一个游戏。这个活动的特点是挑战参与者在非常短的时间内完成一个游戏项目的构思、设计、编程和发布。Ludum Dare强调的是创意和执行能力,而不是游戏的复杂度或制作质量。 2. 作曲作品:在这次Ludum Dare活动中,参与者提交的是一首音乐作品。音乐在游戏开发中扮演着非常重要的角色,它可以增强游戏的氛围、情感以及玩家的沉浸感。作曲者可能使用了数字音乐工作站(DAW)、音频编辑软件或乐器模拟软件来创作这首音乐。 3. 开源:开源(Open Source)指的是软件源代码对所有人都是可获取的,任何人都可以查看、修改和分发该代码。开源软件通常是自由软件,这意味着用户可以自由地使用、复制、研究、修改和改进软件。开源项目往往由一个社区来共同维护和推进,这样的协作模式使得软件可以快速适应不断变化的需求和标准。 4. CireNeikual-LD32:从标题中可以推测,这可能是作曲者在Ludum Dare 32期间创作的音乐作品名称。这可能是一个电子音乐项目,因为音乐标题中的“CireNeikual”听起来像是合成器、电子乐器或虚拟乐器的名称。这类音乐通常包含合成声音、节拍和旋律,它可能与电子游戏的氛围紧密相关。 5. 文件名称列表:从提供的信息来看,压缩包子文件中只包含了“LD32”的文件名,这可能意味着该压缩包仅包含一个作品,即上文提到的CireNeikual-LD32音乐作品。这个文件很可能是以MP3、WAV或其他音频格式存储的音乐文件。 6. 知识点的综合应用:Ludum Dare作为一个游戏开发快速挑战活动,其理念与开源运动的精神不谋而合,都是基于共享、合作和共同进步的价值观。参与者在短短的48小时内不仅要快速制作游戏内容,还要在技术、艺术和音乐方面做出迅速的决策。开放源代码的做法有利于其他开发者学习和改进,这也有助于推动整个游戏开发社区的技术和创意发展。音乐作品的加入为游戏增添了艺术层次,使得整个项目更加完整和吸引人。此外,像CireNeikual-LD32这样的作品可能也会推动音乐创作者之间的交流和合作,通过开源共享其作品,他们能够获得反馈,并且与其他创作者共同探讨音乐制作的新技术和新想法。
recommend-type

多租户资源配置秘籍:CDS 7.1.62命名空间隔离与配额管理的6种实施方式

# 摘要 多租户环境下的资源管理是云原生平台面临的核心挑战之一,尤其在基于Kubernetes的CDS架构中,命名空间隔离与资源配额控制成为保障租户间安全与效率的关
recommend-type

里面的递归合并

你提到“里面的递归合并”,我理解你是指对嵌套结构(如树形结构)的递归处理与合并,特别是在处理监控设备树时,既要递归构建带 `checked` 状态的结构,又要递归提取选中的名称。 我们可以将这两个递归操作合并为一个过程,以减少递归次数,提高性能,特别是在数据量大的情况下。这样可以避免对同一棵树进行多次遍历。 --- ### ✅ 合并递归处理:一次遍历完成结构转换和收集选中项 我们可以在递归构建结构的同时,收集 `checked` 为 `true` 的节点名称,从而减少一次完整的递归遍历。 --- ### ✅ 合并后的优化代码如下: ```javascript videoList(
recommend-type

Clementine.js FCC:专为Free Code Camp设计的项目样板

### 标题知识点解析 #### clementinejs-fcc:专门用于 Free Code Camp 课程的 Clementine.js 样板版本 标题中提到的“clementinejs-fcc”指的是一个专门为Free Code Camp(FCC)课程设计的Clementine.js样板版本。这个版本是为学习者准备的,用以帮助他们完成FCC中的项目。在讨论Clementine.js样板时,需要强调以下几点: 1. **Clementine.js项目定位:** Clementine.js 是一个轻量级的全栈JavaScript开发样板。这意味着它提供了一个基础的框架或模板,初学者和有经验的开发者可以在其基础上快速开始新的项目,而不必从零开始配置整个开发环境。 2. **技术栈:** Clementine.js样板利用了Node.js、Express(一个高性能的Node.js框架)、MongoDB(一个文档型数据库),这三个技术通常被合称为MEAN(MongoDB, Express, AngularJS, Node.js)堆栈。此处虽然提到了Express和Node.js,但AngularJS并未在标题中显示,可能是因为标题提到的版本并没有使用AngularJS。 3. **GitHub认证集成:** 样板包含了GitHub认证,这是非常实用的功能,因为它允许用户使用他们的GitHub账户来登录应用程序,从而简化了用户认证过程。 4. **Free Code Camp(FCC):** Free Code Camp是一个提供免费编码课程的非营利组织,旨在教授学生在真实的项目中使用Web开发技术。FCC项目包括一系列从基础到高级的编程挑战,参与者通过完成这些挑战来学习和提高编程技能。 ### 描述知识点解析 #### 此项目不再积极维护 描述中提到项目已不再积极维护,这意味着项目的主要开发工作已经停止,不再添加新功能或进行重大更新。尽管如此,该项目的存档版本仍然可以供学习者或需要稳定版本的用户使用。 #### Clementine.js FCC 样板概述 - **样板透明性和简单性:** 描述中提到样板在透明度和简单性方面做得很好,这表示项目的设计意图让用户能够轻松理解和使用样板中包含的各个组件。 - **版本说明:** - **基础版本:** 最简单的版本,适用于那些对样板体积和功能侵入性有特定要求的用户。 - **增强版本:** 使用AngularJS作为前端框架的稍微复杂的版本,这表明该版本提供了更多的功能和结构,可能更适合需要前端框架的项目。 - **FCC版本:** 标准样板的修订版,专门为FCC课程的学生设计。由于FCC是一个教育项目,所以这个版本可能包括额外的教学材料、注释或指导,以便学生更好地理解和完成课程中的项目。 ### 标签知识点解析 #### JavaScript 标签中提到“JavaScript”,它是Clementine.js样板的基础。JavaScript是一种高级的、解释型的编程语言,广泛用于前端开发。在Clementine.js样板中,JavaScript不仅用于客户端的交互,还用于Node.js环境中的服务器端编程。由于项目是为FCC课程准备的,因此特别强调了JavaScript的学习和应用,包括其在客户端和服务器端的使用。 ### 压缩包子文件的文件名称列表知识点解析 #### clementinejs-fcc-master 列表中提到的“clementinejs-fcc-master”是项目文件的名称。在这个上下文中,“master”可能表示这是项目的主分支或主版本。在Git版本控制系统中,“master”通常用作默认分支的名字,代表项目的当前最佳版本。这一点对于理解如何下载、安装和使用Clementine.js样板版本是非常重要的。 ### 综合应用知识点 综上所述,Clementine.js样板为全栈JavaScript开发提供了一个轻量级的起点,使用了流行的MEAN堆栈技术,并且针对Free Code Camp课程进行了定制。这个样板的FCC版本适用于学生完成课程项目,因为它的设计和功能都特别考虑到了教学目的。尽管项目已不再积极维护,但用户仍然可以使用存档版本。对于正在学习全栈JavaScript开发的个人,特别是正在参与Free Code Camp课程的学习者,这个样板仍是一个有价值的资源。
recommend-type

CDS 7.1.62云原生整合之道:Kubernetes Operator部署模式的5大优势解析

# 摘要 随着云原生技术的快速发展,Kubernetes Operator在复杂中间件管理中展现出强大优势。本文以CDS 7.1.62为实践平台,系统阐述Operator与云原生体系融合的演进背景,深入解析其基于控制器模式、自定义资源(CRD)和状态协调循环的核心机制,并对比Helm与StatefulSet在部署复杂应用时的局限性。通过实际部署
recommend-type

replace( 5 translate(data,'0123456789','##########'),'#',''),''

在处理字符串中的数字字符时,可以通过替换和翻译方法将数字转换为特殊字符。这种方法通常用于数据混淆、加密或格式化目的。可以使用编程语言中的字符串处理函数来实现此功能,例如 Python 或 Java 中的 `translate` 和 `maketrans` 方法,或者使用正则表达式进行替换。 在 Python 中,可以使用 `str.translate` 方法结合 `str.maketrans` 来实现数字字符的替换。以下是一个示例,将字符串中的数字替换为对应的特殊字符: ```python # 定义数字字符与特殊字符的映射关系 digit_to_special = { '0': '
recommend-type

使用stateless-shiro实现REST Web服务的安全管理

### 知识点详细说明 #### 标题知识点 1. **REST Web 服务**: REST(Representational State Transfer)是一种软件架构风格,它定义了一组约束条件和原则,用于设计网络中的分布式系统,尤其是Web服务。RESTful Web服务是以REST原则为基础开发的服务,广泛用于Web应用中,以实现客户端和服务器之间的无状态通信。 2. **Shiro**: Apache Shiro是一个开源的安全框架,提供认证、授权、加密和会话管理功能。Shiro的设计目标是易于使用并理解,同时提供一个直观且易于维护的API。它旨在用来保护应用程序的安全,无论应用程序的大小如何,可以简单到一个独立的Java应用程序,也可以复杂到大型的Web和企业应用程序。 3. **无国籍**: 在此上下文中,"无国籍"可能是一个翻译或输入错误,它应该是指“无状态”(stateless),意味着系统中各个组件不保持任何状态信息。在Web服务中,无状态通常意味着服务器不会在请求之间保留任何客户端的状态信息,这样可以提高系统的可扩展性和可靠性。 #### 描述知识点 1. **项目构建**: 项目中使用 Maven 作为构建工具,`mvn clean package` 命令用于清理之前的构建,打包项目。`spring-boot:run` 是Spring Boot Maven插件提供的目标,用于运行Spring Boot应用。 2. **初始化测试用户**: 描述中提到的 `curl` 命令用于向服务器发送HTTP请求。`-X PUT` 是指定HTTP方法为PUT的参数,用于初始化测试用户。这表明服务提供了一个API端点用于添加或更新资源。 3. **获取用户列表**: 使用 `curl` 命令获取用户列表,显示了如何使用HTTP GET方法从服务端请求数据。 4. **返回401未授权**: 返回的HTTP状态码为401,表明用户未被授权访问所请求的资源。这说明Shiro的安全框架已经介入,要求客户端提供有效的认证信息。 5. **登录操作**: 描述中未提供完整的登录命令,但暗示了一个需要使用 `curl` 命令和相应的HTTP头进行登录的操作,登录成功后才能访问受保护的资源。 #### 标签知识点 1. **Java**: 标签表示该服务是使用Java语言开发的。Java是一种广泛使用的通用编程语言,它以其“一次编写,到处运行”的特性而闻名,特别适合企业级应用和Web服务。 #### 压缩包子文件的文件名称列表知识点 1. **stateless-shiro-master**: 这是项目压缩包的名称,表明该文件包含了无状态Shiro框架的完整项目文件。"master"通常表示这是项目的主分支或主版本。 ### 综合知识点 Shiro框架在本项目中的应用是为了提供一个无状态的REST Web服务安全解决方案。无状态的设计使得服务易于扩展和维护,Shiro在其中扮演的角色包括但不限于: - **认证**: 确定用户身份,通常涉及到用户名和密码的匹配过程。 - **授权**: 确定用户是否有权限进行特定的操作或访问特定的资源。 - **会话管理**: 在无状态设计中,Shiro的会话管理意味着跟踪用户认证后的交互,尽管它不保存服务器端的状态。 通过Maven构建和Spring Boot运行的项目,说明开发者采用了现代Java开发的流行工具和技术。使用 `curl` 进行测试说明了如何通过HTTP客户端与REST服务进行交互。返回的401未授权状态码揭示了Shiro框架对安全控制的介入,并要求客户端开发者在实现登录逻辑时,必须处理好认证和授权的过程。 总结来看,该文件描述了一个基于Shiro进行认证和授权的REST Web服务的创建、初始化和操作流程,涉及到项目构建、命令行操作和状态码解析等知识点。
recommend-type

(性能调优核心策略)CDS 7.1.62高吞吐写入场景下的JVM参数优化7件套

# 摘要 针对CDS 7.1.62在高吞吐写入场景下面临的JVM性能瓶颈,本文系统研究了JVM内存模型与垃圾回收机制对应用性能的影响,重点分析了堆内存结构、GC算法行为及对象生命周期特征。基于G1 GC的区域化回收优势,提出“七件套”JVM参数优化组合,并通过真实压测环境验证其在降低GC停顿、提升吞吐量和改善内存管理方面的有效性。结合生产部署实
recommend-type

28379芯片的DA代码怎么写

为 **TMS320F28379D** 芯片编写 DA(数模转换)功能代码时,通常该芯片内部已经集成了 **12 位 DAC 模块**,支持多通道输出、内部参考电压、偏移校正等功能,适用于高精度模拟信号输出场景。以下将从硬件配置、初始化流程、代码结构等方面详细说明如何实现 DA 功能。 ### 3.1 初始化 DAC 模块 TMS320F28379D 的 DAC 模块需要进行以下配置: - **使能 DAC 模块时钟** - **设置 DAC 输出通道** - **配置 DAC 参考电压源** - **设置 DAC 输出值** #### 示例代码(C语言): ```c #includ