Skip to content

Commit

Permalink
test parallel scan directly
Browse files Browse the repository at this point in the history
  • Loading branch information
yoyolicoris committed Sep 9, 2024
1 parent 1444b35 commit af31644
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 7 deletions.
9 changes: 5 additions & 4 deletions tests/test_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch
from torch.autograd.gradcheck import gradcheck, gradgradcheck
from torchlpc.core import LPC
from torchlpc.recurrence import RecurrenceCUDA


def get_random_biquads(cmplx=False):
Expand Down Expand Up @@ -131,12 +132,12 @@ def test_cuda_parallel_scan(
batch_size = 2
samples = 123
x = torch.randn(batch_size, samples, dtype=torch.double, device="cuda")
A = torch.rand(batch_size, samples, 1, dtype=torch.double, device="cuda") * 2 - 1
zi = torch.randn(batch_size, 1, dtype=torch.double, device="cuda")
A = torch.rand(batch_size, samples, dtype=torch.double, device="cuda") * 2 - 1
zi = torch.randn(batch_size, dtype=torch.double, device="cuda")

A.requires_grad = a_requires_grad
x.requires_grad = x_requires_grad
zi.requires_grad = zi_requires_grad

assert gradcheck(LPC.apply, (x, A, zi), check_forward_ad=True)
assert gradgradcheck(LPC.apply, (x, A, zi))
assert gradcheck(RecurrenceCUDA.apply, (A, x, zi), check_forward_ad=True)
assert gradgradcheck(RecurrenceCUDA.apply, (A, x, zi))
7 changes: 4 additions & 3 deletions tests/test_vmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from torch.func import hessian, jacfwd
import pytest
from torchlpc.core import LPC
from torchlpc.recurrence import RecurrenceCUDA


from .test_grad import create_test_inputs
Expand Down Expand Up @@ -52,8 +53,8 @@ def test_cuda_parallel_scan_vmap():
batch_size = 3
samples = 255
x = torch.randn(batch_size, samples, dtype=torch.double, device="cuda")
A = torch.rand(batch_size, samples, 1, dtype=torch.double, device="cuda") * 2 - 1
zi = torch.randn(batch_size, 1, dtype=torch.double, device="cuda")
A = torch.rand(batch_size, samples, dtype=torch.double, device="cuda") * 2 - 1
zi = torch.randn(batch_size, dtype=torch.double, device="cuda")
y = torch.randn(batch_size, samples, dtype=torch.double, device="cuda")

A.requires_grad = True
Expand All @@ -63,7 +64,7 @@ def test_cuda_parallel_scan_vmap():
args = (x, A, zi)

def func(x, A, zi):
return F.mse_loss(LPC.apply(x, A, zi), y)
return F.mse_loss(RecurrenceCUDA.apply(A, x, zi), y)

jacs = jacfwd(func, argnums=tuple(range(len(args))))(*args)

Expand Down

0 comments on commit af31644

Please sign in to comment.