活动介绍

def cut_img(img):#把图的H切割为4的倍数 h = img.shape[0] remain = h%4 h_new = h - remain img_new = img[:h_new, :] return img_new

时间: 2024-04-18 21:28:54 浏览: 76
def cut_img(img): # 获取图像的高度 h = img.shape[0] # 计算需要切割的行数 remain = h % 4 h_new = h - remain # 切割图像 img_new = img[:h_new, :] return img_new
相关问题

from torch.utils.data import Dataset, DataLoader class IMDBDataset(Dataset): def __init__(self, data, labels, remain_length=True): if remain_length: self.data = tokenizer.decode(data, remove_bos=False, remove_eos=False, remove_pad=False) else: # 缩减一下数据 self.data = tokenizer.decode(data) self.labels = labels def __getitem__(self, index): text = self.data[index] label = self.labels[index] return text, label def __len__(self): return len(self.data) def collate_fct(batch): text_list = [item[0].split() for item in batch] label_list = [item[1] for item in batch] # 这里使用 padding first text_list = tokenizer.encode(text_list, padding_first=True).to(dtype=torch.int) return text_list, torch.tensor(label_list).reshape(-1, 1).to(dtype=torch.float) # 用RNN,缩短序列长度 train_ds = IMDBDataset(train_data, train_labels, remain_length=False) test_ds = IMDBDataset(test_data, test_labels, remain_length=False) 这段代码有什么用

<think>好的,我现在需要帮用户分析他们提供的这段PyTorch代码的用途。首先,我应该仔细阅读代码,理解各个部分的功能。用户可能对Dataset和DataLoader不太熟悉,所以得先解释这些概念。 首先,用户定义了一个IMDBDataset类,继承自Dataset。这个类应该用于处理IMDB数据集的数据和标签。在__init__方法里,根据remain_length参数决定如何处理数据。如果remain_length为True,使用tokenizer.decode并保留bos、eos和pad标记;否则,可能去掉这些标记,缩短序列长度。这可能是为了适应不同的模型需求,比如RNN需要较短的序列。 接下来是__getitem__和__len__方法,这是Dataset必须实现的,用于获取单个样本和数据集长度。这很标准,没什么问题。 然后是collate_fct函数,这个函数用于DataLoader的collate_fn参数,处理批次数据。里面将文本分割成列表,再用tokenizer.encode进行编码,可能添加padding。这里需要注意padding_first参数的作用,可能是在序列前面填充,这对某些模型是必要的。标签被转换成张量,并调整形状和类型。 最后,用户创建了train_ds和test_ds,设置remain_length为False,可能是为了训练RNN模型时减少序列长度,提高效率。 可能需要考虑tokenizer的来源和具体实现,但用户没有提供这部分代码。另外,数据预处理的方式会影响模型训练,比如填充的位置和是否保留特殊标记会影响RNN的处理。用户可能想了解这段代码在数据加载和预处理中的作用,以及如何与模型配合。需要解释Dataset和DataLoader的作用,以及collate_fn的自定义处理如何影响数据批次的组织。同时,提到RNN使用缩短序列的原因,比如处理长序列的梯度问题或计算效率。</think>这段代码构建了一个基于PyTorch的文本分类数据处理管道,主要用于处理IMDB电影评论数据集。以下是逐层解析: 一、核心组件作用 1. `IMDBDataset`类 - 继承自`Dataset`:实现数据集的标准化接口 - `remain_length`参数控制序列处理: * $True$ 时保留BOS(开头)/EOS(结尾)/PAD(填充)标记 * $False$ 时去除特殊标记,缩短序列长度(适应RNN特性) 2. `collate_fct`函数 - 实现批处理逻辑: $$ \text{原始文本} \xrightarrow{split} \text{分词列表} \xrightarrow{encode} \text{数值编码} $$ - `padding_first=True`:在序列前端进行填充对齐 - 标签转换为二维浮点张量:适配损失函数计算 二、数据处理流程 1. 文本解码 ```python tokenizer.decode(data) # 将数值编码还原为文本 ``` 2. 动态批处理 ```python DataLoader(..., collate_fn=collate_fct) # 自定义批次组织逻辑 ``` 三、设计特点 1. 序列长度优化 - 训练时`remain_length=False`:通过去除填充标记降低RNN计算量 - 验证测试时保留完整长度(虽然示例中test_ds也设为False,但实际可根据需求调整) 2. 内存效率 - 延迟解码机制:仅在`__getitem__`调用时执行解码,避免一次性加载全部数据到内存 四、典型应用场景 该代码适用于: - 长短序列混合的文本分类任务 - RNN/LSTM等循环神经网络模型 - 需要动态padding的批处理场景 - 大规模文本数据的流式处理 实际使用中需确保`tokenizer`已正确定义,且`train_data`/`test_data`已预先加载。通过`DataLoader`可进一步配置批大小、采样策略等参数。

import sys import numpy as np import torch import torch.nn.functional as F from PIL import Image from io import BytesIO import base64 from typing import Dict, Callable, Union import matplotlib.pyplot as plt def dict_apply( x: Dict[str, torch.Tensor], func: Callable[[torch.Tensor], torch.Tensor] ) -> Dict[str, torch.Tensor]: result = dict() for key, value in x.items(): if isinstance(value, dict): result[key] = dict_apply(value, func) else: result[key] = func(value) return result def show_bbox(box: list[int], ax) -> None: """用法,在显示 plt.imshow(image) 之后,把对应的 box 传入,再把对应的轴传入,即可绘出边框。 Args: box (list[int]): 4 个元素 ax (plt.Axis): 通常使用 plt.gca() 传入 """ x0, y0 = box[0], box[1] w, h = box[2] - box[0], box[3] - box[1] ax.add_patch( plt.Rectangle((x0, y0), w, h, edgecolor="green", facecolor=(0, 0, 0, 0), lw=2) ) def show_mask(mask, ax, random_color=False, borders=True): if random_color: color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) else: color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6]) h, w = mask.shape[-2:] mask = mask.astype(np.uint8) mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) if borders: import cv2 contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) # Try to smooth contours contours = [ cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours ] mask_image = cv2.drawContours( mask_image, contours, -1, (1, 1, 1, 0.5), thickness=2 ) ax.imshow(mask_image) def show_image_with_bbox(image: np.ndarray, box: list[int]) -> None: plt.imshow(image) show_bbox(box, plt.gca()) plt.show() def show_image_with_mask(image: np.ndarray, mask: np.ndarray) -> None: plt.imshow(image) show_mask(mask, plt.gca()) plt.show() def image_to_base64(image: Union[np.ndarray, Image.Image]) -> str: """将 numpy 数组转换为 base64 编码的 JPEG 格式 Args: np_image (np.ndarray): RGB 图像。 Returns: str: base64 编码的 JPEG 格式 """ if isinstance(image, Image.Image): pil_image = image elif isinstance(image, np.ndarray): # 确保数据类型为 uint8 if image.dtype != np.uint8: image = image.astype(np.uint8) # 转换维度顺序为 HWC 格式(如果必要),3 通道代表 RGB,4 代表 RGBA if image.shape[-1] not in [3, 4]: image = np.transpose(image, (1, 2, 0)) pil_image = Image.fromarray(image) else: raise ValueError(f"Invalid image type: {type(image)}") buffer = BytesIO() pil_image.save(buffer, format="JPEG") return base64.b64encode(buffer.getvalue()).decode("utf-8") def base64_to_image(base64_string: str) -> np.ndarray: padding = 4 - (len(base64_string) % 4) if padding != 4: base64_string += "=" * padding if "base64," in base64_string: base64_string = base64_string.split("base64,")[1] img_data = base64.b64decode(base64_string) bytes_io = BytesIO(img_data) image = Image.open(bytes_io) if image.mode != "RGB": image = image.convert("RGB") return np.array(image) def interpolate_image_batch( image: np.ndarray, image_size: tuple[int, int] = (518, 518) ) -> torch.Tensor: """Convert image from (T, H, W, 3) into shape (T, 3, H, W), and interpolate them into (518, 518). If the last dimension of image shape is greater than 3, then the remain dimension will be truncated. Args: image (np.ndarray): of shape (T, H, W, 3) and dtype np.uint8 image_size (tuple[int, int], optional): Defaults to (518, 518). Returns: torch.Tensor: of shape (T, 3, 518, 518) and dtype torch.float32 """ rgb = torch.from_numpy(image[..., :3]).float() # [T, H, W, 3] rgb = rgb.permute(0, 3, 1, 2) # [T, 3, H, W] # 需要最后两个维度是 (H, W),才能使用广播的性质 rgb = F.interpolate( rgb / 255.0, size=image_size, mode="bilinear", align_corners=False ) return rgb def interpolate_mask_batch( mask: np.ndarray, image_size: tuple[int, int] = (518, 518) ) -> torch.Tensor: """Rearrange mask into shape (T, 1, H, W), and interpolate them into (518, 518) Args: mask (np.ndarray): of shape (T, H, W) image_size (tuple[int, int], optional): Defaults to (518, 518). Returns: torch.Tensor: of shape (T, 1, 518, 518) """ # Process mask mask = torch.from_numpy(mask).float() # [T, H, W] mask = mask.unsqueeze(1) # [T, 1, H, W] mask = F.interpolate(mask, size=image_size, mode="nearest") mask = (mask > 0.5).float() return mask def interpolate_mask_image_batch( image_mask: np.ndarray, image_size: tuple[int, int] = (518, 518) ) -> torch.Tensor: """Rearrange image and mask into shape (T, 4, H, W), and interpolate them into (518, 518) Args: image_mask (np.ndarray): of shape (T, H, W, 4) and dtype np.uint8 image_size (tuple[int, int]): Defaults to (518, 518). Return: image_mask (torch.Tensor): of shape (T, 4, 518, 518) and dtype torch.float32 """ rgb = torch.from_numpy(image_mask[..., :3]).float() # [T, H, W, 3] rgb = interpolate_image_batch(rgb, image_size) # Process mask mask = torch.from_numpy(image_mask[..., -1]).float() # [T, H, W] mask = interpolate_mask_batch(mask, image_size) # Combine combined = torch.cat([rgb, mask], dim=1) # [T, 4, H, W] return combined # making dataset def get_start_command(): # Clear input buffer sys.stdin.flush() while True: user_input = input("Press <Enter> to start, <q> to quit.") if user_input == "q": return False elif user_input == "": return True else: print("Invalid input. Please press <Enter> or <q>.") 全文注释

