Skip to content

Commit

Permalink
Import/Export PK
Browse files Browse the repository at this point in the history
  • Loading branch information
WindowsNT committed Aug 14, 2023
1 parent 6f49ab0 commit a249734
Show file tree
Hide file tree
Showing 2 changed files with 156 additions and 19 deletions.
100 changes: 95 additions & 5 deletions dll/dll.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,6 @@
#include <memory>


#include "..\\kyber\\include\\kyber512_kem.hpp"
#include "..\\kyber\\include\\kyber768_kem.hpp"
#include "..\\kyber\\include\\kyber1024_kem.hpp"
#include "..\\kyber\\include\\pke.hpp"


#pragma comment (lib,"C:\\Windows Kits\\10\\Cryptographic Provider Development Kit\\Lib\\x64\\bcrypt_provider.lib")
#pragma comment (lib,"C:\\Windows Kits\\10\\Cryptographic Provider Development Kit\\Lib\\x64\\ncrypt_provider.lib")
Expand Down Expand Up @@ -359,8 +354,20 @@ namespace SHA3

}

#include "..\\kyber\\include\\kyber512_kem.hpp"
#include "..\\kyber\\include\\kyber768_kem.hpp"
#include "..\\kyber\\include\\kyber1024_kem.hpp"
#include "..\\kyber\\include\\pke.hpp"

//#include "..\\dilithium\\include\\dilithium5.hpp"
//#include "..\\dilithium\\include\\dilithium3.hpp"
//#include "..\\dilithium\\include\\dilithium2.hpp"
//#include "..\\dilithium\\include\\dilithium.hpp"


