Pytorch容器:nn.Sequential,ModuleList,ParameterList源码解析

本文详细介绍了PyTorch中容器模块的使用方法及其源码实现,包括nn.Sequential、nn.ModuleList、nn.ModuleDict等,展示了如何利用这些模块构建神经网络。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

nn.Sequential

源码位置:./torch/nn/modules/container.py

# 查看torch位置

>>> import torch
>>> torch.__file__

# container.py包含内容 
# 都继承自nn.Module类

# class Sequential(Module)

# class ModuleList(Module)
# class ModuleDict(Module)

# class ParameterList(Module)
# class ParameterDict(Module)

1 nn.Sequential用法及源码

函数包括:

1.1 初始化方法_init_

_init_ 用法:

# 用法1 传入OrderedDict
model = nn.Sequential(OrderedDict([
                  ('conv1', nn.Conv2d(1,20,5)),
                  ('relu1', nn.ReLU()),
                  ('conv2', nn.Conv2d(20,64,5)),
                  ('relu2', nn.ReLU())
                ]))
                
# 用法2 直接添加module
model = nn.Sequential(
    # 二维卷积
	nn.Conv2d(1,20,5),
    nn.ReLU(),
    nn.Conv2d(20,64,5),
    nn.ReLU()
	)

源码:

# 有序 容器
class Sequential(Module):

    _modules: Dict[str, Module]  # type: ignore[assignment]

    @overload
    def __init__(self, *args: Module) -> None:
        ...

    @overload
    def __init__(self, arg: 'OrderedDict[str, Module]') -> None:
        ...

    def __init__(self, *args):
        super(Sequential, self).__init__()
        # 方法1.如果传入参数为OrderedDict类型 则遍历OrderedDict(包含key) 添加模块
        if len(args) == 1 and isinstance(args[0], OrderedDict):
            for key, module in args[0].items():
                self.add_module(key, module)
        # 方法2.不是OrderedDict 直接添加模块(模块的key默认是数字0,1,2...)
        else:
            for idx, module in enumerate(args):
                self.add_module(str(idx), module)

1.2 _getitem_,_setitem_,_delitem_

_getitem_
_setitem_
_setitem_

用法:

# getitem
# 1.如果是用OrderedDict创建的 用key获取
>>> model
Sequential(
  (conv1): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
  (relu1): ReLU()
  (conv2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))
  (relu2): ReLU()
)
>>> model.conv1
Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
>>> model.relu1
ReLU()
>>> 

# 2.如果是直接创建的 用index获取
>>> model
Sequential(
  (0): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
  (1): ReLU()
  (2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))
  (3): ReLU()
)
>>> model[0]
Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
>>> model[1]
ReLU()
>>> 

源码:

# 调用了_get_item_by_idx函数
def __getitem__(self, idx) -> Union['Sequential', T]:
        # OrderedDict取值方式
        if isinstance(idx, slice):
            return self.__class__(OrderedDict(list(self._modules.items())[idx]))
        # 按传入索引取值
        else:
            return self._get_item_by_idx(self._modules.values(), idx)

# 查找函数 被调用
def _get_item_by_idx(self, iterator, idx) -> T:
    """Get the idx-th item of the iterator"""
    size = len(self)
    idx = operator.index(idx)
    if not -size <= idx < size:
        raise IndexError('index {} is out of range'.format(idx))
    idx %= size
    return next(islice(iterator, idx, None))
# setitem
# 修改已经存在模块的信息 
# 以下代码将原来的0模块卷积核大小由5 × 5改为3 × 3
>>> model                       
Sequential(
  (0): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
  (1): ReLU()
  (2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))
  (3): ReLU()
)
>>> model[0] = nn.Conv2d(1,20,3)
>>> model
Sequential(
  (0): Conv2d(1, 20, kernel_size=(3, 3), stride=(1, 1))
  (1): ReLU()
  (2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))
  (3): ReLU()
)