<think>我们有一个Python文件,包含多个函数和工具,用于图像处理、显示和转换等。我们需要为每个函数和类提供详细的中文注释。 注意:注释需要解释函数的作用、参数、返回值以及关键步骤。 同时,按照要求,在注释完成后,提出3-5个相关问题。 注意:问题必须放在最后,以“§§相关问题§§”开头,然后列出问题。 由于代码较长,我们将逐函数进行注释。</think>以下是对提供代码的完整中文注释: ```python import sys import numpy as np import torch import torch.nn.functional as F # PyTorch神经网络功能模块 from PIL import Image # Python图像处理库 from io import BytesIO # 内存字节流处理 import base64 # Base64编码解码 from typing import Dict, Callable, Union # 类型提示 import matplotlib.pyplot as plt # 数据可视化库 # 递归应用函数到字典中的张量 def dict_apply( x: Dict[str, torch.Tensor], func: Callable[[torch.Tensor], torch.Tensor] ) -> Dict[str, torch.Tensor]: """递归地将函数应用于字典中的所有张量 Args: x: 包含张量或嵌套字典的字典 func: 应用于每个张量的函数 Returns: 应用函数后的新字典 """ result = dict() for key, value in x.items(): if isinstance(value, dict): # 如果是嵌套字典 result[key] = dict_apply(value, func) # 递归处理 else: result[key] = func(value) # 应用函数到张量 return result # 在图像上绘制边界框 def show_bbox(box: list[int], ax) -> None: """在matplotlib轴上绘制边界框 Args: box: 边界框坐标 [x0, y0, x1, y1] ax: matplotlib轴对象 """ x0, y0 = box[0], box[1] # 左上角坐标 w, h = box[2] - box[0], box[3] - box[1] # 宽度和高度 # 创建并添加矩形补丁 ax.add_patch( plt.Rectangle( (x0, y0), w, h, edgecolor="green", # 绿色边框 facecolor=(0, 0, 0, 0), # 透明填充 lw=2 # 线宽 ) ) # 在图像上显示掩码 def show_mask(mask, ax, random_color=False, borders=True): """在matplotlib轴上可视化掩码 Args: mask: 二维掩码数组 ax: matplotlib轴对象 random_color: 是否使用随机颜色 borders: 是否绘制边界 """ # 设置颜色(默认蓝色或随机色) if random_color: color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) else: color = np.array([30/255, 144/255, 255/255, 0.6]) # 半透明蓝色 h, w = mask.shape[-2:] # 获取掩码尺寸 mask = mask.astype(np.uint8) # 转换为uint8类型 # 创建彩色掩码图像 mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) if borders: import cv2 # 仅在需要时导入OpenCV # 查找轮廓 contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) # 平滑轮廓 contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours] # 在掩码图像上绘制轮廓 mask_image = cv2.drawContours(mask_image, contours, -1, (1, 1, 1, 0.5), thickness=2) ax.imshow(mask_image) # 显示掩码 # 显示带边界框的图像 def show_image_with_bbox(image: np.ndarray, box: list[int]) -> None: """显示图像并绘制边界框 Args: image: 输入图像数组 box: 边界框坐标 [x0, y0, x1, y1] """ plt.imshow(image) # 显示图像 show_bbox(box, plt.gca()) # 在当前轴上绘制边界框 plt.show() # 显示图像 # 显示带掩码的图像 def show_image_with_mask(image: np.ndarray, mask: np.ndarray) -> None: """显示图像并叠加掩码 Args: image: 输入图像数组 mask: 二维掩码数组 """ plt.imshow(image) # 显示图像 show_mask(mask, plt.gca()) # 在当前轴上显示掩码 plt.show() # 显示图像 # 将图像转换为Base64编码字符串 def image_to_base64(image: Union[np.ndarray, Image.Image]) -> str: """将图像转换为Base64编码的JPEG格式字符串 Args: image: 输入图像(numpy数组或PIL图像) Returns: Base64编码的JPEG字符串 """ if isinstance(image, Image.Image): pil_image = image # 已经是PIL图像 elif isinstance(image, np.ndarray): # 确保数据类型为uint8 if image.dtype != np.uint8: image = image.astype(np.uint8) # 调整维度顺序为HWC(高度, 宽度, 通道) if image.shape[-1] not in [3, 4]: image = np.transpose(image, (1, 2, 0)) pil_image = Image.fromarray(image) # 转换为PIL图像 else: raise ValueError(f"Invalid image type: {type(image)}") buffer = BytesIO() # 创建内存缓冲区 pil_image.save(buffer, format="JPEG") # 保存为JPEG格式 return base64.b64encode(buffer.getvalue()).decode("utf-8") # 编码为Base64字符串 # 将Base64字符串转换为图像 def base64_to_image(base64_string: str) -> np.ndarray: """将Base64字符串转换为numpy图像数组 Args: base64_string: Base64编码的图像字符串 Returns: RGB格式的numpy图像数组 """ # 处理Base64填充 padding = 4 - (len(base64_string) % 4) if padding != 4: base64_string += "=" * padding # 处理可能的数据URI前缀 if "base64," in base64_string: base64_string = base64_string.split("base64,")[1] img_data = base64.b64decode(base64_string) # 解码Base64 bytes_io = BytesIO(img_data) # 创建字节流 image = Image.open(bytes_io) # 打开图像 # 转换为RGB格式 if image.mode != "RGB": image = image.convert("RGB") return np.array(image) # 转换为numpy数组 # 批量插值图像 def interpolate_image_batch( image: np.ndarray, image_size: tuple[int, int] = (518, 518) ) -> torch.Tensor: """批量调整图像尺寸并进行归一化 Args: image: 输入图像数组 (T, H, W, 3) image_size: 目标尺寸 (高度, 宽度) Returns: 处理后的图像张量 (T, 3, H, W) """ # 提取RGB通道并转换为张量 rgb = torch.from_numpy(image[..., :3]).float() # [T, H, W, 3] rgb = rgb.permute(0, 3, 1, 2) # 转换为 [T, 3, H, W] # 双线性插值调整尺寸并归一化到[0,1] rgb = F.interpolate( rgb / 255.0, size=image_size, mode="bilinear", align_corners=False ) return rgb # 批量插值掩码 def interpolate_mask_batch( mask: np.ndarray, image_size: tuple[int, int] = (518, 518) ) -> torch.Tensor: """批量调整掩码尺寸并进行二值化 Args: mask: 输入掩码数组 (T, H, W) image_size: 目标尺寸 (高度, 宽度) Returns: 处理后的掩码张量 (T, 1, H, W) """ mask = torch.from_numpy(mask).float() # 转换为张量 mask = mask.unsqueeze(1) # 添加通道维度 [T, 1, H, W] # 最近邻插值调整尺寸 mask = F.interpolate(mask, size=image_size, mode="nearest") # 二值化处理(阈值0.5) mask = (mask > 0.5).float() return mask # 批量插值图像和掩码 def interpolate_mask_image_batch( image_mask: np.ndarray, image_size: tuple[int, int] = (518, 518) ) -> torch.Tensor: """批量处理同时包含图像和掩码的数据 Args: image_mask: 输入数组 (T, H, W, 4) 最后通道为掩码 image_size: 目标尺寸 (高度, 宽度) Returns: 组合后的张量 (T, 4, H, W) """ # 处理图像部分 rgb = interpolate_image_batch(image_mask[..., :3], image_size) # 处理掩码部分 mask = interpolate_mask_batch(image_mask[..., -1], image_size) # 沿通道维度拼接 combined = torch.cat([rgb, mask], dim=1) # [T, 4, H, W] return combined # 获取用户启动命令 def get_start_command(): """等待用户输入启动命令 Returns: True: 用户按下Enter启动 False: 用户输入q退出 """ sys.stdin.flush() # 清空输入缓冲区 while True: user_input = input("Press <Enter> to start, <q> to quit.") if user_input == "q": return False elif user_input == "": return True else: print("Invalid input. Please press <Enter> or <q>.") ```
阅读全文

相关推荐

while not rospy.is_shutdown(): try: trans, rot = location_listener.lookupTransform('map', 'body_foot', rospy.Time(0)) current_x = trans[0] current_y = trans[1] #返回的是机器人在 map 坐标系下的位置信息和姿态信息(四元数形式的旋转) # print('position',trans,rot) except (tf.LookupException, tf.ConnectivityException, tf.ExtrapolationException): pass # print("logic no TF get !") seq_1 = seq_1 + 1 if seq_1 >= 5: seq_1 = 0 if robot_HP < 200 : require_add_HP = 1 if robot_HP > 350: require_add_HP = 0 wait_attack_cnt = wait_attack_cnt - 1 # print(" wait_attack_cnt :",wait_attack_cnt) #等待攻击时间递减 if wait_attack_cnt <= 0: wait_attack_cnt = 500 attack_cnt = 100 #攻击时间递减 if attack_cnt > 0: attack_cnt = attack_cnt - 1 #锁定计时器递减 if lock_yaw_cnt > 0: lock_yaw_cnt = lock_yaw_cnt - 1 # print(" attack_cnt :",attack_cnt) # target_yaw = target_yaw + 0.475 #0.475 target_pose = PoseStamped() target_pose.header.frame_id = "map" # 加血 if (game_status == 4) and (require_add_HP == 1):#补血区 target_pose.pose.position.x = 7.6 target_pose.pose.position.y = -4.5 target_pose.pose.position.z = 0 #进攻掩护步兵占领中心增益点 elif (game_status == 4) and (robot_HP > 300) and (attack_cnt > 0):#中心增益点靠前部分 target_pose.pose.position.x = 2 target_pose.pose.position.y = -2.9 target_pose.pose.position.z = 0 elif (game_status == 4) and (robot_HP > 300) and (attack_cnt == 0): #没到进攻时间在中心增益点附近 if remain_time > 250: target_pose.pose.position.x = 3.9 target_pose.pose.position.y = -2.9 #退让 elif remain_time > 200: target_pose.pose.position.x = 7.6 target_pose.pose.position.y = -4.5 #比赛最后时段 elif remain_time > 150: target_pose.pose.position.x = 3.9 target_pose.pose.position.y = -2.9 else: target_pose.pose.position.x = 3.9 target_pose.pose.position.y = -2.9 target_pose.pose.position.z = 0 # 否则待在原地 else:#原地 target_pose.pose.position.x = 1 target_pose.pose.position.y = -1 target_pose.pose.position.z = 0 target_quad = quaternion_from_euler(0, 0, target_yaw) target_pose.pose.orientation.x = target_quad[0] target_pose.pose.orientation.y = target_quad[1] target_pose.pose.orientation.z = target_quad[2] target_pose.pose.orientation.w = target_quad[3] distance = math.sqrt((target_pose.pose.position.x - current_x)**2 + (target_pose.pose.position.y - current_y)**2) if distance <= 0.2: print("have got goal! ") cancel_msg = GoalID() cancel_msg.stamp = rospy.Time.now() cancel_pub.publish(cancel_msg) else: location_target_publisher.publish(target_pose) print("distance",distance)请问我这样写可以实现当距离目标点小于0.1m时取消上一个目标点的发布么

import heapq import bisect def main(): import sys input = sys.stdin.read().split() idx = 0 N = int(input[idx]) idx +=1 C = int(input[idx]) idx +=1 schools = [] for _ in range(N): name = input[idx] num = int(input[idx+1]) idx +=2 schools.append( (name, num, num, 0) ) # (缩写, 初始人数, 剩余人数, 监考官数目) # 初始化最大堆,存储(-剩余人数, 学校索引) heap = [] for i in range(N): if schools[i][2] >0: heapq.heappush(heap, (-schools[i][2], i)) total_rooms = 0 # remain字典: key是剩余容量r,value是赛场的编号列表,按编号递增排序 remain = {} while heap: current_neg, idx = heapq.heappop(heap) current_n = -current_neg name, _, remain_n, count = schools[idx] if remain_n != current_n: continue # 过时的记录 if remain_n <=0: continue if current_n >= C: # 处理n >= C的情况 total_rooms +=1 new_remain = current_n - C schools[idx] = (name, schools[idx][1], new_remain, count +1) if new_remain >0: heapq.heappush(heap, (-new_remain, idx)) else: # 处理n < C的情况 found = False for r in range(C, current_n -1, -1): if r in remain and len(remain[r]) >0: # 取出编号最小的赛场 s_id = remain[r].pop(0) # 如果取出后列表为空,可以删除该键(可选) if len(remain[r]) ==0: del remain[r] # 计算新的剩余容量 new_r = r - current_n if new_r >0: if new_r not in remain: remain[new_r] = [] bisect.insort(remain[new_r], s_id) # 更新该学校的监考官数目 schools[idx] = (name, schools[idx][1], 0, count +1) found = True break if not found: # 新开一个赛场 total_rooms +=1 new_r = C - current_n s_id = total_rooms # 新赛场的编号 schools[idx] = (name, schools[idx][1], 0, count +1) if new_r >0: if new_r not in remain: remain[new_r] = [] bisect.insort(remain[new_r], s_id) # 输出结果 for sch in schools: print(f"{sch[0]} {sch[3]}") print(total_rooms) if __name__ == "__main__": main()转c++

