From 930bfabc17748ea3772e6a40b04e84fc4aafcf04 Mon Sep 17 00:00:00 2001
From: meejah <meejah@meejah.ca>
Date: Wed, 9 Nov 2022 23:42:33 -0700
Subject: [PATCH 1/2] use cryptography's HKDF implementation

---
 setup.py                       |  2 +-
 src/spake2/ed25519_basic.py    |  2 +-
 src/spake2/groups.py           | 21 ++++++++++++++-------
 src/spake2/test/test_compat.py |  9 +++++----
 4 files changed, 21 insertions(+), 13 deletions(-)

diff --git a/setup.py b/setup.py
index 660f055..ba3cc28 100755
--- a/setup.py
+++ b/setup.py
@@ -79,5 +79,5 @@ def abbrev(t):
           "Programming Language :: Python :: 3.6",
           "Topic :: Security :: Cryptography",
           ],
-      install_requires=["hkdf"],
+      install_requires=["cryptography"],
       )
diff --git a/src/spake2/ed25519_basic.py b/src/spake2/ed25519_basic.py
index 1890be7..dbab56d 100644
--- a/src/spake2/ed25519_basic.py
+++ b/src/spake2/ed25519_basic.py
@@ -273,7 +273,7 @@ def arbitrary_element(seed): # unknown DL
     # oversized string (128 bits more than the field size), then reducing
     # down to Q. But it's comforting, and it's the same technique we use for
     # converting passwords/seeds to scalars (which *does* need uniformity).
-    hseed = expand_arbitrary_element_seed(seed, (256/8)+16)
+    hseed = expand_arbitrary_element_seed(seed, int((256/8)+16))
     y = int(binascii.hexlify(hseed), 16) % Q
 
     # we try successive Y values until we find a valid point
diff --git a/src/spake2/groups.py b/src/spake2/groups.py
index de4f75d..66b08e7 100644
--- a/src/spake2/groups.py
+++ b/src/spake2/groups.py
@@ -1,6 +1,7 @@
 from __future__ import division
 import hashlib
-from hkdf import Hkdf
+from cryptography.hazmat.primitives.kdf import hkdf
+from cryptography.hazmat.primitives import hashes
 from .six import integer_types
 from .util import (size_bits, size_bytes, unbiased_randrange,
                    bytes_to_number, number_to_bytes)
@@ -63,9 +64,12 @@
 
 
 def expand_password(data, num_bytes):
-    h = Hkdf(salt=b"", input_key_material=data, hash=hashlib.sha256)
-    info = b"SPAKE2 pw"
-    return h.expand(info, num_bytes)
+    return hkdf.HKDF(
+        algorithm=hashes.SHA256(),
+        length=num_bytes,
+        salt=b"",
+        info=b"SPAKE2 pw"
+    ).derive(data)
 
 def password_to_scalar(pw, scalar_size_bytes, q):
     assert isinstance(pw, bytes)
@@ -77,9 +81,12 @@ def password_to_scalar(pw, scalar_size_bytes, q):
     return i % q
 
 def expand_arbitrary_element_seed(data, num_bytes):
-    h = Hkdf(salt=b"", input_key_material=data, hash=hashlib.sha256)
-    info = b"SPAKE2 arbitrary element"
-    return h.expand(info, num_bytes)
+    return hkdf.HKDF(
+        algorithm=hashes.SHA256(),
+        length=num_bytes,
+        salt=b"",
+        info=b"SPAKE2 arbitrary element"
+    ).derive(data)
 
 class _Element:
     def __init__(self, group, e):
diff --git a/src/spake2/test/test_compat.py b/src/spake2/test/test_compat.py
index 3c636be..1c1340c 100644
--- a/src/spake2/test/test_compat.py
+++ b/src/spake2/test/test_compat.py
@@ -1,7 +1,8 @@
 import unittest
 from binascii import hexlify, unhexlify
 from hashlib import sha256
-from hkdf import Hkdf
+from cryptography.hazmat.primitives.kdf import hkdf
+from cryptography.hazmat.primitives import hashes
 from .myhkdf import HKDF as myHKDF
 from spake2 import groups, ed25519_group
 from spake2.spake2 import (SPAKE2_A, SPAKE2_B, SPAKE2_Symmetric,
@@ -213,14 +214,14 @@ def test_vectors(self):
     {"salt": "00", "IKM": "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f", "info": "", "L": 4, "OKM": "37ad2910"},
     ]
 
-class HKDF(unittest.TestCase):
+class TestHKDF(unittest.TestCase):
     def test_vectors(self):
         for vector in HKDF_TEST_VECTORS:
             salt = unhexlify(vector["salt"].encode("ascii"))
             IKM = unhexlify(vector["IKM"].encode("ascii"))
             info = unhexlify(vector["info"].encode("ascii"))
-            h = Hkdf(salt=salt, input_key_material=IKM, hash=sha256)
-            digest = h.expand(info, vector["L"])
+            h = hkdf.HKDF(algorithm=hashes.SHA256(), length=vector["L"], salt=salt, info=info)
+            digest = h.derive(IKM)
             self.assertEqual(digest, myHKDF(IKM, vector["L"], salt, info))
             #print(hexlify(digest))
             expected = vector["OKM"].encode("ascii")