# 源码
def __setitem__(self, idx: int, module: Module) -> None:
    key: str = self._get_item_by_idx(self._modules.keys(), idx)
    return setattr(self, key, module)
# delitem 删除现有的module
# 以下代码演示删除原来的0模块
>>> model
Sequential(
  (0): Conv2d(1, 20, kernel_size=(3, 3), stride=(1, 1))
  (1): ReLU()
  (2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))
  (3): ReLU()
)
>>> del model[0]
>>> model
Sequential(
  (1): ReLU()
  (2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))
  (3): ReLU()
)

# 源码
def __delitem__(self, idx: Union[slice, int]) -> None:
    # 有子模块
    if isinstance(idx, slice):
        for key in list(self._modules.keys())[idx]:
            delattr(self, key)
    else:
        key = self._get_item_by_idx(self._modules.keys(), idx)
        delattr(self, key)

1.3 输出信息 _len_,_dir_

用法:

>>> model
Sequential(
  (1): ReLU()
  (2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))
  (3): ReLU()
)
>>> len(model)
3
>>> dir(model)
['T_destination', '__annotations__',..., 'training', 'type', 'xpu', 'zero_grad']

源码:

def __len__(self) -> int:
    return len(self._modules)


# 返回当前范围内的变量、方法和定义的类型列表
def __dir__(self):
    keys = super(Sequential, self).__dir__()
    keys = [key for key in keys if not key.isdigit()]
    return keys

1.4 迭代器 _iter_

用法:

>>> model
Sequential(
  (1): ReLU()
  (2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))
  (3): ReLU()
)
>>> for i in iter(model):
...     print(i)
... 
ReLU()
Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))
ReLU()
>>> 

源码:

def __iter__(self) -> Iterator[Module]:
    return iter(self._modules.values())

1.5 添加module append方法

用法:

>>> model
Sequential(
  (0): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
  (1): ReLU()
  (2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))
  (3): ReLU()
)
>>> model.append(nn.Conv2d(1,20,3))
Sequential(
  (0): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
  (1): ReLU()
  (2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))
  (3): ReLU()
  (4): Conv2d(1, 20, kernel_size=(3, 3), stride=(1, 1))
)

源码:

def append(self, module: Module) -> 'Sequential':

    self.add_module(str(len(self)), module)
    return self

1.6 前向传播函数 forward方法

用法:
模型会自动调用;

# 前向传播
model(data)

# 而不是使用下面的
# model.forward(data)

源码:

# input依次经过每个module
def forward(self, input):
    for module in self:
        input = module(input)
    return input

2 nn.ModuleList用法及源码

函数们:

用法:

