/***************************************************************************
 * Copyright                                                               *
 *                                                                         *
 *     ESCRYPT GmbH - Embedded Security       ESCRYPT Inc.                 *
 *     Zentrum fuer IT-Sicherheit             315 E Eisenhower Parkway     *
 *     Lise-Meitner-Allee 4                   Suite 214                    *
 *     44801 Bochum                           Ann Arbor, MI 48108          *
 *     Germany                                USA                          *
 *                                                                         *
 *     http://www.escrypt.com                                              *
 *     info"at"escrypt.com                                                 *
 *                                                                         *
 * All Rights reserved                                                     *
 ***************************************************************************/

/***************************************************************************/
/*!
   \file        pkcs1_pss.c

   \brief       PKCS#1 v2.1 RSASSA-PSS signature creation
   \see         ftp://ftp.rsasecurity.com/pub/pkcs/pkcs-1/pkcs-1v2-1.pdf

   $Rev: 998 $
 */
/***************************************************************************/

/***************************************************************************
 * 1. INCLUDES                                                             *
 ***************************************************************************/

#include "../inc/pkcs1_pss.h"

#ifndef EscRsa_ENABLE_STACK_SAVING_INTERFACE

/***************************************************************************
 * 2. DEFINES                                                              *
 ***************************************************************************/

#    define EscPkcs1Pss_ZEROES 8U
#    define EscPkcs1Pss_M1_LEN ( EscPkcs1Pss_DIGEST_LEN + EscPkcs1Pss_SALT_LEN + EscPkcs1Pss_ZEROES )
#    define EscPkcs1Pss_EMLEN ( EscRsa_KEY_BYTES )
#    define EscPkcs1Pss_MSGLEN ( EscPkcs1Pss_DIGEST_LEN + 4U )
#    define EscPkcs1Pss_DB_LEN ( ( EscPkcs1Pss_EMLEN - EscPkcs1Pss_DIGEST_LEN ) - 1U )
#    define EscPkcs1Pss_EMSA_PSS_TRAILER 0xbcU
#    define EscPkcs1Pss_EMSA_PSS_SEPARATOR 0x01U
#    define EscPkcs1Pss_BITS_PER_BYTE 8U

/***************************************************************************
 * 3. DECLARATIONS                                                         *
 ***************************************************************************/

/** Converts an integer into an octet string. (big endian) */
static void
EscPkcs1Pss_I2OSP(
    UINT32 integer,
    UINT8 octet[] );

/** Wrapper function for the used hash function */
static BOOL
EscPkcs1Pss_Hash(
    const UINT8 msg[],
    UINT32 msgLen,
    UINT8 digest[] );

/** Function EMSA-PSS-ENCODE. Encodes a message according to EMSA-PSS */
static BOOL
EscPkcs1Pss_EmsaPssEncode(
    EscRandom_ContextT* randCtx,
    const UINT8 message[],  /** Message to verify */
    const UINT32 messageLen, /** Length of message in byte */
    UINT8 em[],
    const UINT32 emBits );

/** Function EMSA-PSS-Verify. Verifies the encoding of a message according to EMSA-PSS. */
static BOOL
EscPkcs1Pss_EmsaPssVerify(
    const UINT8 message[],  /** Message to verify */
    const UINT32 messageLen, /** Length of message in byte */
    const UINT8 em[],
    const UINT32 emBits );

/** Mask generation function according to PKCS#1 v2.1.*/
static BOOL
EscPkcs1Pss_MGF1(
    const UINT8 mgfSeed[],
    UINT8 mask[],
    UINT32 maskLen );

/***************************************************************************
 * 4. IMPLEMENTATION OF FUNCTIONS                                          *
 ***************************************************************************/

/** Converts an integer into an octet string. (big endian) */
static void
EscPkcs1Pss_I2OSP(
    UINT32 integer,
    UINT8 octet[] )
{
    octet[ 0 ] = (UINT8)( ( integer >> 24 ) & 0xffU );
    octet[ 1 ] = (UINT8)( ( integer >> 16 ) & 0xffU );
    octet[ 2 ] = (UINT8)( ( integer >> 8 ) & 0xffU );
    octet[ 3 ] = (UINT8)( ( integer ) & 0xffU );
}

/** Wrapper function for the used hash function */
static BOOL
EscPkcs1Pss_Hash(
    const UINT8 msg[],
    UINT32 msgLen,
    UINT8 digest[] )
{
    BOOL hasFailed;

#   if ( EscPkcs1Pss_HASH_TYPE == EscPkcs1Pss_USE_SHA1 )
/* SHA-1 */
    hasFailed = EscSha1_Calc( msg, msgLen, digest);
#   elif ( EscPkcs1Pss_HASH_TYPE == EscPkcs1Pss_USE_SHA256 )
/* SHA-256 */
    hasFailed = EscSha256_Calc( msg, msgLen, digest);
#   elif ( EscPkcs1Pss_HASH_TYPE == EscPkcs1Pss_USE_SHA512 )
/* SHA-512 */
    hasFailed = EscSha512_Calc( msg, msgLen, digest);
#   endif

    return hasFailed;
}

