Skip to content
This repository has been archived by the owner on Dec 23, 2024. It is now read-only.

Commit

Permalink
Merge pull request #100 from zama-ai/feat/rotl
Browse files Browse the repository at this point in the history
feat: add rotate left and rotate right
  • Loading branch information
immortal-tofu authored Apr 3, 2024
2 parents 5876042 + de4413a commit 1b3e576
Show file tree
Hide file tree
Showing 8 changed files with 816 additions and 0 deletions.
96 changes: 96 additions & 0 deletions fhevm/contracts_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -824,6 +824,102 @@ func FheLibShr(t *testing.T, fheUintType tfhe.FheUintType, scalar bool) {
}
}

func FheLibRotl(t *testing.T, fheUintType tfhe.FheUintType, scalar bool) {
var lhs, rhs uint64
switch fheUintType {
case tfhe.FheUint4:
lhs = 2
rhs = 1
case tfhe.FheUint8:
lhs = 2
rhs = 1
case tfhe.FheUint16:
lhs = 4283
rhs = 2
case tfhe.FheUint32:
lhs = 1333337
rhs = 3
case tfhe.FheUint64:
lhs = 13333377777777777
lhs = 34
}
expected := lhs << rhs
signature := "fheRotl(uint256,uint256,bytes1)"
depth := 1
environment := newTestEVMEnvironment()
environment.depth = depth
addr := common.Address{}
readOnly := false
lhsHash := verifyCiphertextInTestMemory(environment, lhs, depth, fheUintType).GetHash()
var rhsHash common.Hash
if scalar {
rhsHash = common.BytesToHash(big.NewInt(int64(rhs)).Bytes())
} else {
rhsHash = verifyCiphertextInTestMemory(environment, rhs, depth, fheUintType).GetHash()
}
input := toLibPrecompileInput(signature, scalar, lhsHash, rhsHash)
out, err := FheLibRun(environment, addr, addr, input, readOnly)
if err != nil {
t.Fatalf(err.Error())
}
res := getVerifiedCiphertextFromEVM(environment, common.BytesToHash(out))
if res == nil {
t.Fatalf("output ciphertext is not found in verifiedCiphertexts")
}
decrypted, err := res.ciphertext.Decrypt()
if err != nil || decrypted.Uint64() != expected {
t.Fatalf("invalid decrypted result, decrypted %v != expected %v", decrypted.Uint64(), expected)
}
}

func FheLibRotr(t *testing.T, fheUintType tfhe.FheUintType, scalar bool) {
var lhs, rhs uint64
switch fheUintType {
case tfhe.FheUint4:
lhs = 2
rhs = 1
case tfhe.FheUint8:
lhs = 2
rhs = 1
case tfhe.FheUint16:
lhs = 4283
rhs = 3
case tfhe.FheUint32:
lhs = 1333337
rhs = 3
case tfhe.FheUint64:
lhs = 13333377777777777
lhs = 34
}
expected := lhs >> rhs
signature := "fheRotr(uint256,uint256,bytes1)"
depth := 1
environment := newTestEVMEnvironment()
environment.depth = depth
addr := common.Address{}
readOnly := false
lhsHash := verifyCiphertextInTestMemory(environment, lhs, depth, fheUintType).GetHash()
var rhsHash common.Hash
if scalar {
rhsHash = common.BytesToHash(big.NewInt(int64(rhs)).Bytes())
} else {
rhsHash = verifyCiphertextInTestMemory(environment, rhs, depth, fheUintType).GetHash()
}
input := toLibPrecompileInput(signature, scalar, lhsHash, rhsHash)
out, err := FheLibRun(environment, addr, addr, input, readOnly)
if err != nil {
t.Fatalf(err.Error())
}
res := getVerifiedCiphertextFromEVM(environment, common.BytesToHash(out))
if res == nil {
t.Fatalf("output ciphertext is not found in verifiedCiphertexts")
}
decrypted, err := res.ciphertext.Decrypt()
if err != nil || decrypted.Uint64() != expected {
t.Fatalf("invalid decrypted result, decrypted %v != expected %v", decrypted.Uint64(), expected)
}
}