class MyModule(nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()
        # nn.ModuleList装了10个nn.Linear
        self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(10)])

    def forward(self, x):
        # ModuleList can act as an iterable, or be indexed using ints
        for i, l in enumerate(self.linears):
            x = self.linears[i // 2](x) + l(x)
        return x

这里只介绍insert/append/extend,其他函数与之前相同;

# insert 在任意位置添加
>>> linears = nn.ModuleList([nn.Linear(10, 10) for i in range(5)]) 
>>> linears
ModuleList(
  (0): Linear(in_features=10, out_features=10, bias=True)
  (1): Linear(in_features=10, out_features=10, bias=True)
  (2): Linear(in_features=10, out_features=10, bias=True)
  (3): Linear(in_features=10, out_features=10, bias=True)
  (4): Linear(in_features=10, out_features=10, bias=True)
)
>>> linears.insert(3,nn.Conv2d(1,20,5))
>>> linears
ModuleList(
  (0): Linear(in_features=10, out_features=10, bias=True)
  (1): Linear(in_features=10, out_features=10, bias=True)
  (2): Linear(in_features=10, out_features=10, bias=True)
  (3): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
  (4): Linear(in_features=10, out_features=10, bias=True)
  (5): Linear(in_features=10, out_features=10, bias=True)
)

# append 在最后添加
>>> linears
ModuleList(
  (0): Linear(in_features=10, out_features=10, bias=True)
  (1): Linear(in_features=10, out_features=10, bias=True)
  (2): Linear(in_features=10, out_features=10, bias=True)
  (3): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
  (4): Linear(in_features=10, out_features=10, bias=True)
  (5): Linear(in_features=10, out_features=10, bias=True)
)
>>> linears.append(nn.Conv2d(1,20,5))
ModuleList(
  (0): Linear(in_features=10, out_features=10, bias=True)
  (1): Linear(in_features=10, out_features=10, bias=True)
  (2): Linear(in_features=10, out_features=10, bias=True)
  (3): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
  (4): Linear(in_features=10, out_features=10, bias=True)
  (5): Linear(in_features=10, out_features=10, bias=True)
  (6): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
)

# extend 一次添加多个
>>> linears = nn.ModuleList([nn.Linear(10, 10) for i in range(3)]) 
>>> linears
ModuleList(
  (0): Linear(in_features=10, out_features=10, bias=True)
  (1): Linear(in_features=10, out_features=10, bias=True)
  (2): Linear(in_features=10, out_features=10, bias=True)
)
>>> linears.extend([nn.Conv2d(1,10,5),nn.ReLU()])
ModuleList(
  (0): Linear(in_features=10, out_features=10, bias=True)
  (1): Linear(in_features=10, out_features=10, bias=True)
  (2): Linear(in_features=10, out_features=10, bias=True)
  (3): Conv2d(1, 10, kernel_size=(5, 5), stride=(1, 1))
  (4): ReLU()
)

源码:

# 在任意位置插入一个模块
def insert(self, index: int, module: Module) -> None:

    for i in range(len(self._modules), index, -1):
        self._modules[str(i)] = self._modules[str(i - 1)]
    self._modules[str(index)] = module


# 在最后添加一个模块
def append(self, module: Module) -> 'ModuleList':
    r"""Appends a given module to the end of the list.

    Args:
        module (nn.Module): module to append
    """
    self.add_module(str(len(self)), module)
    return self

# 在最后添加多个模块
def extend(self, modules: Iterable[Module]) -> 'ModuleList':
    r"""Appends modules from a Python iterable to the end of the list.

    Args:
        modules (iterable): iterable of modules to append
    """
    if not isinstance(modules, container_abcs.Iterable):
        raise TypeError("ModuleList.extend should be called with an "
                        "iterable, but got " + type(modules).__name__)
    offset = len(self)
    for i, module in enumerate(modules):
        self.add_module(str(offset + i), module)
    return self

3 nn.ModuleDict用法及源码

作用和nn.ModuleDict一样,一个是list形式,一个是dict形式;
将list的函数换成dict的函数即可;

函数们:

用法:

class MyModule(nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()
        self.choices = nn.ModuleDict({
            'conv': nn.Conv2d(10, 10, 3),
            'conv2': nn.Conv2d(10,10,3),
            'pool': nn.MaxPool2d(3)
        })
        self.activations = nn.ModuleDict([
            ['lrelu', nn.LeakyReLU()],
            ['prelu', nn.PReLU()]
        ])

    def forward(self, x, choice, act):
        x = self.choices[choice](x)
        x = self.activations[act](x)
        return x

函数用法:

# 初始化一个ModuleDict
>>> choices = nn.ModuleDict({
...             'conv': nn.Conv2d(10, 10, 3),
...             'conv2': nn.Conv2d(10,10,3),
...             'pool': nn.MaxPool2d(3)
...         })
>>> 
>>> choices
ModuleDict(
  (conv): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1))
  (conv2): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1))
  (pool): MaxPool2d(kernel_size=3, stride=3, padding=0, dilation=1, ceil_mode=False)
)
# 查看
>>> choices.items()
odict_items([('conv', Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1))), ('conv2', Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1))), ('pool', MaxPool2d(kernel_size=3, stride=3, padding=0, dilation=1, ceil_mode=False))])
# 只看键
>>> choices.keys() 
odict_keys(['conv', 'conv2', 'pool'])