/**
 * Mask generation function according to PKCS#1 v2.1.
 * The size of input is fixed to EscPkcs1Pss_DIGEST_LEN.
 **/
static BOOL
EscPkcs1Pss_MGF1(
    const UINT8 mgfSeed[],  /** seed from which the mask is generated */
    UINT8 mask[],
    UINT32 maskLen )
{
    /*lint --e{661} --e{662} It is confirmed that the array out of bounds warning is a false positive */
    BOOL hasFailed = FALSE;
    UINT8 msg[ EscPkcs1Pss_MSGLEN ];
    UINT8 digest[ EscPkcs1Pss_DIGEST_LEN ];
    UINT32 i;
    UINT32 j;

    for ( i = 0U; i < EscPkcs1Pss_DIGEST_LEN; i++ ) {
        msg[ i ] = mgfSeed[ i ];
    }

    for ( i = 0U; i < ( ( maskLen + ( EscPkcs1Pss_DIGEST_LEN - 1U ) ) / EscPkcs1Pss_DIGEST_LEN ); i++ ) {
        /* C = I2OSP( counter, 4 ) */
        /* T = T || Hash( mgfSeed || C ) */
        EscPkcs1Pss_I2OSP( i, &msg[ EscPkcs1Pss_DIGEST_LEN ] );

        hasFailed |= EscPkcs1Pss_Hash( msg, EscPkcs1Pss_MSGLEN, digest );

        if ( hasFailed == FALSE ) {
            UINT32 copylen = ( ( ( ( i + 1U ) * EscPkcs1Pss_DIGEST_LEN ) <= maskLen ) ? EscPkcs1Pss_DIGEST_LEN : ( maskLen % EscPkcs1Pss_DIGEST_LEN ) );

            for ( j = 0U; j < copylen; j++ ) {
                mask[ ( i * EscPkcs1Pss_DIGEST_LEN ) + j ] = digest[ j ];
            }
        }
    }

    return hasFailed;
}

/** Function EMSA-PSS-ENCODE. Encodes a message according to EMSA-PSS.
    All quoted references refer to the section 9.1.1 from the PKCS#1 v2.1 standard */
static BOOL
EscPkcs1Pss_EmsaPssEncode(
    EscRandom_ContextT* randCtx,
    const UINT8 message[],  /** Message to encode */
    const UINT32 messageLen, /** Length of message in byte */
    UINT8 em[],
    const UINT32 emBits )
{
    UINT8 M1[ EscPkcs1Pss_M1_LEN ];
    UINT32 i;
    BOOL hasFailed;

    /* plausibility check */
#    if ( ( EscPkcs1Pss_DIGEST_LEN + EscPkcs1Pss_SALT_LEN + 2U ) > EscPkcs1Pss_EMLEN )
#        error "PKCS_PSS: encoding error"
#    endif

    /* 2./4./5. M' = 00 00 00 00 00 00 00 00 || Hash(M) || salt */
    for ( i = 0U; i < EscPkcs1Pss_ZEROES; i++ ) {
        M1[ i ] = 0x00U;
    }

    hasFailed = EscPkcs1Pss_Hash( message, messageLen, &M1[ EscPkcs1Pss_ZEROES ] );

    hasFailed |= EscRandom_GetRandom( randCtx, &M1[ EscPkcs1Pss_ZEROES + EscPkcs1Pss_DIGEST_LEN ], EscPkcs1Pss_SALT_LEN );

    /* 12. EM = maskedDB || Hash(M') || 0xbc */
    hasFailed |= EscPkcs1Pss_Hash( M1, EscPkcs1Pss_M1_LEN, &em[ EscPkcs1Pss_DB_LEN ] );

    em[ EscPkcs1Pss_EMLEN - 1U ] = EscPkcs1Pss_EMSA_PSS_TRAILER;

    /* 9. dbMask = MGF( H, emLen - hLen - 1 ) */
    hasFailed |= EscPkcs1Pss_MGF1( &em[ EscPkcs1Pss_DB_LEN ], em, EscPkcs1Pss_DB_LEN );

    /* 7. PS consists of EscPkcs1Pss_DB_LEN - EscPkcs1Pss_SALT_LEN - 1 zero octets */
    /* 8./10. maskedDB = DB XOR ( PS || 0x01 || salt ) */
    em[ ( EscPkcs1Pss_DB_LEN - EscPkcs1Pss_SALT_LEN ) - 1U ] ^= EscPkcs1Pss_EMSA_PSS_SEPARATOR;

    for ( i = 0U; i < EscPkcs1Pss_SALT_LEN; i++ ) {
        em[ ( EscPkcs1Pss_DB_LEN - EscPkcs1Pss_SALT_LEN ) + i ] ^= M1[ EscPkcs1Pss_ZEROES + EscPkcs1Pss_DIGEST_LEN + i ];
    }

    /* 11. set the leftmost 8*emLen - emBits bits of the leftmost octect in maskedDB to zero */
    em[ 0x00 ] &= (UINT8)( ( ( (UINT32)1U << ( EscPkcs1Pss_BITS_PER_BYTE - ( ( EscPkcs1Pss_BITS_PER_BYTE * EscPkcs1Pss_EMLEN ) - emBits ) ) ) - 1U ) & 0xffU );

    return hasFailed;
}