使用gaussian噪声,基于参考分子生成 1000 个分子,初步过滤后有效分子: 1000 进一步过滤后,含苯环分子数量: 939 含苯环分子比例: 93.90% [15:31:55] WARNING: could not find number of expected rings. Switching to an approximate ring finding algorithm. 进一步过滤后,单一片段分子数量: 0 单一片段分子比例: 0.00% 没有分子可保存 开始基于参考分子生成新分子(均匀噪声,强制含苯环)... [15:31:55] WARNING: could not find number of expected rings. Switching to an approximate ring finding algorithm. [15:31:55] WARNING: could not find number of expected rings. Switching to an approximate ring finding algorithm. [15:31:55] WARNING: could not find number of expected rings. Switching to an approximate ring finding algorithm. 使用uniform噪声,基于参考分子生成 1000 个分子,初步过滤后有效分子: 1000 [15:31:56] WARNING: could not find number of expected rings. Switching to an approximate ring finding algorithm. 进一步过滤后,含苯环分子数量: 957 含苯环分子比例: 95.70% [15:31:56] WARNING: could not find number of expected rings. Switching to an approximate ring finding algorithm. [15:31:56] WARNING: could not find number of expected rings. Switching to an approximate ring finding algorithm. 进一步过滤后,单一片段分子数量: 1 单一片段分子比例: 0.10%这个是下面的代码运行的结果import torch import torch.nn as nn import torch.optim as optim import pandas as pd from rdkit import Chem from rdkit.Chem import Draw, rdMolDescriptors from torch_geometric.data import Data, Batch from torch_geometric.nn import GCNConv, global_mean_pool, GATConv, GraphConv from torch_geometric.loader import DataLoader import matplotlib.pyplot as plt from sklearn.preprocessing import StandardScaler import os import numpy as np import random from functools import lru_cache from torch.cuda import amp # 设置随机种子以确保结果可复现 def set_seed(seed): torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) np.random.seed(seed) random.seed(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False set_seed(42) # ------------------------- 工具函数 ------------------------- def validate_atomic_nums(atomic_nums): valid_atoms = {1, 6, 7, 8, 16, 9, 17} # H, C, N, O, S, F, Cl if isinstance(atomic_nums, torch.Tensor): atomic_nums = atomic_nums.cpu().numpy() if atomic_nums.is_cuda else atomic_nums.numpy() if isinstance(atomic_nums, list): atomic_nums = np.array(atomic_nums) if atomic_nums.ndim > 1: atomic_nums = atomic_nums.flatten() atomic_nums = np.round(atomic_nums).astype(int) return [num if num in valid_atoms else 6 for num in atomic_nums] def smiles_to_graph(smiles): mol = Chem.MolFromSmiles(smiles) if mol is None: return None Chem.SanitizeMol(mol) atom_feats = [atom.GetAtomicNum() for atom in mol.GetAtoms()] edge_index = [] for bond in mol.GetBonds(): i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx() edge_index += [(i, j), (j, i)] x = torch.tensor(atom_feats, dtype=torch.long).view(-1, 1) edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous() return Data(x=x, edge_index=edge_index) def augment_mol(smiles): mol = Chem.MolFromSmiles(smiles) if mol: try: mol = Chem.AddHs(mol) except: pass return Chem.MolToSmiles(mol) return smiles # ------------------------- 增强版边生成算法(改进版)------------------------- def generate_realistic_edges_improved(node_features, avg_edge_count=20, atomic_nums=None, force_min_edges=True): if node_features is None or node_features.numel() == 0: return torch.empty((2, 0), dtype=torch.long), torch.tensor([]) num_nodes = node_features.shape[0] if num_nodes < 2: return torch.empty((2, 0), dtype=torch.long), torch.tensor([]) atomic_nums = validate_atomic_nums(atomic_nums) atomic_nums = [int(num) for num in atomic_nums] max_valence = {1: 1, 6: 4, 7: 3, 8: 2, 16: 6, 9: 1, 17: 1} # 计算节点相似度并添加额外的连接倾向 similarity = torch.matmul(node_features, node_features.transpose(0, 1)).squeeze() # 添加额外的连接偏向,确保分子整体性(避免孤立节点) connectivity_bias = torch.ones_like(similarity) - torch.eye(num_nodes, dtype=torch.long) similarity = similarity + 0.5 * connectivity_bias.float() norm_similarity = similarity / (torch.max(similarity) + 1e-8) edge_probs = norm_similarity.view(-1) num_possible_edges = num_nodes * (num_nodes - 1) num_edges = min(avg_edge_count, num_possible_edges) _, indices = torch.topk(edge_probs, k=num_edges) edge_set = set() edge_index = [] used_valence = {i: 0 for i in range(num_nodes)} bond_types = [] tried_edges = set() # 确保至少有一个连接,并优先连接苯环部分 if num_nodes >= 2 and force_min_edges: if num_nodes >= 6 and all(atomic_nums[i] == 6 for i in range(6)): # 先连接苯环形成基础结构 for i in range(6): j = (i + 1) % 6 if (i, j) not in tried_edges: edge_set.add((i, j)) edge_set.add((j, i)) edge_index.append([i, j]) edge_index.append([j, i]) used_valence[i] += 1 used_valence[j] += 1 bond_types.append(2 if i % 2 == 0 else 1) bond_types.append(2 if i % 2 == 0 else 1) tried_edges.add((i, j)) tried_edges.add((j, i)) # 连接苯环到第一个非苯环原子(确保分子整体性) if num_nodes > 6: first_non_benzene = 6 # 寻找与苯环连接最紧密的非苯环原子 max_sim = -1 best_benzene = 0 for i in range(6): sim = norm_similarity[i, first_non_benzene] if sim > max_sim: max_sim = sim best_benzene = i if (best_benzene, first_non_benzene) not in tried_edges: edge_set.add((best_benzene, first_non_benzene)) edge_set.add((first_non_benzene, best_benzene)) edge_index.append([best_benzene, first_non_benzene]) edge_index.append([first_non_benzene, best_benzene]) used_valence[best_benzene] += 1 used_valence[first_non_benzene] += 1 bond_types.append(1) bond_types.append(1) tried_edges.add((best_benzene, first_non_benzene)) tried_edges.add((first_non_benzene, best_benzene)) else: # 常规分子的最小连接(确保连通性) max_sim = -1 best_u, best_v = 0, 1 for i in range(num_nodes): for j in range(i + 1, num_nodes): if norm_similarity[i, j] > max_sim and (i, j) not in tried_edges: max_sim = norm_similarity[i, j] best_u, best_v = i, j edge_set.add((best_u, best_v)) edge_set.add((best_v, best_u)) edge_index.append([best_u, best_v]) edge_index.append([best_v, best_u]) used_valence[best_u] += 1 used_valence[best_v] += 1 bond_types.append(1) bond_types.append(1) tried_edges.add((best_u, best_v)) tried_edges.add((best_v, best_u)) # 改进边生成逻辑,确保分子连接性(优先连接未连通部分) for idx in indices: u, v = (idx // num_nodes).item(), (idx % num_nodes).item() if u == v or (u, v) in tried_edges or (v, u) in tried_edges: continue tried_edges.add((u, v)) tried_edges.add((v, u)) atom_u, atom_v = atomic_nums[u], atomic_nums[v] max_u, max_v = max_valence[atom_u], max_valence[atom_v] avail_u, avail_v = max_u - used_valence[u], max_v - used_valence[v] # 不允许两个氢原子相连 if atom_u == 1 and atom_v == 1: continue # 确保连接不会导致分子分裂(至少需要n-1条边连接n个节点) if len(edge_set) < num_nodes - 1: edge_set.add((u, v)) edge_set.add((v, u)) edge_index.append([u, v]) edge_index.append([v, u]) used_valence[u] += 1 used_valence[v] += 1 bond_types.append(1) bond_types.append(1) continue # 常规连接逻辑 if avail_u >= 1 and avail_v >= 1: edge_set.add((u, v)) edge_set.add((v, u)) edge_index.append([u, v]) edge_index.append([v, u]) used_valence[u] += 1 used_valence[v] += 1 bond_types.append(1) bond_types.append(1) continue # 尝试形成双键(增加分子稳定性) if atom_u != 1 and atom_v != 1 and avail_u >= 2 and avail_v >= 2 and random.random() < 0.3: edge_set.add((u, v)) edge_set.add((v, u)) edge_index.append([u, v]) edge_index.append([v, u]) used_valence[u] += 2 used_valence[v] += 2 bond_types.append(2) bond_types.append(2) # 确保分子连接性的最后检查 if len(edge_index) == 0 and num_nodes >= 2: edge_index = [[0, 1], [1, 0]] bond_types = [1, 1] used_valence[0], used_valence[1] = 1, 1 return torch.tensor(edge_index).t().contiguous(), torch.tensor(bond_types) @lru_cache(maxsize=128) def cached_calculate_tpsa(mol): try: return rdMolDescriptors.CalcTPSA(mol) if mol else 0 except: return 0 @lru_cache(maxsize=128) def cached_calculate_ring_penalty(mol): try: if mol: ssr = Chem.GetSymmSSSR(mol) return len(ssr) * 0.2 return 0 except: return 0 # ------------------------- Gumbel-Softmax 实现 ------------------------- def gumbel_softmax(logits, tau=1, hard=False, eps=1e-10, dim=-1): gumbels = -(torch.empty_like(logits).exponential_() + eps).log() gumbels = (logits + gumbels) / tau y_soft = gumbels.softmax(dim) if hard: index = y_soft.max(dim, keepdim=True)[1] y_hard = torch.zeros_like(logits).scatter_(dim, index, 1.0) ret = y_hard - y_soft.detach() + y_soft else: ret = y_soft return ret # ------------------------- 核心优化:强化化学约束(改进版)------------------------- def build_valid_mol_improved(atomic_nums, edge_index, bond_types=None): atomic_nums = validate_atomic_nums(atomic_nums) atomic_nums = [int(num) for num in atomic_nums] bond_type_map = {1: Chem.BondType.SINGLE, 2: Chem.BondType.DOUBLE} mol = Chem.RWMol() atom_indices = [] for num in atomic_nums: atom = Chem.Atom(num) idx = mol.AddAtom(atom) atom_indices.append(idx) added_edges = set() if bond_types is not None: if isinstance(bond_types, torch.Tensor): bond_types = bond_types.cpu().numpy().tolist() if bond_types.is_cuda else bond_types.numpy().tolist() else: bond_types = [] max_valence = {1: 1, 6: 4, 7: 3, 8: 2, 16: 6, 9: 1, 17: 1} current_valence = {i: 0 for i in range(len(atomic_nums))} if edge_index.numel() > 0: # 按重要性排序边,优先连接形成单一片段(苯环连接优先) edge_importance = [] for j in range(edge_index.shape[1]): u, v = edge_index[0, j].item(), edge_index[1, j].item() # 计算边的重要性:苯环与非苯环连接 > 苯环内部连接 > 其他连接 importance = 2.0 if (u < 6 and v >= 6) or (v < 6 and u >= 6) else 1.5 if (u < 6 and v < 6) else 1.0 edge_importance.append((importance, j)) # 按重要性降序排序 edge_order = [j for _, j in sorted(edge_importance, key=lambda x: x[0], reverse=True)] for j in edge_order: u, v = edge_index[0, j].item(), edge_index[1, j].item() if u >= len(atomic_nums) or v >= len(atomic_nums) or u == v or (u, v) in added_edges: continue atom_u = mol.GetAtomWithIdx(atom_indices[u]) atom_v = mol.GetAtomWithIdx(atom_indices[v]) max_u = max_valence[atomic_nums[u]] max_v = max_valence[atomic_nums[v]] used_u = current_valence[u] used_v = current_valence[v] remain_u = max_u - used_u remain_v = max_v - used_v if remain_u >= 1 and remain_v >= 1: bond_type = bond_type_map.get(bond_types[j] if j < len(bond_types) else 1, Chem.BondType.SINGLE) bond_order = 1 if bond_type == Chem.BondType.SINGLE else 2 if bond_order > min(remain_u, remain_v): bond_order = 1 bond_type = Chem.BondType.SINGLE mol.AddBond(atom_indices[u], atom_indices[v], bond_type) added_edges.add((u, v)) added_edges.add((v, u)) current_valence[u] += bond_order current_valence[v] += bond_order # 更稳健的价态修复逻辑(优先保持分子连接性) try: Chem.SanitizeMol(mol) return mol except: # 先尝试添加氢原子修复(保持分子完整性) try: mol = Chem.AddHs(mol) Chem.SanitizeMol(mol) return mol except: pass # 智能移除键(仅移除导致价态问题的非关键键) try: problematic_bonds = [] for bond in mol.GetBonds(): begin = bond.GetBeginAtom() end = bond.GetEndAtom() if begin.GetExplicitValence() > max_valence[begin.GetAtomicNum()] or \ end.GetExplicitValence() > max_valence[end.GetAtomicNum()]: problematic_bonds.append(bond.GetIdx()) # 按键类型和连接重要性排序(先移除双键,再移除非苯环连接) problematic_bonds.sort(key=lambda idx: ( mol.GetBondWithIdx(idx).GetBondTypeAsDouble(), -1 if (mol.GetBondWithIdx(idx).GetBeginAtomIdx() < 6 and mol.GetBondWithIdx(idx).GetEndAtomIdx() >= 6) else 1 ), reverse=True) for bond_idx in problematic_bonds: mol.RemoveBond(mol.GetBondWithIdx(bond_idx).GetBeginAtomIdx(), mol.GetBondWithIdx(bond_idx).GetEndAtomIdx()) try: Chem.SanitizeMol(mol) break except: continue Chem.SanitizeMol(mol) return mol except: # 最后尝试移除氢原子(作为最后的修复手段) try: mol = Chem.RemoveHs(mol) Chem.SanitizeMol(mol) return mol except: return None def mol_to_graph(mol): if not mol: return None try: Chem.SanitizeMol(mol) atom_feats = [atom.GetAtomicNum() for atom in mol.GetAtoms()] edge_index = [] for bond in mol.GetBonds(): i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx() edge_index += [(i, j), (j, i)] x = torch.tensor(atom_feats, dtype=torch.long).view(-1, 1) edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous() return Data(x=x, edge_index=edge_index) except: return None def filter_valid_mols(mols): valid_mols = [] for mol in mols: if mol is None: continue try: Chem.SanitizeMol(mol) rdMolDescriptors.CalcExactMolWt(mol) valid_mols.append(mol) except: continue return valid_mols # ------------------------- 新增:分子片段过滤工具(改进版)------------------------- def filter_single_fragment_mols_improved(mol_list): valid_mols = [] for mol in mol_list: if mol is None: continue try: # 先尝试添加氢原子以确保分子完整性 mol = Chem.AddHs(mol) Chem.SanitizeMol(mol) # 使用更鲁棒的方法检测片段数 frags = Chem.GetMolFrags(mol, asMols=True, sanitizeFrags=False) if len(frags) == 1: # 进一步检查分子重量和原子数(过滤过小分子) mol_weight = rdMolDescriptors.CalcExactMolWt(mol) if mol_weight > 50 and mol.GetNumAtoms() > 5: valid_mols.append(mol) except: continue return valid_mols # ------------------------- 新增:苯环检测工具 ------------------------- def has_benzene_ring(mol): if mol is None: return False try: benzene_smarts = Chem.MolFromSmarts("c1ccccc1") return mol.HasSubstructMatch(benzene_smarts) except Exception as e: print(f"检测苯环失败: {e}") return False def filter_benzene_mols(mols): return [mol for mol in mols if has_benzene_ring(mol)] # ------------------------- 数据集 ------------------------- class MolecularGraphDataset(torch.utils.data.Dataset): def __init__(self, path): df = pd.read_excel(path, usecols=[0, 1, 2]) df.columns = ['SMILES', 'Concentration', 'Efficiency'] df.dropna(inplace=True) df['SMILES'] = df['SMILES'].apply(augment_mol) self.graphs, self.conditions = [], [] scaler = StandardScaler() cond_scaled = scaler.fit_transform(df[['Concentration', 'Efficiency']]) self.edge_stats = {'edge_counts': []} valid_count = 0 for i, row in df.iterrows(): graph = smiles_to_graph(row['SMILES']) if graph: atomic_nums = graph.x.squeeze().tolist() atomic_nums = validate_atomic_nums(atomic_nums) graph.x = torch.tensor(atomic_nums, dtype=torch.long).view(-1, 1) self.graphs.append(graph) self.conditions.append(torch.tensor(cond_scaled[i], dtype=torch.float32)) self.edge_stats['edge_counts'].append(graph.edge_index.shape[1] // 2) valid_count += 1 print(f"成功加载 {valid_count} 个分子") self.avg_edge_count = int(np.mean(self.edge_stats['edge_counts'])) if self.edge_stats['edge_counts'] else 20 print(f"真实图平均边数: {self.avg_edge_count}") def __len__(self): return len(self.graphs) def __getitem__(self, idx): return self.graphs[idx], self.conditions[idx] # ------------------------- 增强版生成器(含苯环条件编码) ------------------------- class EnhancedGraphGenerator(nn.Module): def __init__(self, noise_dim=16, condition_dim=2, benzene_condition_dim=1, hidden_dim=128, num_atoms=15): super().__init__() self.num_atoms = num_atoms self.benzene_embedding = nn.Embedding(2, hidden_dim) # 苯环条件嵌入 # 噪声和条件处理 self.initial_transform = nn.Sequential( nn.Linear(noise_dim + condition_dim + hidden_dim, hidden_dim), nn.LeakyReLU(0.2), nn.BatchNorm1d(hidden_dim), nn.Linear(hidden_dim, hidden_dim) ) # 苯环专用生成路径 self.benzene_generator = nn.Sequential( nn.Linear(hidden_dim, hidden_dim), nn.LeakyReLU(0.2), nn.Linear(hidden_dim, 6 * 7) # 6个碳原子,7种可能的原子类型 ) # 分子其余部分生成路径 self.rest_generator = nn.Sequential( nn.Linear(hidden_dim, hidden_dim), nn.LeakyReLU(0.2), nn.Linear(hidden_dim, (num_atoms - 6) * 7) # 剩余原子 ) # 原子间连接预测 self.edge_predictor = nn.Sequential( nn.Linear(hidden_dim, hidden_dim), nn.LeakyReLU(0.2), nn.Linear(hidden_dim, num_atoms * num_atoms) # 原子间连接概率 ) self.tau = 0.8 self.valid_atoms = torch.tensor([1, 6, 7, 8, 16, 9, 17], dtype=torch.long) def forward(self, z, cond, benzene_condition): # 嵌入苯环条件 benzene_feat = self.benzene_embedding(benzene_condition) # 合并噪声和条件 x = torch.cat([z, cond, benzene_feat], dim=1) shared_repr = self.initial_transform(x) # 生成苯环部分(前6个原子) benzene_logits = self.benzene_generator(shared_repr).view(-1, 6, 7) benzene_probs = gumbel_softmax(benzene_logits, tau=self.tau, hard=False, dim=-1) # 强制前6个原子为碳原子(苯环) benzene_indices = torch.ones_like(benzene_probs.argmax(dim=-1)) * 1 # 碳的索引是1 benzene_nodes = self.valid_atoms[benzene_indices].float().view(-1, 6, 1) / 17.0 # 生成分子其余部分 if self.num_atoms > 6: rest_logits = self.rest_generator(shared_repr).view(-1, self.num_atoms - 6, 7) rest_probs = gumbel_softmax(rest_logits, tau=self.tau, hard=False, dim=-1) rest_indices = rest_probs.argmax(dim=-1) rest_nodes = self.valid_atoms[rest_indices].float().view(-1, self.num_atoms - 6, 1) / 17.0 # 合并苯环和其余部分 node_feats = torch.cat([benzene_nodes, rest_nodes], dim=1) else: node_feats = benzene_nodes return node_feats # ------------------------- 增强版判别器 ------------------------- class EnhancedGraphDiscriminator(nn.Module): def __init__(self, node_feat_dim=1, condition_dim=2): super().__init__() # 图卷积层 self.conv1 = GCNConv(node_feat_dim + 1, 64) self.bn1 = nn.BatchNorm1d(64) self.conv2 = GATConv(64, 32, heads=3, concat=False) self.bn2 = nn.BatchNorm1d(32) self.conv3 = GraphConv(32, 16) self.bn3 = nn.BatchNorm1d(16) # 注意力机制 self.attention = nn.Sequential( nn.Linear(16, 8), nn.Tanh(), nn.Linear(8, 1) ) # 条件处理 self.condition_processor = nn.Sequential( nn.Linear(condition_dim, 16), nn.LeakyReLU(0.2) ) # 分类器 self.classifier = nn.Sequential( nn.Linear(16 + 16, 16), nn.LeakyReLU(0.2), nn.Linear(16, 1) ) def forward(self, data, condition): x, edge_index, batch = data.x.float(), data.edge_index, data.batch # 提取键类型特征 bond_features = torch.zeros(x.size(0), 1).to(x.device) if hasattr(data, 'bond_types') and data.bond_types is not None: bond_types = data.bond_types for i in range(edge_index.size(1)): u, v = edge_index[0, i].item(), edge_index[1, i].item() bond_type = bond_types[i] if i < len(bond_types) else 1 bond_features[u] = max(bond_features[u], bond_type) bond_features[v] = max(bond_features[v], bond_type) # 合并原子特征和键特征 x = torch.cat([x, bond_features], dim=1) # 图卷积 x = self.conv1(x, edge_index) x = self.bn1(x).relu() x = self.conv2(x, edge_index) x = self.bn2(x).relu() x = self.conv3(x, edge_index) x = self.bn3(x).relu() # 注意力池化 attn_weights = self.attention(x).squeeze() attn_weights = torch.softmax(attn_weights, dim=0) x = global_mean_pool(x * attn_weights.unsqueeze(-1), batch) # 处理条件 cond_repr = self.condition_processor(condition) # 合并图表示和条件 x = torch.cat([x, cond_repr], dim=1) # 分类 return torch.sigmoid(self.classifier(x)) # ------------------------- 训练流程(含苯环条件) ------------------------- def train_gan(dataset_path, epochs=2000, batch_size=32, lr=1e-5, device=None, generator_class=EnhancedGraphGenerator): if device is None: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"使用设备: {device}") dataset = MolecularGraphDataset(dataset_path) if len(dataset) == 0: print("错误:数据集为空") return None, None def collate_fn(data_list): graphs, conditions = zip(*data_list) batch_graphs = Batch.from_data_list(graphs) batch_conditions = torch.stack(conditions) return batch_graphs, batch_conditions dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn) # 创建生成器和判别器 generator = generator_class(noise_dim=16, condition_dim=2).to(device) discriminator = EnhancedGraphDiscriminator().to(device) # 调整学习率比例,让生成器学习更快 g_opt = optim.Adam(generator.parameters(), lr=lr * 3, betas=(0.5, 0.999)) d_opt = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999)) g_scheduler = optim.lr_scheduler.ReduceLROnPlateau(g_opt, 'min', patience=50, factor=0.5) d_scheduler = optim.lr_scheduler.ReduceLROnPlateau(d_opt, 'min', patience=50, factor=0.5) loss_fn = nn.BCELoss() use_amp = device.type == 'cuda' scaler = amp.GradScaler(enabled=use_amp) d_loss_history, g_loss_history = [], [] valid_mol_counts = [] fragment_counts = [] benzene_ratios = [] # 新增:记录含苯环分子比例 os.makedirs("results", exist_ok=True) os.chdir("results") for epoch in range(1, epochs + 1): generator.train() discriminator.train() total_d_loss, total_g_loss = 0.0, 0.0 for real_batch, conds in dataloader: real_batch = real_batch.to(device) conds = conds.to(device) batch_size = real_batch.num_graphs real_labels = torch.ones(batch_size, 1).to(device) fake_labels = torch.zeros(batch_size, 1).to(device) # 判别器训练 with amp.autocast(enabled=use_amp): d_real = discriminator(real_batch, conds) loss_real = loss_fn(d_real, real_labels) noise = torch.randn(batch_size, 16).to(device) # 生成苯环条件(50%概率要求含苯环) benzene_condition = torch.randint(0, 2, (batch_size,), device=device) fake_nodes = generator(noise, conds, benzene_condition).detach() fake_mols = [] for i in range(fake_nodes.shape[0]): node_feats = fake_nodes[i] atomic_nums = (node_feats.squeeze() * 17).cpu().numpy().round().astype(int) atomic_nums = validate_atomic_nums(atomic_nums) atomic_nums = [int(num) for num in atomic_nums] edge_index, bond_types = generate_realistic_edges_improved( node_feats, dataset.avg_edge_count, atomic_nums ) mol = build_valid_mol_improved(atomic_nums, edge_index, bond_types) if mol and mol.GetNumAtoms() > 0 and mol.GetNumBonds() > 0: fake_mols.append(mol_to_graph(mol)) if fake_mols: fake_batch = Batch.from_data_list(fake_mols).to(device) conds_subset = conds[:len(fake_mols)] d_fake = discriminator(fake_batch, conds_subset) loss_fake = loss_fn(d_fake, fake_labels[:len(fake_mols)]) d_loss = loss_real + loss_fake else: # 如果没有生成有效的分子,创建一个简单的分子作为占位符 mol = Chem.MolFromSmiles("CCO") fake_graph = mol_to_graph(mol) fake_batch = Batch.from_data_list([fake_graph]).to(device) d_fake = discriminator(fake_batch, conds[:1]) loss_fake = loss_fn(d_fake, fake_labels[:1]) d_loss = loss_real + loss_fake d_opt.zero_grad() scaler.scale(d_loss).backward() torch.nn.utils.clip_grad_norm_(discriminator.parameters(), 1.0) scaler.step(d_opt) scaler.update() total_d_loss += d_loss.item() # 生成器训练 with amp.autocast(enabled=use_amp): noise = torch.randn(batch_size, 16).to(device) benzene_condition = torch.randint(0, 2, (batch_size,), device=device) fake_nodes = generator(noise, conds, benzene_condition) fake_graphs = [] for i in range(fake_nodes.shape[0]): node_feats = fake_nodes[i] atomic_nums = (node_feats.squeeze() * 17).cpu().numpy().round().astype(int) atomic_nums = validate_atomic_nums(atomic_nums) atomic_nums = [int(num) for num in atomic_nums] edge_index, bond_types = generate_realistic_edges_improved( node_feats, dataset.avg_edge_count, atomic_nums ) fake_graphs.append(Data(x=node_feats, edge_index=edge_index, bond_types=bond_types)) valid_fake_graphs = [] for graph in fake_graphs: if graph.edge_index.numel() == 0: graph.edge_index = torch.tensor([[0, 0]], dtype=torch.long).t().to(device) graph.bond_types = torch.tensor([1], dtype=torch.long).to(device) valid_fake_graphs.append(graph) fake_batch = Batch.from_data_list(valid_fake_graphs).to(device) g_loss = loss_fn(discriminator(fake_batch, conds), real_labels) g_opt.zero_grad() scaler.scale(g_loss).backward() torch.nn.utils.clip_grad_norm_(generator.parameters(), 1.0) scaler.step(g_opt) scaler.update() total_g_loss += g_loss.item() avg_d_loss = total_d_loss / len(dataloader) avg_g_loss = total_g_loss / len(dataloader) d_loss_history.append(avg_d_loss) g_loss_history.append(avg_g_loss) g_scheduler.step(avg_g_loss) d_scheduler.step(avg_d_loss) if epoch % 10 == 0: generator.eval() discriminator.eval() with torch.no_grad(): num_samples = 50 noise = torch.randn(num_samples, 16).to(device) conds = torch.randn(num_samples, 2).to(device) benzene_condition = torch.ones(num_samples, dtype=torch.long, device=device) # 强制生成含苯环 fake_nodes = generator(noise, conds, benzene_condition) d_real_scores, d_fake_scores = [], [] for i in range(min(num_samples, 10)): real_idx = np.random.randint(0, len(dataset)) real_graph, real_cond = dataset[real_idx] real_batch = Batch.from_data_list([real_graph]).to(device) real_cond = real_cond.unsqueeze(0).to(device) d_real = discriminator(real_batch, real_cond) d_real_scores.append(d_real.item()) node_feats = fake_nodes[i] atomic_nums = (node_feats.squeeze() * 17).cpu().numpy().round().astype(int) atomic_nums = validate_atomic_nums(atomic_nums) atomic_nums = [int(num) for num in atomic_nums] edge_index, bond_types = generate_realistic_edges_improved( node_feats, dataset.avg_edge_count, atomic_nums ) mol = build_valid_mol_improved(atomic_nums, edge_index, bond_types) if mol: fake_graph = mol_to_graph(mol) if fake_graph: fake_batch = Batch.from_data_list([fake_graph]).to(device) fake_cond = conds[i].unsqueeze(0) d_fake = discriminator(fake_batch, fake_cond) d_fake_scores.append(d_fake.item()) if d_real_scores and d_fake_scores: print(f"Epoch {epoch}: D_loss={avg_d_loss:.4f}, G_loss={avg_g_loss:.4f}") print(f"D_real评分: {np.mean(d_real_scores):.4f} ± {np.std(d_real_scores):.4f}") print(f"D_fake评分: {np.mean(d_fake_scores):.4f} ± {np.std(d_fake_scores):.4f}") print(f"学习率: G={g_opt.param_groups[0]['lr']:.8f}, D={d_opt.param_groups[0]['lr']:.8f}") else: print(f"Epoch {epoch}: D_loss={avg_d_loss:.4f}, G_loss={avg_g_loss:.4f}") generator.train() if epoch % 100 == 0: torch.save(generator.state_dict(), f"generator_epoch_{epoch}.pt") torch.save(discriminator.state_dict(), f"discriminator_epoch_{epoch}.pt") generator.eval() with torch.no_grad(): num_samples = 25 noise = torch.randn(num_samples, 16).to(device) conds = torch.randn(num_samples, 2).to(device) benzene_condition = torch.ones(num_samples, dtype=torch.long, device=device) # 强制含苯环 fake_nodes = generator(noise, conds, benzene_condition) generated_mols = [] for i in range(num_samples): node_feats = fake_nodes[i] atomic_nums = (node_feats.squeeze() * 17).cpu().numpy().round().astype(int) atomic_nums = validate_atomic_nums(atomic_nums) atomic_nums = [int(num) for num in atomic_nums] edge_index, bond_types = generate_realistic_edges_improved( node_feats, dataset.avg_edge_count, atomic_nums ) mol = build_valid_mol_improved(atomic_nums, edge_index, bond_types) if mol and mol.GetNumAtoms() > 0 and mol.GetNumBonds() > 0: try: mol_weight = rdMolDescriptors.CalcExactMolWt(mol) logp = rdMolDescriptors.CalcCrippenDescriptors(mol)[0] tpsa = cached_calculate_tpsa(mol) generated_mols.append((mol, mol_weight, logp, tpsa)) except Exception as e: print(f"计算分子描述符时出错: {e}") # 过滤无效分子 valid_mols = filter_valid_mols([mol for mol, _, _, _ in generated_mols]) # 统计含苯环分子比例 benzene_mols = filter_benzene_mols(valid_mols) benzene_ratio = len(benzene_mols) / len(valid_mols) if valid_mols else 0 benzene_ratios.append(benzene_ratio) # 统计分子片段情况(使用改进的过滤函数) single_fragment_mols = filter_single_fragment_mols_improved(valid_mols) fragment_ratio = len(single_fragment_mols) / len(valid_mols) if valid_mols else 0 fragment_counts.append(fragment_ratio) valid_mol_counts.append(len(single_fragment_mols)) print(f"Epoch {epoch}: 生成{len(generated_mols)}个分子,初步过滤后保留{len(valid_mols)}个合法分子") print(f"Epoch {epoch}: 含苯环分子比例: {len(benzene_mols)}/{len(valid_mols)} = {benzene_ratio:.2%}") print( f"Epoch {epoch}: 单一片段分子比例: {len(single_fragment_mols)}/{len(valid_mols)} = {fragment_ratio:.2%}") if single_fragment_mols: filtered_mols = [] for mol in single_fragment_mols: try: mol_weight = rdMolDescriptors.CalcExactMolWt(mol) logp = rdMolDescriptors.CalcCrippenDescriptors(mol)[0] tpsa = cached_calculate_tpsa(mol) filtered_mols.append((mol, mol_weight, logp, tpsa)) except: continue if filtered_mols: filtered_mols.sort(key=lambda x: x[1]) mols = [mol for mol, _, _, _ in filtered_mols] legends = [ f"Mol {i + 1}\nMW: {mw:.2f}\nLogP: {logp:.2f}\nTPSA: {tpsa:.2f}" for i, (_, mw, logp, tpsa) in enumerate(filtered_mols) ] img = Draw.MolsToGridImage(mols, molsPerRow=5, subImgSize=(200, 200), legends=legends) img.save(f"generated_molecules_epoch_{epoch}.png") print(f"Epoch {epoch}: 保存{len(filtered_mols)}个单一片段分子的可视化结果") else: print(f"Epoch {epoch}: 过滤后无合法单一片段分子可显示") else: print(f"Epoch {epoch}: 未生成任何合法单一片段分子") generator.train() # 绘制损失曲线 plt.figure(figsize=(10, 5)) plt.plot(d_loss_history, label='Discriminator Loss') plt.plot(g_loss_history, label='Generator Loss') plt.xlabel('Epoch') plt.ylabel('Loss') plt.title('GAN Loss Curve') plt.legend() plt.savefig('gan_loss_curve.png') plt.close() # 绘制合法分子数量曲线 plt.figure(figsize=(10, 5)) plt.plot([i * 100 for i in range(1, len(valid_mol_counts) + 1)], valid_mol_counts, label='Valid Single-Fragment Molecules') plt.xlabel('Epoch') plt.ylabel('Number of Valid Molecules') plt.title('Number of Valid Single-Fragment Molecules Generated per Epoch') plt.legend() plt.savefig('valid_molecules_curve.png') plt.close() # 绘制分子片段比例曲线 plt.figure(figsize=(10, 5)) plt.plot([i * 100 for i in range(1, len(fragment_counts) + 1)], fragment_counts, label='Single-Fragment Ratio') plt.xlabel('Epoch') plt.ylabel('Single-Fragment Molecules Ratio') plt.title('Ratio of Single-Fragment Molecules in Generated Molecules') plt.legend() plt.savefig('fragment_ratio_curve.png') plt.close() # 绘制苯环比例曲线 plt.figure(figsize=(10, 5)) plt.plot([i * 100 for i in range(1, len(benzene_ratios) + 1)], benzene_ratios, label='Benzene Ring Ratio') plt.xlabel('Epoch') plt.ylabel('Benzene Ring Molecules Ratio') plt.title('Ratio of Molecules Containing Benzene Rings') plt.legend() plt.savefig('benzene_ratio_curve.png') plt.close() return generator, discriminator # ------------------------- 参考分子处理与批量生成(改进版)------------------------- def process_reference_smiles(reference_smiles): ref_graph = smiles_to_graph(reference_smiles) if ref_graph is None: raise ValueError("参考分子SMILES转换图结构失败,请检查SMILES合法性") ref_concentration = 1.0 ref_efficiency = 0.8 ref_condition = torch.tensor([ref_concentration, ref_efficiency], dtype=torch.float32) return ref_graph, ref_condition def generate_molecules_with_reference_improved(generator, reference_smiles, num_samples=1000, noise_type="gaussian", force_benzene=True, device=None): if device is None: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") ref_graph, ref_condition = process_reference_smiles(reference_smiles) ref_condition_batch = ref_condition.unsqueeze(0).repeat(num_samples, 1).to(device) if noise_type.lower() == "gaussian": noise = torch.randn(num_samples, 16).to(device) elif noise_type.lower() == "uniform": noise = 2 * torch.rand(num_samples, 16).to(device) - 1 else: raise ValueError("噪声类型必须是 'gaussian' 或 'uniform'") generator.eval() generated_mols = [] # 强制生成含苯环分子 benzene_condition = torch.ones(num_samples, dtype=torch.long, device=device) if force_benzene else \ torch.randint(0, 2, (num_samples,), device=device) with torch.no_grad(): fake_nodes = generator(noise, ref_condition_batch, benzene_condition) for i in range(num_samples): node_feats = fake_nodes[i] atomic_nums = (node_feats.squeeze() * 17).cpu().numpy().round().astype(int) atomic_nums = validate_atomic_nums(atomic_nums) atomic_nums = [int(num) for num in atomic_nums] # 使用改进的边生成函数 edge_index, bond_types = generate_realistic_edges_improved( node_feats, 20, atomic_nums ) # 使用改进的分子构建函数 mol = build_valid_mol_improved(atomic_nums, edge_index, bond_types) if mol and mol.GetNumAtoms() > 0 and mol.GetNumBonds() > 0: generated_mols.append(mol) # 过滤无效分子 valid_mols = filter_valid_mols(generated_mols) print(f"使用{noise_type}噪声,基于参考分子生成 {num_samples} 个分子,初步过滤后有效分子: {len(valid_mols)}") # 过滤含苯环的分子(如果强制要求) if force_benzene: benzene_mols = filter_benzene_mols(valid_mols) print(f"进一步过滤后,含苯环分子数量: {len(benzene_mols)}") print(f"含苯环分子比例: {len(benzene_mols) / len(valid_mols):.2%}") valid_mols = benzene_mols # 使用改进的单一片段过滤函数 single_fragment_mols = filter_single_fragment_mols_improved(valid_mols) print(f"进一步过滤后,单一片段分子数量: {len(single_fragment_mols)}") # 计算片段比例 fragment_ratio = len(single_fragment_mols) / len(valid_mols) if valid_mols else 0 print(f"单一片段分子比例: {fragment_ratio:.2%}") return single_fragment_mols # ------------------------- 分子属性计算与保存 ------------------------- def calculate_molecular_properties(mols): properties = [] for mol in mols: try: mol_weight = rdMolDescriptors.CalcExactMolWt(mol) logp = rdMolDescriptors.CalcCrippenDescriptors(mol)[0] tpsa = rdMolDescriptors.CalcTPSA(mol) hba = rdMolDescriptors.CalcNumHBA(mol) hbd = rdMolDescriptors.CalcNumHBD(mol) rot_bonds = rdMolDescriptors.CalcNumRotatableBonds(mol) n_count = sum(1 for atom in mol.GetAtoms() if atom.GetAtomicNum() == 7) o_count = sum(1 for atom in mol.GetAtoms() if atom.GetAtomicNum() == 8) s_count = sum(1 for atom in mol.GetAtoms() if atom.GetAtomicNum() == 16) p_count = sum(1 for atom in mol.GetAtoms() if atom.GetAtomicNum() == 15) frags = Chem.GetMolFrags(mol, asMols=True) fragment_count = len(frags) # 检测苯环 has_benzene = has_benzene_ring(mol) properties.append({ 'SMILES': Chem.MolToSmiles(mol), 'MW': mol_weight, 'LogP': logp, 'TPSA': tpsa, 'HBA': hba, 'HBD': hbd, 'RotBonds': rot_bonds, 'N_count': n_count, 'O_count': o_count, 'S_count': s_count, 'P_count': p_count, 'FragmentCount': fragment_count, 'HasBenzene': has_benzene }) except Exception as e: print(f"计算分子属性时出错: {e}") continue return properties def save_molecules(mols, prefix="generated", noise_type="gaussian"): if not mols: print("没有分子可保存") return subdir = f"{prefix}_{noise_type}" os.makedirs(subdir, exist_ok=True) if len(mols) <= 100: legends = [f"Mol {i + 1}" for i in range(len(mols))] img = Draw.MolsToGridImage(mols, molsPerRow=5, subImgSize=(300, 300), legends=legends) img.save(f"{subdir}/{prefix}_{noise_type}_molecules.png") properties = calculate_molecular_properties(mols) df = pd.DataFrame(properties) df.to_csv(f"{subdir}/{prefix}_{noise_type}_properties.csv", index=False) with open(f"{subdir}/{prefix}_{noise_type}_smiles.smi", "w") as f: for props in properties: f.write(f"{props['SMILES']}\n") print(f"已保存 {len(mols)} 个单一片段分子到目录: {subdir}") # ------------------------- 新增:生成结果分析工具 ------------------------- def analyze_generated_molecules(mols_gaussian, mols_uniform): print("\n===== 生成分子分析报告 =====") count_gaussian = len(mols_gaussian) count_uniform = len(mols_uniform) print(f"高斯噪声生成单一片段分子: {count_gaussian}") print(f"均匀噪声生成单一片段分子: {count_uniform}") def calculate_avg_properties(mols): if not mols: return {} props = calculate_molecular_properties(mols) avg_props = {} for key in props[0].keys(): if key != 'SMILES': avg_props[key] = sum(p[key] for p in props) / len(props) return avg_props avg_gaussian = calculate_avg_properties(mols_gaussian) avg_uniform = calculate_avg_properties(mols_uniform) if avg_gaussian and avg_uniform: print("\n高斯噪声生成分子的平均属性:") for key, value in avg_gaussian.items(): print(f" {key}: {value:.2f}") print("\n均匀噪声生成分子的平均属性:") for key, value in avg_uniform.items(): print(f" {key}: {value:.2f}") print("\n属性差异 (均匀 - 高斯):") for key in avg_gaussian.keys(): if key != 'SMILES': diff = avg_uniform[key] - avg_gaussian[key] print(f" {key}: {diff:+.2f}") if mols_gaussian and mols_uniform: properties = ['MW', 'LogP', 'TPSA', 'HBA', 'HBD', 'RotBonds'] plt.figure(figsize=(15, 10)) for i, prop in enumerate(properties, 1): plt.subplot(2, 3, i) gaussian_vals = [p[prop] for p in calculate_molecular_properties(mols_gaussian)] uniform_vals = [p[prop] for p in calculate_molecular_properties(mols_uniform)] plt.hist(gaussian_vals, bins=20, alpha=0.5, label='Gaussian') plt.hist(uniform_vals, bins=20, alpha=0.5, label='Uniform') plt.title(f'{prop} Distribution') plt.xlabel(prop) plt.ylabel('Frequency') plt.legend() plt.tight_layout() plt.savefig('molecular_property_distribution.png') plt.close() print("\n分子属性分布图已保存为 'molecular_property_distribution.png'") # ------------------------- 新增:分子可视化工具 ------------------------- def visualize_molecules_grid(molecules, num_per_row=5, filename="molecules_grid.png", legends=None): """创建高质量的分子网格可视化图""" if not molecules: print("没有分子可可视化") return if legends is None: legends = [f"Molecule {i + 1}" for i in range(len(molecules))] try: img = Draw.MolsToGridImage( molecules, molsPerRow=num_per_row, subImgSize=(300, 300), legends=legends, useSVG=False, highlightAtomLists=None, highlightBondLists=None ) img.save(filename) print(f"分子网格图已保存至: {filename}") return img except Exception as e: print(f"生成分子网格图时出错: {e}") return None # ------------------------- 主函数 ------------------------- def main(): print("=" * 80) print("基于合法SMILES的分子生成GAN系统(改进版)") print("=" * 80) dataset_path = "D:\python\pythonProject1\DATA\Inhibitor1368_data.xlsx" if not os.path.exists(dataset_path): print(f"错误:数据集文件 '{dataset_path}' 不存在!") exit(1) print(f"开始加载数据集: {dataset_path}") print("=" * 80) print("开始训练分子生成GAN...") generator, discriminator = train_gan( dataset_path=dataset_path, epochs=200, batch_size=32, lr=1e-5, ) print("=" * 80) # 设置参考缓蚀剂分子(确保是单一片段含苯环的有效分子) reference_smiles = "NCCNCc1ccc(O)c2ncccc12" if generator: print("训练完成!模型和生成结果已保存") print("生成的分子可视化结果在'results'目录") print("损失曲线已保存为'gan_loss_curve.png'") print("合法分子数量曲线已保存为'valid_molecules_curve.png'") print("分子片段比例曲线已保存为'fragment_ratio_curve.png'") print("苯环比例曲线已保存为'benzene_ratio_curve.png'") # 基于参考分子生成1000个新分子(高斯噪声,强制含苯环) print("\n开始基于参考分子生成新分子(高斯噪声,强制含苯环)...") gaussian_mols = generate_molecules_with_reference_improved( generator, reference_smiles, num_samples=1000, noise_type="gaussian", force_benzene=True # 强制生成含苯环分子 ) save_molecules(gaussian_mols, prefix="ref_based", noise_type="gaussian_benzene") # 可视化最佳分子 if gaussian_mols: # 计算每个分子的QED分数 qed_scores = [] for mol in gaussian_mols: try: qed_scores.append(rdMolDescriptors.CalcQED(mol)) except: qed_scores.append(0) # 按QED分数排序 sorted_indices = sorted(range(len(qed_scores)), key=lambda i: qed_scores[i], reverse=True) top_molecules = [gaussian_mols[i] for i in sorted_indices[:25]] top_legends = [f"QED: {qed_scores[i]:.3f}" for i in sorted_indices[:25]] # 可视化 visualize_molecules_grid( top_molecules, num_per_row=5, filename="top_molecules_by_qed.png", legends=top_legends ) print("已生成并保存最佳分子的可视化结果") # 基于参考分子生成1000个新分子(均匀噪声,强制含苯环) print("\n开始基于参考分子生成新分子(均匀噪声,强制含苯环)...") uniform_mols = generate_molecules_with_reference_improved( generator, reference_smiles, num_samples=1000, noise_type="uniform", force_benzene=True # 强制生成含苯环分子 ) save_molecules(uniform_mols, prefix="ref_based", noise_type="uniform_benzene") # 分析两种噪声生成的分子差异 if gaussian_mols and uniform_mols: analyze_generated_molecules(gaussian_mols, uniform_mols) else: print("训练过程中未生成合法分子,请检查数据") print("=" * 80) if __name__ == "__main__": main()