namespace PK_ALGORITHMS
{

class PKA
{
public:
Expand Down Expand Up @@ -696,6 +703,71 @@ namespace PK_ALGORITHMS
_In_ ULONG dwFlags
)
{
BOP* k = (BOP*)hAlgorithm;
if (dynamic_cast<KYBER512*>(k))
{
auto k = new KYBER(512);
k->d.resize(32);
k->z.resize(32);
prng::prng_t prng;
prng.read(k->d.data(), k->d.size());
prng.read(k->z.data(), k->z.size());
k->p.resize(kyber512_kem::PKEY_LEN);
k->s.resize(kyber512_kem::SKEY_LEN);
auto req = k->p.size() + k->s.size();
if (cbInput != req)
{
delete k;
return STATUS_NOT_SUPPORTED;
}
memcpy(k->p.data(), pbInput, k->p.size());
memcpy(k->s.data(), pbInput + k->p.size(), k->s.size());
*phKey = k;
return STATUS_SUCCESS;
}
if (dynamic_cast<KYBER768*>(k))
{
auto k = new KYBER(768);
k->d.resize(32);
k->z.resize(32);
prng::prng_t prng;
prng.read(k->d.data(), k->d.size());
prng.read(k->z.data(), k->z.size());
k->p.resize(kyber768_kem::PKEY_LEN);
k->s.resize(kyber768_kem::SKEY_LEN);
auto req = k->p.size() + k->s.size();
if (cbInput != req)
{
delete k;
return STATUS_NOT_SUPPORTED;
}
memcpy(k->p.data(), pbInput, k->p.size());
memcpy(k->s.data(), pbInput + k->p.size(), k->s.size());
*phKey = k;
return STATUS_SUCCESS;
}
if (dynamic_cast<KYBER1024*>(k))
{
auto k = new KYBER(1024);
k->d.resize(32);
k->z.resize(32);
prng::prng_t prng;
prng.read(k->d.data(), k->d.size());
prng.read(k->z.data(), k->z.size());
k->p.resize(kyber1024_kem::PKEY_LEN);
k->s.resize(kyber1024_kem::SKEY_LEN);
auto req = k->p.size() + k->s.size();
if (cbInput != req)
{
delete k;
return STATUS_NOT_SUPPORTED;
}
memcpy(k->p.data(), pbInput, k->p.size());
memcpy(k->s.data(), pbInput + k->p.size(), k->s.size());
*phKey = k;
return STATUS_SUCCESS;
}

return STATUS_NOT_SUPPORTED;
};

Expand Down Expand Up @@ -821,6 +893,24 @@ namespace PK_ALGORITHMS
_Out_ ULONG* pcbResult,
_In_ ULONG dwFlags)
{
if (wcscmp(pszBlobType, L"") != 0)
return STATUS_NOT_SUPPORTED;
PKA* a = (PKA*)hKey;
if (auto k = dynamic_cast<KYBER*>(a))
{
std::vector<uint8_t> m(32);
std::vector<uint8_t> cipher;
std::vector<uint8_t> shrd_key0(32);
auto needs = k->p.size() + k->s.size();
*pcbResult = (ULONG)needs;
if (!pbOutput || !cbOutput)
return STATUS_SUCCESS;
if (cbOutput < needs)
return STATUS_NOT_SUPPORTED;
memcpy(pbOutput, k->p.data(), k->p.size());
memcpy(pbOutput + k->p.size(), k->s.data(), k->s.size());
return ERROR_SUCCESS;
}
return STATUS_NOT_SUPPORTED;
};
m.DestroyKey = [](_Inout_ BCRYPT_KEY_HANDLE hKey
Expand Down
75 changes: 61 additions & 14 deletions test/test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <windows.h>
#include <bcrypt.h>
#include <vector>
#include <string>
#include "..\\common.h"

#pragma comment(lib,"bcrypt.lib")
Expand Down Expand Up @@ -125,23 +126,25 @@ class PK
{
BCRYPT_ALG_HANDLE h;
BCRYPT_HASH_HANDLE ha;
std::wstring talg;
public:

PK(LPCWSTR alg = KYBER_512_ALGORITHM,int bits = 1024)
PK(LPCWSTR alg = KYBER_512_ALGORITHM)
{
talg = alg;
BCryptOpenAlgorithmProvider(&h, alg, 0, 0);
}

void gen(int bits = 512)
{
if (ha)
return;
if (h)
{
auto st = BCryptGenerateKeyPair(h, &ha, bits, 0);
st = BCryptFinalizeKeyPair(ha, 0);
/* // HASH h2(BCRYPT_SHA256_ALGORITHM);
HASH h2(SHA3_256_ALGORITHM);
h2.hash((BYTE*)pwd, (DWORD)pwdlen);
std::vector<unsigned char> rx;
h2.get(rx);
BCryptGenerateSymmetricKey(h, &ha, 0, 0, (PUCHAR)rx.data(), (ULONG)rx.size(), 0);
*/
}

}

bool e(const BYTE* d, DWORD sz, std::vector<unsigned char>& out,HASH* h = 0)
Expand All @@ -161,6 +164,39 @@ class PK
return (nt == 0) ? true : false;
}

bool imp(std::vector<unsigned char>& key)
{
if (ha)
return false;

auto str = BCRYPT_RSAFULLPRIVATE_BLOB;
if (talg != BCRYPT_RSA_ALGORITHM)
str = L"";
auto r = BCryptImportKeyPair(h, 0, str, &ha, key.data(), (DWORD)key.size(), 0);
if (ha)
return true;
return false;
}

bool exp(std::vector<unsigned char>& out)
{
if (!ha)
return false;
auto str = BCRYPT_RSAFULLPRIVATE_BLOB;
if (talg != BCRYPT_RSA_ALGORITHM)
str = L"";
ULONG cb = 0;
auto st = BCryptExportKey(ha, 0, str, 0, 0, &cb, 0);
if (st != 0)
return false;
out.resize(cb);
st = BCryptExportKey(ha, 0, str, (PUCHAR)out.data(), (DWORD)out.size(), &cb, 0);
if (st != 0)
return false;
out.resize(cb);
return true;
}

bool d(const BYTE* d, DWORD sz, std::vector<unsigned char>& out)
{
if (!ha)
Expand Down Expand Up @@ -219,12 +255,23 @@ int __stdcall WinMain(HINSTANCE, HINSTANCE, LPSTR, int)
if (1)
{
// Can be used with existing algorithms
// PK e(BCRYPT_RSA_ALGORITHM);
// auto algo = BCRYPT_RSA_ALGORITHM;

// New algorithms
// PK e(KYBER_512_ALGORITHM);
// PK e(KYBER_768_ALGORITHM);
PK e(KYBER_1024_ALGORITHM);
// auto algo = KYBER_512_ALGORITHM;
// auto algo = KYBER_768_ALGORITHM;
auto algo = KYBER_1024_ALGORITHM;


PK e1(algo);
e1.gen(1024);

std::vector<unsigned char> key;
e1.exp(key);

PK e2(algo);
e2.imp(key);

std::vector<unsigned char> out1;
std::vector<unsigned char> out2;

Expand All @@ -233,8 +280,8 @@ int __stdcall WinMain(HINSTANCE, HINSTANCE, LPSTR, int)
std::vector<unsigned char> outx(32);
hash.hash((BYTE*)"Hello", 5);
hash.get(outx);
e.e((const BYTE*)outx.data(), (DWORD)outx.size(), out1);
e.d(out1.data(), (ULONG)out1.size(), out2);
e1.e((const BYTE*)outx.data(), (DWORD)outx.size(), out1);
e2.d(out1.data(), (ULONG)out1.size(), out2);
assert(memcmp(outx.data(), out2.data(), out2.size()) == 0);
}

Expand Down

0 comments on commit a249734

Please sign in to comment.