/** Function EMSA-PSS-Verify. Verifies the encoding of a message according to EMSA-PSS.
    All quoted references refer to the section 9.1.2 from the PKCS#1 v2.1 standard. */
static BOOL
EscPkcs1Pss_EmsaPssVerify(
    const UINT8 message[],  /** Message to verify */
    const UINT32 messageLen, /** Length of message in byte */
    const UINT8 em[],
    const UINT32 emBits )
{
    BOOL hasFailed;
    UINT8 M1[ EscPkcs1Pss_M1_LEN ];
    UINT8 H1[ EscPkcs1Pss_DIGEST_LEN ];
    UINT8 DB[ EscPkcs1Pss_DB_LEN ];
    UINT32 i;

    /* 1. (hasFailed will turn true if the message exceeds the input of Sha1),
       2. (mHash is stored in M' = EscPkcs1Pss_ZEROES || -> mHash <- || salt) */
    hasFailed = EscPkcs1Pss_Hash( message, messageLen, &M1[ EscPkcs1Pss_ZEROES ] );

    /* 3. If emLen < hLen + sLen + 2, output inconsistent and stop */
#    if ( ( EscPkcs1Pss_DIGEST_LEN + EscPkcs1Pss_SALT_LEN + 2U ) > EscPkcs1Pss_EMLEN )
#        error "EmsaPssVerify: inconsistent (length of encoded message < hash length + seed length + 2)"
#    endif

    /* 4. Test if the rightmost octet of EM has hexadecimal value 0xbc */
    if ( hasFailed == FALSE ) {
        if ( em[ EscPkcs1Pss_EMLEN - 1U ] != EscPkcs1Pss_EMSA_PSS_TRAILER ) {
            hasFailed = TRUE;
        }
    }

    /* 6. Test if the leftmost 8*emLen  emBits bits
       of the leftmost octet in maskedDB are all equal to zero*/
    if ( hasFailed == FALSE ) {
        UINT8 mask = (UINT8)( ( ( (UINT32)1U << ( EscPkcs1Pss_BITS_PER_BYTE - ( ( EscPkcs1Pss_BITS_PER_BYTE * EscPkcs1Pss_EMLEN ) - emBits ) ) ) - 1U ) & 0xffU );
        if ( ( ( em[ 0x00 ] ) | mask ) != mask ) {
            hasFailed = TRUE;
        }
    }

    /* 7. Let dbMask = MGF (H, emLen  hLen  1). */
    hasFailed |= EscPkcs1Pss_MGF1( &em[ EscPkcs1Pss_DB_LEN ], DB, EscPkcs1Pss_DB_LEN );

    /* 8. Let DB = maskedDB xor dbMask. */
    for ( i = 0U; i < EscPkcs1Pss_DB_LEN; i++ ) {
        DB[ i ] ^= em[ i ];
    }

    /* 9.Set the leftmost 8emLen  emBits bits of the leftmost octet in DB to zero. */
    DB[ 0x00 ] &= (UINT8)( ( ( (UINT32)1U << ( 8U - ( ( 8U * EscPkcs1Pss_EMLEN ) - emBits ) ) ) - 1U ) & 0xffU );

    /* 10.a Test if the emLen  hLen  sLen  2 leftmost octets of DB are zero */
    for ( i = 0U; i < ( ( ( EscPkcs1Pss_EMLEN - EscPkcs1Pss_DIGEST_LEN ) - EscPkcs1Pss_SALT_LEN ) - 2U ); i++ ) {
        if ( DB[ i ] != 0U ) {
            hasFailed |= TRUE;
        }
    }

    /* 10.b Test if the octet at position emLen  hLen  sLen  1
       (the leftmost position is position 1) has hexadecimal value 0x01 */
    if ( DB[ ( ( EscPkcs1Pss_EMLEN - EscPkcs1Pss_DIGEST_LEN ) - EscPkcs1Pss_SALT_LEN ) - 2U ] != EscPkcs1Pss_EMSA_PSS_SEPARATOR ) {
        hasFailed = TRUE;
    }

    /* 12. Build M'= (0x)00 00 00 00 00 00 00 00 || mHash (Step 1) || salt*/
    /* leading zeros */
    for ( i = 0U; i < EscPkcs1Pss_ZEROES; i++ ) {
        M1[ i ] = 0x00U;
    }

    /* salt */
    for ( i = 0U; i < ( ( EscPkcs1Pss_M1_LEN - EscPkcs1Pss_ZEROES ) - EscPkcs1Pss_DIGEST_LEN ); i++ ) {
        M1[ i + EscPkcs1Pss_ZEROES + EscPkcs1Pss_DIGEST_LEN ] = DB[ ( EscPkcs1Pss_DB_LEN - EscPkcs1Pss_SALT_LEN ) + i ];
    }

    /* 13. Let H = Hash (M), an octet string of length hLen. */
    hasFailed |= EscPkcs1Pss_Hash( M1, EscPkcs1Pss_M1_LEN, H1 );

    /* 14. Test if H = H */
    for ( i = 0U; ( i < EscPkcs1Pss_DIGEST_LEN ); i++ ) {
        if ( H1[ i ] != em[ EscPkcs1Pss_DB_LEN + i ] ) {
            hasFailed |= TRUE;
        }
    }

    return hasFailed;
}