location_target_publisher = rospy.Publisher('/move_base', PoseStamped, queue_size=1)#目的地话题发布 target_pose = PoseStamped() target_pose.header.frame_id = "map" #加血 if (game_status == 4) and (require_add_HP == 1):#补血区 target_pose.pose.position.x = -2.50 target_pose.pose.position.y = 3.85 target_pose.pose.position.z = 0 #进攻掩护步兵占领中心增益点 elif (game_status == 4) and (robot_HP > 300) and (attack_cnt > 0):#中心增益点 target_pose.pose.position.x = 6.23 target_pose.pose.position.y = -2.66 target_pose.pose.position.z = 0 elif (game_status == 4) and (robot_HP > 300) and (attack_cnt == 0): #没到进攻时间在中心增益点附近 if remain_time > 220: target_pose.pose.position.x = 6.6 target_pose.pose.position.y = -2.45 #退让 elif remain_time > 150: target_pose.pose.position.x = 3.0 target_pose.pose.position.y = 0.3 #比赛最后时段 else: target_pose.pose.position.x = 6.23 target_pose.pose.position.y = -2.66 target_pose.pose.position.z = 0 #否则待在原地 else:#原地 target_pose.pose.position.x = 0 target_pose.pose.position.y = 0 target_pose.pose.position.z = 0 target_quad = quaternion_from_euler(0, 0, target_yaw) target_pose.pose.orientation.x = target_quad[0] target_pose.pose.orientation.y = target_quad[1] target_pose.pose.orientation.z = target_quad[2] target_pose.pose.orientation.w = target_quad[3] location_target_publisher.publish(target_pose) if target_yaw > 31.415: target_yaw = -31.415 # print('OK!') 通过上述代码我怎么获得是否到达目标点的状态

