/*
 * (c) 2015-2017 Marcos Del Sol Vives
 * (c) 2016      javiMaD
 *
 * SPDX-License-Identifier: MIT
 */

#include "drbg.h"
#include <assert.h>
#include <string.h>
#include <mbedtls/md.h>

void nfc3d_drbg_init(nfc3d_drbg_ctx *ctx, const uint8_t *hmacKey, size_t hmacKeySize, const uint8_t *seed, size_t seedSize) {
    assert(ctx != NULL);
    assert(hmacKey != NULL);
    assert(seed != NULL);
    assert(seedSize <= NFC3D_DRBG_MAX_SEED_SIZE);

    // Initialize primitives
    ctx->used = false;
    ctx->iteration = 0;
    ctx->bufferSize = sizeof(ctx->iteration) + seedSize;

    // The 16-bit counter is prepended to the seed when hashing, so we'll leave 2 bytes at the start
    memcpy(ctx->buffer + sizeof(uint16_t), seed, seedSize);

    // Initialize underlying HMAC context
    mbedtls_md_init(&ctx->hmacCtx);
    mbedtls_md_setup(&ctx->hmacCtx, mbedtls_md_info_from_type(MBEDTLS_MD_SHA256), 1);
    mbedtls_md_hmac_starts(&ctx->hmacCtx, hmacKey, hmacKeySize);
}

void nfc3d_drbg_step(nfc3d_drbg_ctx *ctx, uint8_t *output) {
    assert(ctx != NULL);
    assert(output != NULL);

    if (ctx->used) {
        // If used at least once, reinitialize the HMAC
        mbedtls_md_hmac_reset(&ctx->hmacCtx);
    } else {
        ctx->used = true;
    }

    // Store counter in big endian, and increment it
    ctx->buffer[0] = ctx->iteration >> 8;
    ctx->buffer[1] = ctx->iteration >> 0;
    ctx->iteration++;

    // Do HMAC magic
    mbedtls_md_hmac_update(&ctx->hmacCtx, ctx->buffer, ctx->bufferSize);
    mbedtls_md_hmac_finish(&ctx->hmacCtx, output);
}

void nfc3d_drbg_cleanup(nfc3d_drbg_ctx *ctx) {
    assert(ctx != NULL);
    mbedtls_md_free(&ctx->hmacCtx);
}

void nfc3d_drbg_generate_bytes(const uint8_t *hmacKey, size_t hmacKeySize, const uint8_t *seed, size_t seedSize, uint8_t *output, size_t outputSize) {
    uint8_t temp[NFC3D_DRBG_OUTPUT_SIZE];

    nfc3d_drbg_ctx rngCtx;
    nfc3d_drbg_init(&rngCtx, hmacKey, hmacKeySize, seed, seedSize);

    while (outputSize > 0) {
        if (outputSize < NFC3D_DRBG_OUTPUT_SIZE) {
            nfc3d_drbg_step(&rngCtx, temp);
            memcpy(output, temp, outputSize);
            break;
        }

        nfc3d_drbg_step(&rngCtx, output);
        output += NFC3D_DRBG_OUTPUT_SIZE;
        outputSize -= NFC3D_DRBG_OUTPUT_SIZE;
    }

    nfc3d_drbg_cleanup(&rngCtx);
}