/** All quoted references refer to the section 8.1.1 from the PKCS#1 v2.1 standard. */
BOOL
EscPkcs1Pss_Sign(
    const EscPkcs1Pss_SignDataT* signData )
{
    BOOL hasFailed = TRUE;
#   ifdef EscRsa_ENABLE_CRT
    EscRsa_KeyPairT keyPair;
#   endif

    if ( ( signData != 0 ) &&
         ( signData->randCtx != 0 ) &&
         ( signData->message != 0 ) &&
         ( signData->privKey != 0 ) &&
         ( signData->signature != 0 ) )
    {
        /* Since we support only moduli where the highest bit is set, we don't need to search for the first bit */

        /* 1. EMSA-PSS encoding */
        hasFailed = EscPkcs1Pss_EmsaPssEncode( signData->randCtx, signData->message, signData->messageLen, signData->signature, ( EscRsa_KEY_BITS - 1U ) );

        if ( hasFailed == FALSE ) {
#   ifdef EscRsa_ENABLE_CRT
            EscRsaFe_FromBytesBE( &keyPair.p, signData->privKey->p, EscRsa_KEY_BYTES / 2U );
            EscRsaFe_FromBytesBE( &keyPair.q, signData->privKey->q, EscRsa_KEY_BYTES / 2U );
            EscRsaFe_FromBytesBE( &keyPair.dmp1, signData->privKey->dmp1, EscRsa_KEY_BYTES / 2U );
            EscRsaFe_FromBytesBE( &keyPair.dmq1, signData->privKey->dmq1, EscRsa_KEY_BYTES / 2U );
            EscRsaFe_FromBytesBE( &keyPair.iqmp, signData->privKey->iqmp, EscRsa_KEY_BYTES / 2U );
            hasFailed = EscRsa_ModExpCrt( signData->signature, signData->signature, &keyPair );
#   else
            /* 2. RSA exponentiation */
            hasFailed = EscRsa_ModExpLong( signData->signature, signData->privKey->modulus, signData->privKey->privExp, signData->signature );
#   endif
        }
    }

    return hasFailed;
}

/** All quoted references refer to the section 8.1.2 from the PKCS#1 v2.1 standard. */
BOOL
EscPkcs1Pss_Verify(
    const EscPkcs1Pss_VerifyDataT* verifyData )
{
    BOOL hasFailed = TRUE;
    UINT8 calculatedMessage[ EscPkcs1Pss_EMLEN ];

    if ( ( verifyData != 0 ) &&
         ( verifyData->message != 0 ) &&
         ( verifyData->pubKey != 0 ) &&
         ( verifyData->signature != 0 ) )
    {
        /* Since we support only moduli where the highest bit is set, we don't need to search for the first bit */

        /* 2. RSA exponentiation */
        hasFailed = EscRsa_ModExp( verifyData->signature, verifyData->pubKey->modulus, verifyData->pubKey->pubExp, calculatedMessage );

        /* 3. EMSA-PSS verification */
        if ( hasFailed == FALSE ) {
            hasFailed = EscPkcs1Pss_EmsaPssVerify( verifyData->message, verifyData->messageLen, calculatedMessage, ( EscRsa_KEY_BITS - 1U ) );
        }
    }

    return hasFailed;
}

#endif
/***************************************************************************
 * 6. END                                                                  *
 ***************************************************************************/