帮我注释下面代码class Ball: """ 针头 """ def __init__(self, angle): self.x = x0 self.y = y0 + length self.center = (self.x, self.y) self.radius = 12 self.angle = angle ball_group.append(self) def draw(self, surface): pygame.draw.line(surface, WHITE, shaft, self.center, 2) pygame.draw.circle(surface, WHITE, self.center, 12) def move(self, speed): """ 围绕转轴做圆周运动 :param speed: 转动的角速度 :return: """ if self.angle < 2 * math.pi: self.angle += speed else: self.angle = self.angle - 2 * math.pi self.x = x0 - length * math.sin(self.angle) self.y = y0 + length * math.cos(self.angle) self.center = (self.x, self.y) def check_collide(new_ball): for ball in ball_group: distance = math.sqrt(abs(ball.x - new_ball.x) ** 2 + abs(ball.y - new_ball.y) ** 2) if ball is not new_ball and distance <= new_ball.radius * 2: return True return False def game_init(): global stage, gaming, remain_ball, stage_pass,score if stage == len(all_stage): stage = 1 elif stage_pass: stage += 1 ball_group.clear() for a in all_stage[stage - 1]: b = Ball(a) remain_ball = ball_num[stage - 1] score=0 pygame.time.delay(200) gaming = True stage_pass = False button = Button('重新开始', color=(220, 0, 0)) button.rect.center = shaft button.click_connect(game_init) def restart(): for event in pygame.event.get(): if event.type == pygame.QUIT: pygame.quit() sys.exit() button.get_click(event) def game_stage(): global remain_ball, stage, gaming, stage_pass,score if remain_ball == 0 and gaming: stage_pass = True gaming = False