# 只看值
>>> choices.values()
odict_values([Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1)), Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1)), MaxPool2d(kernel_size=3, stride=3, padding=0, dilation=1, ceil_mode=False)])

# 更新现有的结构
>>> choices.update({'pool':nn.MaxPool2d(5)})
>>> choices
ModuleDict(
  (conv): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1))
  (conv2): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1))
  (pool): MaxPool2d(kernel_size=5, stride=5, padding=0, dilation=1, ceil_mode=False)
)
# pop 删除一个
>>> choices.pop('pool')
MaxPool2d(kernel_size=5, stride=5, padding=0, dilation=1, ceil_mode=False)
>>> choices
ModuleDict(
  (conv): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1))
  (conv2): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1))
)
# 清空所有
>>> choices.clear()
>>> choices
ModuleDict()

源码:

class ModuleDict(Module):

    def clear(self) -> None:
        """Remove all items from the ModuleDict.
        """
        self._modules.clear()

    def pop(self, key: str) -> Module:
        r"""Remove key from the ModuleDict and return its module.

        Args:
            key (string): key to pop from the ModuleDict
        """
        v = self[key]
        del self[key]
        return v

    @_copy_to_script_wrapper
    def keys(self) -> Iterable[str]:
        r"""Return an iterable of the ModuleDict keys.
        """
        return self._modules.keys()

    @_copy_to_script_wrapper
    def items(self) -> Iterable[Tuple[str, Module]]:
        r"""Return an iterable of the ModuleDict key/value pairs.
        """
        return self._modules.items()

    @_copy_to_script_wrapper
    def values(self) -> Iterable[Module]:
        r"""Return an iterable of the ModuleDict values.
        """
        return self._modules.values()

    def update(self, modules: Mapping[str, Module]) -> None:

        # 更新
        if isinstance(modules, (OrderedDict, ModuleDict, container_abcs.Mapping)):
            for key, module in modules.items():
                self[key] = module

4 nn.ParameterList

函数们:

用法:

