Differentiable OT with PyTorch