net = net.to(device=torch.device(f'cuda:{0}'))
for X,y in train_iter:
X = X.to(torch.device(f'cuda:{0}'))
y = y.to(torch.device(f'cuda:{0}'))
net = net.to(device=torch.device(f'cuda:{0}'))
for X,y in train_iter:
X = X.to(torch.device(f'cuda:{0}'))
y = y.to(torch.device(f'cuda:{0}'))