import data LOT = data.LOT Daily_monitor_lot = data.Daily_monitor_lot Operation = data.Operation Machine_of_operation = data.Machine_of_operation Last_Operation_Machine_inhibit_due_time = data.Last_Operation_Machine_inhibit_due_time Operation_Machine_K = data.Operation_Machine_K Lot_wafer_QTY = data.Lot_wafer_QTY Operation_Machine_Capability = data.Operation_Machine_Capability Lot_release_time = data.Lot_release_time Machine_release_time = data.Machine_release_time Operation_Machine_Lot_transfer_time = data.Operation_Machine_Lot_transfer_time Lot_release_Remain_QT = data.Lot_release_Remain_QT Lot_release_RQT_w = data.Lot_release_RQT_w Loop_QTime = data.Loop_QTime Loop_qt = data.Loop_qt Last_Operation_Machine_SD_PM_Start_time = data.Last_Operation_Machine_SD_PM_Start_time Last_Operation_Machine_SD_PM_End_time = data.Last_Operation_Machine_SD_PM_End_time N = 1e6 # 无限大正数 # lot的最后一道工序映射表 last_Operation = {l: Operation[l][-1] for l in LOT} print('last_Operation=', last_Operation) # 最后一道工序对应机台映射表(最后一步工序包含的machineList) last_Operation_Machine = {op: list(Machine_of_operation[op].keys())for op in set(last_Operation.values())} print('last_Operation_Machine=', last_Operation_Machine) # ================== 模型构建 ================== solver = pywraplp.Solver.CreateSolver('SCIP') # 决策变量(1): Lot l 的第o道工序是否在machine m上加工,0-1变量 from collections import defaultdict x_lom = defaultdict(dict) # x_lom: 1如果lot l的工序o在机器m上加工 for l in LOT: if l not in Operation: continue for o in Operation[l]: machines = Machine_of_operation.get(o, []) for m in machines: if Operation_Machine_K.get((o, m), 0) == 1 : key = (l, o, m) if key not in x_lom: x_lom[key] = solver.BoolVar(f'x_{l}_{o}_{m}') #print(f'x_{l}_{o}_{m}=', x_lom.values()) # 决策变量(2): Lot l 在第o step的start time s_lom = defaultdict(dict) # x_lom: 1如果lot l的工序o在机器m上加工 for l in LOT: if l not in Operation: continue for o in Operation[l]: machines = Machine_of_operation.get(o, []) for m in machines: key = (l, o, m) if key not in s_lom: s_lom[key] = solver.NumVar(0, solver.infinity(), f's_{l}_{o}_{m}') 建立一个简单的ortools求解器求解的柔性车间调度问题