func FheLibNe(t *testing.T, fheUintType tfhe.FheUintType, scalar bool) {
var lhs, rhs uint64
switch fheUintType {
Expand Down
12 changes: 12 additions & 0 deletions fhevm/fhelib.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,18 @@ var fhelibMethods = []*FheLibMethod{
requiredGasFunction: fheShrRequiredGas,
runFunction: fheShrRun,
},
{
name: "fheRotl",
argTypes: "(uint256,uint256,bytes1)",
requiredGasFunction: fheRotlRequiredGas,
runFunction: fheRotlRun,
},
{
name: "fheRotr",
argTypes: "(uint256,uint256,bytes1)",
requiredGasFunction: fheRotrRequiredGas,
runFunction: fheRotrRun,
},
{
name: "fheNe",
argTypes: "(uint256,uint256,bytes1)",
Expand Down
133 changes: 133 additions & 0 deletions fhevm/operators_bit.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,139 @@ func fheShrRun(environment EVMEnvironment, caller common.Address, addr common.Ad
}
}


func fheRotlRun(environment EVMEnvironment, caller common.Address, addr common.Address, input []byte, readOnly bool, runSpan trace.Span) ([]byte, error) {
input = input[:minInt(65, len(input))]

logger := environment.GetLogger()

isScalar, err := isScalarOp(input)
if err != nil {
logger.Error("fheShl can not detect if operator is meant to be scalar", "err", err, "input", hex.EncodeToString(input))
return nil, err
}

if !isScalar {
lhs, rhs, err := get2VerifiedOperands(environment, input)
otelDescribeOperands(runSpan, encryptedOperand(*lhs), encryptedOperand(*rhs))
if err != nil {
logger.Error("fheShl inputs not verified", "err", err, "input", hex.EncodeToString(input))
return nil, err
}
if lhs.fheUintType() != rhs.fheUintType() {
msg := "fheShl operand type mismatch"
logger.Error(msg, "lhs", lhs.fheUintType(), "rhs", rhs.fheUintType())
return nil, errors.New(msg)
}

// If we are doing gas estimation, skip execution and insert a random ciphertext as a result.
if !environment.IsCommitting() && !environment.IsEthCall() {
return importRandomCiphertext(environment, lhs.fheUintType()), nil
}

result, err := lhs.ciphertext.Rotl(rhs.ciphertext)
if err != nil {
logger.Error("fheRotl failed", "err", err)
return nil, err
}
importCiphertext(environment, result)

resultHash := result.GetHash()
logger.Info("fheRotl success", "lhs", lhs.hash().Hex(), "rhs", rhs.hash().Hex(), "result", resultHash.Hex())
return resultHash[:], nil

} else {
lhs, rhs, err := getScalarOperands(environment, input)
otelDescribeOperands(runSpan, encryptedOperand(*lhs), plainOperand(*rhs))
if err != nil {
logger.Error("fheRotl scalar inputs not verified", "err", err, "input", hex.EncodeToString(input))
return nil, err
}

// If we are doing gas estimation, skip execution and insert a random ciphertext as a result.
if !environment.IsCommitting() && !environment.IsEthCall() {
return importRandomCiphertext(environment, lhs.fheUintType()), nil
}

result, err := lhs.ciphertext.ScalarRotl(rhs)
if err != nil {
logger.Error("fheRotl failed", "err", err)
return nil, err
}
importCiphertext(environment, result)

resultHash := result.GetHash()
logger.Info("fheRotl scalar success", "lhs", lhs.hash().Hex(), "rhs", rhs.Uint64(), "result", resultHash.Hex())
return resultHash[:], nil
}
}

func fheRotrRun(environment EVMEnvironment, caller common.Address, addr common.Address, input []byte, readOnly bool, runSpan trace.Span) ([]byte, error) {
input = input[:minInt(65, len(input))]

logger := environment.GetLogger()

isScalar, err := isScalarOp(input)
if err != nil {
logger.Error("fheRotr can not detect if operator is meant to be scalar", "err", err, "input", hex.EncodeToString(input))
return nil, err
}

if !isScalar {
lhs, rhs, err := get2VerifiedOperands(environment, input)
otelDescribeOperands(runSpan, encryptedOperand(*lhs), encryptedOperand(*rhs))
if err != nil {
logger.Error("fheRotr inputs not verified", "err", err, "input", hex.EncodeToString(input))
return nil, err
}
if lhs.fheUintType() != rhs.fheUintType() {
msg := "fheRotr operand type mismatch"
logger.Error(msg, "lhs", lhs.fheUintType(), "rhs", rhs.fheUintType())
return nil, errors.New(msg)
}

// If we are doing gas estimation, skip execution and insert a random ciphertext as a result.
if !environment.IsCommitting() && !environment.IsEthCall() {
return importRandomCiphertext(environment, lhs.fheUintType()), nil
}

result, err := lhs.ciphertext.Rotr(rhs.ciphertext)
if err != nil {
logger.Error("fheRotr failed", "err", err)
return nil, err
}
importCiphertext(environment, result)

resultHash := result.GetHash()
logger.Info("fheRotr success", "lhs", lhs.hash().Hex(), "rhs", rhs.hash().Hex(), "result", resultHash.Hex())
return resultHash[:], nil

} else {
lhs, rhs, err := getScalarOperands(environment, input)
otelDescribeOperands(runSpan, encryptedOperand(*lhs), plainOperand(*rhs))
if err != nil {
logger.Error("fheRotr scalar inputs not verified", "err", err, "input", hex.EncodeToString(input))
return nil, err
}

// If we are doing gas estimation, skip execution and insert a random ciphertext as a result.
if !environment.IsCommitting() && !environment.IsEthCall() {
return importRandomCiphertext(environment, lhs.fheUintType()), nil
}

result, err := lhs.ciphertext.ScalarRotr(rhs)
if err != nil {
logger.Error("fheRotr failed", "err", err)
return nil, err
}
importCiphertext(environment, result)

resultHash := result.GetHash()
logger.Info("fheRotr scalar success", "lhs", lhs.hash().Hex(), "rhs", rhs.Uint64(), "result", resultHash.Hex())
return resultHash[:], nil
}
}

func fheNegRun(environment EVMEnvironment, caller common.Address, addr common.Address, input []byte, readOnly bool, runSpan trace.Span) ([]byte, error) {
input = input[:minInt(32, len(input))]

Expand Down
10 changes: 10 additions & 0 deletions fhevm/operators_bit_gas.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,16 @@ func fheShrRequiredGas(environment EVMEnvironment, input []byte) uint64 {
return fheShlRequiredGas(environment, input)
}

func fheRotrRequiredGas(environment EVMEnvironment, input []byte) uint64 {
// Implement in terms of shl, because comparison costs are currently the same.
return fheShlRequiredGas(environment, input)
}

func fheRotlRequiredGas(environment EVMEnvironment, input []byte) uint64 {
// Implement in terms of shl, because comparison costs are currently the same.
return fheShlRequiredGas(environment, input)
}

func fheNegRequiredGas(environment EVMEnvironment, input []byte) uint64 {
input = input[:minInt(32, len(input))]

Expand Down
86 changes: 86 additions & 0 deletions fhevm/tfhe/tfhe_ciphertext.go
Original file line number Diff line number Diff line change
Expand Up @@ -1485,6 +1485,92 @@ func (lhs *TfheCiphertext) ScalarShr(rhs *big.Int) (*TfheCiphertext, error) {
fheUint160BinaryScalarNotSupportedOp, false)
}


func (lhs *TfheCiphertext) Rotl(rhs *TfheCiphertext) (*TfheCiphertext, error) {
return lhs.executeBinaryCiphertextOperation(rhs,
boolBinaryNotSupportedOp,
func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) {
return C.rotl_fhe_uint4(lhs, rhs, sks), nil
},
func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) {
return C.rotl_fhe_uint8(lhs, rhs, sks), nil
},
func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) {
return C.rotl_fhe_uint16(lhs, rhs, sks), nil
},
func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) {
return C.rotl_fhe_uint32(lhs, rhs, sks), nil
},
func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) {
return C.rotl_fhe_uint64(lhs, rhs, sks), nil
},
fheUint160BinaryNotSupportedOp, false)
}