class MyModule(nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()
        self.params = nn.ParameterList([nn.Parameter(torch.randn(10, 10)) for i in range(10)])

    def forward(self, x):
        # ParameterList can act as an iterable, or be indexed using ints
        for i, p in enumerate(self.params):
            x = self.params[i // 2].mm(x) + p.mm(x)
        return x

extra_repr函数

def extra_repr(self) -> str:
    child_lines = []
    for k, p in self._parameters.items():
        size_str = 'x'.join(str(size) for size in p.size())
        device_str = '' if not p.is_cuda else ' (GPU {})'.format(p.get_device())
        parastr = 'Parameter containing: [{} of size {}{}]'.format(
            torch.typename(p), size_str, device_str)
        child_lines.append('  (' + str(k) + '): ' + parastr)
    tmpstr = '\n'.join(child_lines)
    return tmpstr

# 用法:
>>> params = nn.ParameterList([nn.Parameter(torch.randn(10, 10)) for i in range(10)])
>>> params.extra_repr()
'  (0): Parameter containing: [torch.FloatTensor of size 10x10]\n  (1): Parameter containing: [torch.FloatTensor of size 10x10]\n  (2): Parameter containing: [torch.FloatTensor of size 10x10]\n  (3): Parameter containing: [torch.FloatTensor of size 10x10]\n  (4): Parameter containing: [torch.FloatTensor of size 10x10]\n  (5): Parameter containing: [torch.FloatTensor of size 10x10]\n  (6): Parameter containing: [torch.FloatTensor of size 10x10]\n  (7): Parameter containing: [torch.FloatTensor of size 10x10]\n  (8): Parameter containing: [torch.FloatTensor of size 10x10]\n  (9): Parameter containing: [torch.FloatTensor of size 10x10]'

5 nn.ParameterDict

函数们:

用法:

class MyModule(nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()
        self.params = nn.ParameterDict({
            'left': nn.Parameter(torch.randn(5, 10)),
            'right': nn.Parameter(torch.randn(5, 10))
        })

    def forward(self, x, choice):
        x = self.params[choice].mm(x)
        return x
import torch import torch.nn as nn import torch.nn.functional as F import pennylane as qml from torchvision import datasets, transforms from torch.utils.data import DataLoader import numpy as np # 量子电路参数 n_qubits = 4 # 量子比特总数 n_a_qubits = 1 # 辅助量子比特数 q_depth = 1 # 参数化量子电路深度 n_generators = 49 # 子生成器数量(49×16=784) # 量子模拟器 dev = qml.device("default.qubit", wires=n_qubits) device = torch.device("cpu") # 量子生成器电路 @qml.qnode(dev, interface="torch", diff_method="parameter-shift") def quantum_circuit(noise, weights): weights = weights.reshape(q_depth, n_qubits) # 初始化噪声 for i in range(n_qubits): qml.RY(noise[i], wires=i) # 参数化量子电路 for i in range(q_depth): for y in range(n_qubits): qml.RY(weights[i][y], wires=y) for y in range(n_qubits - 1): qml.CZ(wires=[y, (y + 1) % n_qubits]) return qml.probs(wires=list(range(n_qubits))) def partial_measure(noise, weights): probs = quantum_circuit(noise, weights) probsgiven0 = probs[:(2 ** (n_qubits - n_a_qubits))] probsgiven0 /= torch.sum(probs) return probsgiven0 / torch.max(probsgiven0) # 量子生成器 class PatchQuantumGenerator(nn.Module): def __init__(self, q_delta=1): super().__init__() self.q_params = nn.ParameterList([ nn.Parameter(q_delta * torch.rand(q_depth * n_qubits)) for _ in range(n_generators) ]) # 经典后处理网络 self.post_process = nn.Sequential( nn.Linear(n_generators * 16, 512), nn.ReLU(), nn.Linear(512, 784), nn.Tanh() ) def forward(self, x): patches = [] for params in self.q_params: patch = torch.cat([partial_measure(elem, params) for elem in x]) patches.append(patch) images = torch.cat(patches, dim=1) return self.post_process(images).view(-1, 1, 28, 28) # 量子判别器 class Quanv1d(nn.Module): def __init__(self, kernel_size=4, stride=4): supe
最新发布
03-12
<think>好的,用户想了解如何用PyTorch和PennyLane实现量子生成对抗网络(QGAN)。首先,我需要回忆QGAN的基本结构,它通常包含一个量子生成器和一个经典判别器,或者两者都是量子的。但结合PyTorch和PennyLane的话,可能生成器是量子的,判别器是经典的,用PyTorch实现。 接下来,我需要确定PennyLane如何与PyTorch集成。PennyLane支持PyTorch的接口,可以通过`qml.qnodes`装饰器指定`interface='torch'`,这样梯度计算就能和PyTorch兼容了。然后,生成器的量子电路部分需要用PennyLane设计,可能使用参数化的量子门,比如旋转门和受控操作,参数由PyTorch的优化器更新。 判别器部分则完全用PyTorch构建,比如简单的全连接网络。对抗训练的过程中,生成器和判别器交替训练:生成器试图生成真实的数据,判别器则区分真实数据和生成数据。需要注意量子电路的输出如何作为生成器的输出,可能需要测量得到经典数据,再输入判别器。 然后考虑代码结构。生成器量子电路函数,使用PennyLane定义,返回测量结果。用`qml.qnode`包装,并指定PyTorch接口。判别器用PyTorchnn.Module定义。损失函数用二分类的交叉熵,优化器用Adam。训练循环中,先训练判别器,再训练生成器,交替进行。 可能的问题包括量子电路的参数如何初始化,测量结果的维度是否匹配判别器的输入。例如,如果生成的数据是二维的,量子电路可能需要返回两个量子比特的测量结果。此外,梯度计算是否正确,需要确保PennyLane和PyTorch的自动微分兼容。 示例代码需要展示这些部分:量子电路的定义,生成器和判别器的类,训练循环的逻辑。同时要解释关键部分,比如参数更新、损失计算。可能还需要调整超参数,如学习率、训练步数,量子电路的层数等。 最后,用户可能需要了解如何评估QGAN的性能,比如生成数据的分布与真实数据的对比,或者使用指标如Wasserstein距离。但示例代码可能只展示基础训练过程,评估部分可能需要额外补充。 需要检查PennyLane和PyTorch的版本兼容性,确保代码可以运行。另外,真实数据集的准备,比如用户可能使用简单的合成数据,如高斯分布,方便演示。在代码中,真实数据可以用PyTorch的Tensor生成。 总结步骤:定义量子生成器,定义经典判别器,设置优化器,交替训练,损失函数计算,参数更新。确保代码中的各部分正确衔接,并解释每个部分的作用,比如量子电路的设计如何影响生成能力,判别器结构的选择等。</think>以下是一个结合PyTorch和PennyLane的量子生成对抗网络(QGAN)实现示例,包含代码解释和实现要点: ### 1. 量子生成器设计 使用PennyLane构建参数化量子电路,生成模拟数据分布: ```python import pennylane as qml import torch import torch.nn as nn n_qubits = 3 # 量子比特数 dev = qml.device("default.qubit", wires=n_qubits) @qml.qnode(dev, interface='torch') # 指定PyTorch接口 def generator(noise, weights): # 噪声输入作为初始旋转角度 for i in range(n_qubits): qml.RX(noise[i], wires=i) qml.RY(noise[i+1], wires=i) # 参数化量子层 for w in weights: for i in range(n_qubits): qml.RX(w[i], wires=i) qml.RY(w[i+1], wires=i) qml.Barrier(wires=range(n_qubits)) for i in range(n_qubits-1): qml.CNOT(wires=[i, i+1]) return [qml.expval(qml.PauliZ(i)) for i in range(n_qubits)] ``` ### 2. 经典判别器构建 使用PyTorch构建全连接网络: ```python class Discriminator(nn.Module): def __init__(self, input_dim): super().__init__() self.model = nn.Sequential( nn.Linear(input_dim, 16), nn.ReLU(), nn.Linear(16, 8), nn.ReLU(), nn.Linear(8, 1), nn.Sigmoid() ) def forward(self, x): return self.model(x) ``` ### 3. 对抗训练框架 ```python # 初始化组件 gen_weights = torch.randn(5, n_qubits+1, requires_grad=True) # 5层量子参数 discriminator = Discriminator(n_qubits) opt_gen = torch.optim.Adam([gen_weights], lr=0.01) opt_disc = torch.optim.Adam(discriminator.parameters(), lr=0.01) loss_fn = nn.BCELoss() # 训练循环 for epoch in range(100): # 真实数据(示例使用简单分布) real_data = torch.randn(100, n_qubits) * 0.5 + 0.5 # 训练判别器 noise = torch.rand(100, n_qubits+1) fake_data = torch.stack([generator(n, gen_weights) for n in noise]) disc_real = discriminator(real_data) disc_fake = discriminator(fake_data.detach()) loss_disc = loss_fn(disc_real, torch.ones_like(disc_real)) + \ loss_fn(disc_fake, torch.zeros_like(disc_fake)) opt_disc.zero_grad() loss_disc.backward() opt_disc.step() # 训练生成器 disc_fake = discriminator(fake_data) loss_gen = loss_fn(disc_fake, torch.ones_like(disc_fake)) opt_gen.zero_grad() loss_gen.backward() opt_gen.step() ``` ### 关键实现要点: 1. **混合架构设计**:量子生成器通过Pennylane实现参数化量子电路,经典判别器使用PyTorch构建[^1] 2. **梯度计算**:通过`interface='torch'`确保量子电路的自动微分与PyTorch兼容 3. **参数优化**:量子参数通过PyTorch优化器更新,实现端到端训练 4. **维度匹配**:量子测量结果维度(n_qubits)需与判别器输入维度一致 5. **噪声输入**:使用经典噪声作为量子电路的初始旋转参数 ###
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值