大家在看

recommend-type

《极品家丁(七改版)》(珍藏七改加料无雷精校全本)(1).zip

《极品家丁(七改版)》(珍藏七改加料无雷精校全本)(1).zip
recommend-type

密码::unlocked::sparkles::locked:创新,方便,安全的加密应用程序

隐身者 创新,方便,安全的加密应用程序。 加密无限位。 只记得一点。 Crypter是一款跨平台的加密应用程序,它使加密和解密变得很方便,同时仍然保持强大的安全性。 它解决了当今大多数安全系统中最弱的链接之一-弱密码。 它简化了安全密码的生成和管理,并且只需要记住一个位-MasterPass。 是一个加密应用程序,可以解密和加密包括文件和文件夹在内的任意数据。 该版本已发布,并针对macOS(OSX),Linux(适用于所有通过发行的发行版)和Windows(32和64位)进行了全面测试。 所有核心模块(提供核心功能的模块)都经过了全面测试。 会将MasterPass保存在操作系统的钥匙串中,因此您不必在每次打开应用程序时都输入它。 为了帮助加快开发速度,请发送PR剩下的内容做 如果您有任何建议,请打开一个问题,并通过PR进行改进! 还要签出 ( )一个分散的端到端加密消息传递应用程序。 链接到此自述文件: : 内容 安装 适用于所有主要平台的所有预构建二进制文件都可以在。 Crypter也适用于macOS的 。 因此,要安装它,只需在终端中运行以下命令:
recommend-type

HkAndroidSDK.zip

助于Android开发视频监控功能,根据ip地址可以远程操控,控制向左,向右,向下,向上以及转动摄像头,也可以放大和缩小
recommend-type

matlab的欧拉方法代码-BEM_flow_simulation:计算流体力学:使用边界元方法模拟障碍物周围/附近的流动

matlab的欧拉方法代码BEM_flow_simulation MATLAB上的计算流体力学: 目的是使用边界元素方法模拟任何障碍物附近或周围的任何形式的流动 使用BEM绕圆柱障碍物和接近均匀战争的潜在流动 非粘性势流的假设适用于导航斯托克斯方程(Euler方程),使用边界元方法,该代码模拟了在均匀垂直壁附近的尺寸稳定的圆柱障碍物周围的流动。 该系统不受其他方向的限制。 该代码是流体力学硕士1实习的主题,并且作为大型项目的第一块砖,该项目用于模拟复杂非均匀障碍物周围的粘性流动,因此可以自由继续。 类“ pot_flow_class”模拟垂直于垂直壁(两个障碍物之间的距离为H)附近圆柱2D障碍物(无量纲半径r = 1)附近的该势流。 流速为U = 1(无量纲)。 使用边界元素方法的第二层。 这样的流动的精确解决方案的代码允许验证无垂直壁模拟。
recommend-type

基于YOLO网络的行驶车辆目标检测matlab仿真+操作视频

1.领域:matlab,YOLO网络的行驶车辆目标检测算法 2.内容:基于YOLO网络的行驶车辆目标检测matlab仿真+操作视频 3.用处:用于YOLO网络的行驶车辆目标检测算法编程学习 4.指向人群:本硕博等教研学习使用 5.运行注意事项: 使用matlab2021a或者更高版本测试,运行里面的Runme_.m文件,不要直接运行子函数文件。运行时注意matlab左侧的当前文件夹窗口必须是当前工程所在路径。 具体可观看提供的操作录像视频跟着操作。

最新推荐

recommend-type

C#类库封装:简化SDK调用实现多功能集成,构建地磅无人值守系统

内容概要:本文介绍了利用C#类库封装多个硬件设备的SDK接口,实现一系列复杂功能的一键式调用。具体功能包括身份证信息读取、人证识别、车牌识别(支持臻识和海康摄像头)、LED显示屏文字输出、称重数据读取、二维码扫描以及语音播报。所有功能均被封装为简单的API,极大降低了开发者的工作量和技术门槛。文中详细展示了各个功能的具体实现方式及其应用场景,如身份证读取、人证核验、车牌识别等,并最终将这些功能整合到一起,形成了一套完整的地磅称重无人值守系统解决方案。 适合人群:具有一定C#编程经验的技术人员,尤其是需要快速集成多种硬件设备SDK的应用开发者。 使用场景及目标:适用于需要高效集成多种硬件设备SDK的项目,特别是那些涉及身份验证、车辆管理、物流仓储等领域的企业级应用。通过使用这些封装好的API,可以大大缩短开发周期,降低维护成本,提高系统的稳定性和易用性。 其他说明:虽然封装后的API极大地简化了开发流程,但对于一些特殊的业务需求,仍然可能需要深入研究底层SDK。此外,在实际部署过程中,还需考虑网络环境、硬件兼容性等因素的影响。
recommend-type

基于STM32F1的BLDC无刷直流电机与PMSM永磁同步电机源码解析:传感器与无传感器驱动详解

基于STM32F1的BLDC无刷直流电机和PMSM永磁同步电机的驱动实现方法,涵盖了有传感器和无传感两种驱动方式。对于BLDC电机,有传感器部分采用霍尔传感器进行六步换相,无传感部分则利用反电动势过零点检测实现换相。对于PMSM电机,有传感器部分包括霍尔传感器和编码器的方式,无传感部分则采用了滑模观测器进行矢量控制(FOC)。文中不仅提供了详细的代码片段,还分享了许多调试经验和技巧。 适合人群:具有一定嵌入式系统和电机控制基础知识的研发人员和技术爱好者。 使用场景及目标:适用于需要深入了解和实现BLDC和PMSM电机驱动的开发者,帮助他们掌握不同传感器条件下的电机控制技术和优化方法。 其他说明:文章强调了实际调试过程中可能遇到的问题及其解决方案,如霍尔传感器的中断触发换相、反电动势过零点检测的采样时机、滑模观测器的参数调整以及编码器的ABZ解码等。
recommend-type

基于Java的跨平台图像处理软件ImageJ:多功能图像编辑与分析工具

内容概要:本文介绍了基于Java的图像处理软件ImageJ,详细阐述了它的跨平台特性、多线程处理能力及其丰富的图像处理功能。ImageJ由美国国立卫生研究院开发,能够在多种操作系统上运行,包括Windows、Mac OS、Linux等。它支持多种图像格式,如TIFF、PNG、GIF、JPEG、BMP、DICOM、FITS等,并提供图像栈功能,允许多个图像在同一窗口中进行并行处理。此外,ImageJ还提供了诸如缩放、旋转、扭曲、平滑处理等基本操作,以及区域和像素统计、间距、角度计算等高级功能。这些特性使ImageJ成为科研、医学、生物等多个领域的理想选择。 适合人群:需要进行图像处理的专业人士,如科研人员、医生、生物学家,以及对图像处理感兴趣的普通用户。 使用场景及目标:适用于需要高效处理大量图像数据的场合,特别是在科研、医学、生物学等领域。用户可以通过ImageJ进行图像的编辑、分析、处理和保存,提高工作效率。 其他说明:ImageJ不仅功能强大,而且操作简单,用户无需安装额外的运行环境即可直接使用。其基于Java的开发方式确保了不同操作系统之间的兼容性和一致性。
recommend-type