func (lhs *TfheCiphertext) ScalarRotl(rhs *big.Int) (*TfheCiphertext, error) {
return lhs.executeBinaryScalarOperation(rhs,
boolBinaryScalarNotSupportedOp,
func(lhs unsafe.Pointer, rhs C.uint8_t) (unsafe.Pointer, error) {
return C.scalar_rotl_fhe_uint4(lhs, rhs, sks), nil
},
func(lhs unsafe.Pointer, rhs C.uint8_t) (unsafe.Pointer, error) {
return C.scalar_rotl_fhe_uint8(lhs, rhs, sks), nil
},
func(lhs unsafe.Pointer, rhs C.uint16_t) (unsafe.Pointer, error) {
return C.scalar_rotl_fhe_uint16(lhs, rhs, sks), nil
},
func(lhs unsafe.Pointer, rhs C.uint32_t) (unsafe.Pointer, error) {
return C.scalar_rotl_fhe_uint32(lhs, rhs, sks), nil
},
func(lhs unsafe.Pointer, rhs C.uint64_t) (unsafe.Pointer, error) {
return C.scalar_rotl_fhe_uint64(lhs, rhs, sks), nil
},
fheUint160BinaryScalarNotSupportedOp, false)
}

func (lhs *TfheCiphertext) Rotr(rhs *TfheCiphertext) (*TfheCiphertext, error) {
return lhs.executeBinaryCiphertextOperation(rhs,
boolBinaryNotSupportedOp,
func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) {
return C.rotr_fhe_uint4(lhs, rhs, sks), nil
},
func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) {
return C.rotr_fhe_uint8(lhs, rhs, sks), nil
},
func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) {
return C.rotr_fhe_uint16(lhs, rhs, sks), nil
},
func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) {
return C.rotr_fhe_uint32(lhs, rhs, sks), nil
},
func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) {
return C.rotr_fhe_uint64(lhs, rhs, sks), nil
},
fheUint160BinaryNotSupportedOp,
false)
}

