diff --git a/tests/test_grad.py b/tests/test_grad.py index c77b0b9..b771170 100644 --- a/tests/test_grad.py +++ b/tests/test_grad.py @@ -2,6 +2,7 @@ import torch from torch.autograd.gradcheck import gradcheck, gradgradcheck from torchlpc.core import LPC +from torchlpc.recurrence import RecurrenceCUDA def get_random_biquads(cmplx=False): @@ -131,12 +132,12 @@ def test_cuda_parallel_scan( batch_size = 2 samples = 123 x = torch.randn(batch_size, samples, dtype=torch.double, device="cuda") - A = torch.rand(batch_size, samples, 1, dtype=torch.double, device="cuda") * 2 - 1 - zi = torch.randn(batch_size, 1, dtype=torch.double, device="cuda") + A = torch.rand(batch_size, samples, dtype=torch.double, device="cuda") * 2 - 1 + zi = torch.randn(batch_size, dtype=torch.double, device="cuda") A.requires_grad = a_requires_grad x.requires_grad = x_requires_grad zi.requires_grad = zi_requires_grad - assert gradcheck(LPC.apply, (x, A, zi), check_forward_ad=True) - assert gradgradcheck(LPC.apply, (x, A, zi)) + assert gradcheck(RecurrenceCUDA.apply, (A, x, zi), check_forward_ad=True) + assert gradgradcheck(RecurrenceCUDA.apply, (A, x, zi)) diff --git a/tests/test_vmap.py b/tests/test_vmap.py index 75c4496..1c20250 100644 --- a/tests/test_vmap.py +++ b/tests/test_vmap.py @@ -3,6 +3,7 @@ from torch.func import hessian, jacfwd import pytest from torchlpc.core import LPC +from torchlpc.recurrence import RecurrenceCUDA from .test_grad import create_test_inputs @@ -52,8 +53,8 @@ def test_cuda_parallel_scan_vmap(): batch_size = 3 samples = 255 x = torch.randn(batch_size, samples, dtype=torch.double, device="cuda") - A = torch.rand(batch_size, samples, 1, dtype=torch.double, device="cuda") * 2 - 1 - zi = torch.randn(batch_size, 1, dtype=torch.double, device="cuda") + A = torch.rand(batch_size, samples, dtype=torch.double, device="cuda") * 2 - 1 + zi = torch.randn(batch_size, dtype=torch.double, device="cuda") y = torch.randn(batch_size, samples, dtype=torch.double, device="cuda") A.requires_grad = True @@ -63,7 +64,7 @@ def test_cuda_parallel_scan_vmap(): args = (x, A, zi) def func(x, A, zi): - return F.mse_loss(LPC.apply(x, A, zi), y) + return F.mse_loss(RecurrenceCUDA.apply(A, x, zi), y) jacs = jacfwd(func, argnums=tuple(range(len(args))))(*args)