MATLAB语音识别系统:基于GUI的数字0-9识别及深度学习模型应用 · GUI v1.2

内容概要:本文介绍了一款基于MATLAB的语音识别系统,主要功能是识别数字0到9。该系统采用图形用户界面(GUI),方便用户操作,并配有详尽的代码注释和开发报告。文中详细描述了系统的各个组成部分,包括音频采集、信号处理、特征提取、模型训练和预测等关键环节。此外,还讨论了MATLAB在此项目中的优势及其面临的挑战,如提高识别率和处理背景噪音等问题。最后,通过对各模块的工作原理和技术细节的总结,为未来的研究和发展提供了宝贵的参考资料。 适合人群:对语音识别技术和MATLAB感兴趣的初学者、学生或研究人员。 使用场景及目标:适用于希望深入了解语音识别技术原理的人群,特别是希望通过实际案例掌握MATLAB编程技巧的学习者。目标是在实践中学习如何构建简单的语音识别应用程序。 其他说明:该程序需要MATLAB 2019b及以上版本才能正常运行,建议使用者确保软件环境符合要求。
recommend-type

Teleport Pro教程:轻松复制网站内容

标题中提到的“复制别人网站的软件”指向的是一种能够下载整个网站或者网站的特定部分,然后在本地或者另一个服务器上重建该网站的技术或工具。这类软件通常被称作网站克隆工具或者网站镜像工具。 描述中提到了一个具体的教程网址,并提到了“天天给力信誉店”,这可能意味着有相关的教程或资源可以在这个网店中获取。但是这里并没有提供实际的教程内容,仅给出了网店的链接。需要注意的是,根据互联网法律法规,复制他人网站内容并用于自己的商业目的可能构成侵权,因此在此类工具的使用中需要谨慎,并确保遵守相关法律法规。 标签“复制 别人 网站 软件”明确指出了这个工具的主要功能,即复制他人网站的软件。 文件名称列表中列出了“Teleport Pro”,这是一款具体的网站下载工具。Teleport Pro是由Tennyson Maxwell公司开发的网站镜像工具,允许用户下载一个网站的本地副本,包括HTML页面、图片和其他资源文件。用户可以通过指定开始的URL,并设置各种选项来决定下载网站的哪些部分。该工具能够帮助开发者、设计师或内容分析人员在没有互联网连接的情况下对网站进行离线浏览和分析。 从知识点的角度来看,Teleport Pro作为一个网站克隆工具,具备以下功能和知识点: 1. 网站下载:Teleport Pro可以下载整个网站或特定网页。用户可以设定下载的深度,例如仅下载首页及其链接的页面,或者下载所有可访问的页面。 2. 断点续传:如果在下载过程中发生中断,Teleport Pro可以从中断的地方继续下载,无需重新开始。 3. 过滤器设置:用户可以根据特定的规则过滤下载内容,如排除某些文件类型或域名。 4. 网站结构分析:Teleport Pro可以分析网站的链接结构,并允许用户查看网站的结构图。 5. 自定义下载:用户可以自定义下载任务,例如仅下载图片、视频或其他特定类型的文件。 6. 多任务处理:Teleport Pro支持多线程下载,用户可以同时启动多个下载任务来提高效率。 7. 编辑和管理下载内容:Teleport Pro具备编辑网站镜像的能力,并可以查看、修改下载的文件。 8. 离线浏览:下载的网站可以在离线状态下浏览,这对于需要测试网站在不同环境下的表现的情况十分有用。 9. 备份功能:Teleport Pro可以用来备份网站,确保重要数据的安全。 在实际使用此类工具时,需要注意以下几点: - 著作权法:复制网站内容可能侵犯原作者的版权,因此在使用此类工具时,必须确保有合法权利去下载和使用目标网站的内容。 - 服务条款:许多网站的服务条款明确禁止未经授权的网站克隆。因此,在使用此类软件之前,应当仔细阅读并遵守目标网站的服务条款。 - 数据隐私:下载含有个人数据的网站可能触及隐私保护法律,特别是在欧洲通用数据保护条例(GDPR)等法规的环境下。 - 网络安全:随意下载网站可能导致恶意软件和病毒的传播,用户应当使用最新的反病毒软件,并在安全的环境中进行操作。 综上所述,虽然“复制别人网站的软件”在技术上是可行的,且可以用于多种正当目的,但在使用此类工具时,必须遵守法律法规和道德规范,不可侵犯他人的合法权益。
recommend-type

【跨平台开发者的必读】:解决Qt5Widgetsd.lib目标计算机类型冲突终极指南

# 1. 跨平台开发与Qt框架简介 跨平台开发已经成为现代软件开发中不可或缺的一部分,它使得开发者能够构建出能够在多种操作系统上运行的应用程序,从而触及更广泛的用户群体。在这其中,Qt框架作为一套成熟的跨平台应用程序和用户界面框架,扮演着关键角色。Qt不仅仅是一个跨平台的工具集,它的设计哲学还包含了模块化和可扩展性,使其能够适应各种规模和复杂度的项目。本章节我们将简要介绍
recommend-type

普通RNN结构和特点

<think>首先,用户的问题是:“我想了解普通RNN的结构及其特点请问RNN循环神经网络结构特点工作原理”。这涉及到RNN的结构、特点和工作原理。根据系统级指令:-所有行内数学表达式必须使用$...$格式。-独立公式必须使用$$...$$格式并单独成段。-LaTeX语法正确。-使用中文回答。-生成相关问题。-回答中引用的段落末尾自然地添加引用标识。用户可见层指令:-回答结构清晰,帮助用户逐步解决问题。-保证回答真实可靠。参考站内引用:-引用[1]:关于RNN的基本介绍,为什么需要RNN。-引用[2]:关于RNN的工作原理、结构图,以及与其他网络的比较。用户上一次的问题和我的回答:用户是第一次
recommend-type

探讨通用数据连接池的核心机制与应用

根据给定的信息,我们能够推断出讨论的主题是“通用数据连接池”,这是一个在软件开发和数据库管理中经常用到的重要概念。在这个主题下,我们可以详细阐述以下几个知识点: 1. **连接池的定义**: 连接池是一种用于管理数据库连接的技术,通过维护一定数量的数据库连接,使得连接的创建和销毁操作更加高效。开发者可以在应用程序启动时预先创建一定数量的连接,并将它们保存在一个池中,当需要数据库连接时,可以直接从池中获取,从而降低数据库连接的开销。 2. **通用数据连接池的概念**: 当提到“通用数据连接池”时,它意味着这种连接池不仅支持单一类型的数据库(如MySQL、Oracle等),而且能够适应多种不同数据库系统。设计一个通用的数据连接池通常需要抽象出一套通用的接口和协议,使得连接池可以兼容不同的数据库驱动和连接方式。 3. **连接池的优点**: - **提升性能**:由于数据库连接创建是一个耗时的操作,连接池能够减少应用程序建立新连接的时间,从而提高性能。 - **资源复用**:数据库连接是昂贵的资源,通过连接池,可以最大化现有连接的使用,避免了连接频繁创建和销毁导致的资源浪费。 - **控制并发连接数**:连接池可以限制对数据库的并发访问,防止过载,确保数据库系统的稳定运行。 4. **连接池的关键参数**: - **最大连接数**:池中能够创建的最大连接数。 - **最小空闲连接数**:池中保持的最小空闲连接数,以应对突发的连接请求。 - **连接超时时间**:连接在池中保持空闲的最大时间。 - **事务处理**:连接池需要能够管理不同事务的上下文,保证事务的正确执行。 5. **实现通用数据连接池的挑战**: 实现一个通用的连接池需要考虑到不同数据库的连接协议和操作差异。例如,不同的数据库可能有不同的SQL方言、认证机制、连接属性设置等。因此,通用连接池需要能够提供足够的灵活性,允许用户配置特定数据库的参数。 6. **数据连接池的应用场景**: - **Web应用**:在Web应用中,为了处理大量的用户请求,数据库连接池可以保证数据库连接的快速复用。 - **批处理应用**:在需要大量读写数据库的批处理作业中,连接池有助于提高整体作业的效率。 - **微服务架构**:在微服务架构中,每个服务可能都需要与数据库进行交互,通用连接池能够帮助简化服务的数据库连接管理。 7. **常见的通用数据连接池技术**: - **Apache DBCP**:Apache的一个Java数据库连接池库。 - **C3P0**:一个提供数据库连接池和控制工具的开源Java框架。 - **HikariCP**:目前性能最好的开源Java数据库连接池之一。 - **BoneCP**:一个高性能的开源Java数据库连接池。 - **Druid**:阿里巴巴开源的一个数据库连接池,提供了对性能监控的高级特性。 8. **连接池的管理与监控**: 为了保证连接池的稳定运行,开发者需要对连接池的状态进行监控,并对其进行适当的管理。监控指标可能包括当前活动的连接数、空闲的连接数、等待获取连接的请求队列长度等。一些连接池提供了监控工具或与监控系统集成的能力。 9. **连接池的配置和优化**: 连接池的性能与连接池的配置密切相关。需要根据实际的应用负载和数据库性能来调整连接池的参数。例如,在高并发的场景下,可能需要增加连接池中连接的数量。另外,适当的线程池策略也可以帮助连接池更好地服务于多线程环境。 10. **连接池的应用案例**: 一个典型的案例是电商平台在大型促销活动期间,用户访问量激增,此时通用数据连接池能够保证数据库操作的快速响应,减少因数据库连接问题导致的系统瓶颈。 总结来说,通用数据连接池是现代软件架构中的重要组件,它通过提供高效的数据库连接管理,增强了软件系统的性能和稳定性。了解和掌握连接池的原理及实践,对于任何涉及数据库交互的应用开发都至关重要。在实现和应用连接池时,需要关注其设计的通用性、配置的合理性以及管理的有效性,确保在不同的应用场景下都能发挥出最大的效能。
recommend-type

【LabVIEW网络通讯终极指南】:7个技巧提升UDP性能和安全性

# 摘要 本文系统介绍了LabVIEW在网络通讯中的应用,尤其是针对UDP协议的研究与优化。首先,阐述了UDP的原理、特点及其在LabVIEW中的基础应用。随后,本文深入探讨了通过调整数据包大小、实现并发通信及优化缓冲区管理等技巧来优化UDP性能的LabVIEW方法。接着,文章聚焦于提升UDP通信安全性,介绍了加密技术和认证授权机制在LabVIEW中的实现,以及防御网络攻击的策略。最后,通过具体案例展示了LabVIEW在实时数据采集和远程控制系统中的高级应用,并展望了LabVIEW与UDP通讯技术的未来发展趋势及新兴技术的影响。 # 关键字 LabVIEW;UDP网络通讯;性能优化;安全性;
recommend-type

简要介绍cnn卷积神经网络

<think>首先,用户的问题是:“我想了解卷积神经网络的基本概念和原理 请问CNN卷积神经网络简介及其工作原理”。所以,我需要基于提供的引用材料,构建一个清晰的回答。 引用材料总结: - [^1]: CNN的核心思想是局部感受野、权值共享和时间或空间亚采样,提供位移、尺度、形变不变性。三大特色:局部感知、权重共享和多卷积核。 - [^2]: CNN是一种前馈神经网络,由卷积层和池化层组成,特别在图像处理方面出色。与传统多层神经网络相比,CNN加入了卷积层和池化层,使特征学习更有效。 - [^3]: CNN与全连接神经网络的区别:至少有一个卷积层提取特征;神经元局部连接和权值共享,减少参数数