func (lhs *TfheCiphertext) ScalarRotr(rhs *big.Int) (*TfheCiphertext, error) {
return lhs.executeBinaryScalarOperation(rhs,
boolBinaryScalarNotSupportedOp,
func(lhs unsafe.Pointer, rhs C.uint8_t) (unsafe.Pointer, error) {
return C.scalar_rotr_fhe_uint4(lhs, rhs, sks), nil
},
func(lhs unsafe.Pointer, rhs C.uint8_t) (unsafe.Pointer, error) {
return C.scalar_rotr_fhe_uint8(lhs, rhs, sks), nil
},
func(lhs unsafe.Pointer, rhs C.uint16_t) (unsafe.Pointer, error) {
return C.scalar_rotr_fhe_uint16(lhs, rhs, sks), nil
},
func(lhs unsafe.Pointer, rhs C.uint32_t) (unsafe.Pointer, error) {
return C.scalar_rotr_fhe_uint32(lhs, rhs, sks), nil
},
func(lhs unsafe.Pointer, rhs C.uint64_t) (unsafe.Pointer, error) {
return C.scalar_rotr_fhe_uint64(lhs, rhs, sks), nil
},
fheUint160BinaryScalarNotSupportedOp, false)
}

func (lhs *TfheCiphertext) Eq(rhs *TfheCiphertext) (*TfheCiphertext, error) {
return lhs.executeBinaryCiphertextOperation(rhs,
boolBinaryNotSupportedOp,
Expand Down
Loading

0 comments on commit 1b3e576

Please sign in to comment.