#include <cstdlib>
#include <cstring>
#include <fstream>

#include <unistd.h>
#include <fcntl.h>
#include <sys/mman.h>
#include <sys/stat.h>

#include <openssl/pem.h>
#include <openssl/cms.h>
#include <openssl/err.h>
#include <openssl/rsa.h>

#include "util/swu_targetKey.h"
#include "util/swu_certificate.h"
#include "util/swu_constants.hpp"
#include "util/swu_filesystem.h"
#include "util/swu_sourceNor.hpp"
#include "util/swu_sourceSdc.hpp"

#include "util/swu_trace.h"
#ifdef VARIANT_S_FTR_ENABLE_TRC_GEN
#define ETG_DEFAULT_TRACE_CLASS TR_CLASS_SWUPDATE_UTIL
#include "trcGenProj/Header/swu_targetKey.cpp.trc.h"
#endif 

namespace swu {

// short name for constants class
typedef swu::Constants::Mmc MMC;
typedef swu::Constants::SDC SDC;

// --------------------------------------------------------- class CTargetKeyIf
CTargetKeyIf::~CTargetKeyIf() 
{
   if(_rsa) RSA_free(_rsa);
}

bool CTargetKeyIf::decrypt(const std::vector<SWU_BYTE>& encrypted, std::string &decrypted)
{
   if( (!load()) || (0 == _rsa) ) {
      ETG_TRACE_ERR(("RSA key is not set"));
      return false;
   }
   SWU_BYTE* decBuf = new SWU_BYTE[RSA_size(_rsa)];
   int len =  RSA_private_decrypt(static_cast<int> (encrypted.size() ), &encrypted[0], decBuf, _rsa, RSA_PKCS1_OAEP_PADDING) ;
   if(-1 == len) {
      delete[] decBuf;
      ETG_TRACE_ERR(("RSA_private_decrypt failure"));
      return false;
   }
   decrypted.assign((char *) decBuf, len);
   
   delete[] decBuf;

   return true;
}

bool CTargetKeyIf::decryptCMS(const std::string& encrypted, std::string &decrypted)
{
   if( (!load()) || (0 == _rsa) ) {
      ETG_TRACE_ERR(("RSA key is not set"));
      return false;
   }

   BIO* enc_bio = 0;
   BIO* dec_bio = 0;
   EVP_PKEY* pkey = 0;
   CMS_ContentInfo* cms = 0;
   char* enc_s = 0;
   char* dec_s = 0;
   long len = 0;

   bool result = true;
   do { // this loop emulates "goto cleanup;" in error handling
      pkey = EVP_PKEY_new();
      if (pkey == 0) {
         ETG_TRACE_ERR(("Can't create EVP key object"));
         result = false; break;
      }
      
      int res = EVP_PKEY_set1_RSA(pkey, _rsa);
      if ( res < 1 ) {
         ETG_TRACE_ERR(("Can't set decryption key."));
         result = false; break;
      }

      enc_s = strdup(encrypted.c_str());
      enc_bio = BIO_new_mem_buf(enc_s, -1);
      dec_bio = BIO_new(BIO_s_mem());
      if (0 == enc_bio || 0 == dec_bio) {
         ETG_TRACE_ERR(("Can't create memory BIO."));
         result = false; break;
      }
     
      cms = SMIME_read_CMS(enc_bio, NULL);
      if (0 == cms) {
         ETG_TRACE_ERR(("Can't parse CMS for decryption. %s", decrypted.c_str()));
         result = false; break;
      }

      if(1 > CMS_decrypt(cms, pkey, NULL, NULL, dec_bio, 0)) {
         ETG_TRACE_ERR(("Failed to decrypt CMS."));
         result = false; break;
      }
      len = BIO_get_mem_data(dec_bio, &dec_s);
      if (len < 1) {
         ETG_TRACE_ERR(("Could not get data from decrypted BIO."));
         result = false; break;
      }
      decrypted.assign(dec_s, (size_t)len);
   } while (false);

   if (!result) {
      ETG_TRACE_ERR(("OpenSSL error: %s", ERR_error_string(ERR_get_error(), NULL)));
   }

   // cleanup:
   if (cms) CMS_ContentInfo_free(cms);
   if (pkey) EVP_PKEY_free(pkey);
   if (dec_bio) BIO_free(dec_bio);
   if (enc_bio) BIO_free(enc_bio);
   if (enc_s) free(enc_s);

   return result;
}

bool CTargetKeyIf::readPemOrDer (const std::vector< SWU_BYTE >& keyData)
{
   if (isPEM(keyData)) {
      // we explicitly cast const away here, since BIO_new_mem_buf accepts 
      // no const void* as parameter, although it treats it as read-only
      // memory. This will change with OpenSSL 1.0.2g
      SWU_BYTE *data = const_cast<SWU_BYTE*>( &(keyData[0]) );
      BIO *keyBio = BIO_new_mem_buf(data, -1);
      if (! keyBio) {
         ETG_TRACE_ERR(( "Cannot create private key BIO." ));
         return false;
      }
      _rsa = PEM_read_bio_RSAPrivateKey(keyBio, 0, 0, 0);
      BIO_free(keyBio);
      if (! _rsa) {
         ETG_TRACE_ERR(("Error when parsing PEM RSA key."));
         return false;
      }
   }
   else {
      const SWU_BYTE *data = &(keyData[0]);
      _rsa = d2i_RSAPrivateKey(0, &data, keyData.size());
      if ( ! _rsa ) {
         ETG_TRACE_ERR(("Error when parsing DER RSA key."));
         return false;
      }
   }
   return true;
}
// -------------------------------------------------------- class CNORTargetKey
NORTargetKey::NORTargetKey()
{ }

bool NORTargetKey::load()
{
   if (_loaded) { return true; }

   SourceNOR nor;
   bool readSuccessfully = false;
   bool isValid = false;
   std::vector<SWU_BYTE> data;

   ETG_TRACE_USR3(( "Trying to read target key from first memory location." ));
   readSuccessfully = nor.read(
         data,
         MMC::MTD_TARGET_KEY1_OFFSET,
         MMC::MTD_KEY_BUFFER_SIZE);
   if (readSuccessfully and setTargetKey(data)) {
      return true;
   }

   ETG_TRACE_USR3(( "Trying to read target key from second memory location." ));
   readSuccessfully = nor.read(
         data,
         MMC::MTD_TARGET_KEY2_OFFSET,
         MMC::MTD_KEY_BUFFER_SIZE);
   if (readSuccessfully and setTargetKey(data)) {
      return true;
   }

   return false;
}

bool NORTargetKey::setTargetKey(const std::vector<SWU_BYTE>& data)
{
   std::vector<SWU_BYTE>::const_iterator it = data.begin();
   for ( ; it != data.end() and *it == MMC::MTD_ERASED_MASK; ++it );
   if(it == data.end()) {
      ETG_TRACE_USR3(("Target key is erased (all bytes == 0xFF)")); 
      ETG_TRACE_USR3(("Length of data: %i", data.size())); 
      return false;
   }

   it = data.begin();
   for ( ; it != data.end() and *it == MMC::MTD_ZEROED_MASK; ++it );
   if(it == data.end()) {
      ETG_TRACE_USR3(("Target key is zeroed (all bytes == 0x00)"));
      ETG_TRACE_USR3(("Length of data: %i", data.size())); 
      return false;
   }

   _loaded = readPemOrDer(data);
   return _loaded;
}

// --------------------------------------------------------- class SdcTargetKey
SdcTargetKey::SdcTargetKey()
{ }

bool SdcTargetKey::load ()
{
   if (_loaded) { return true; }

   SourceSdc sdc;
   std::vector<SWU_BYTE> data;
   char const* privkey = SDC::PATH_PRIVKEY;

   // We have two possible paths. If a path exists, we take it.
   // If the first path exists, but the key can't be read from it, we don't try
   // the second path.
   if ( ! swu::exists(privkey) ) {
      ETG_TRACE_USR3(("No private key installed in path %s", privkey));
      privkey = SDC::PSA_PATH_PRIVKEY;
      ETG_TRACE_USR3(("Trying alternative path %s", privkey));
      if ( ! swu::exists(privkey) ) {
         ETG_TRACE_USR3(("No private key installed."));
         return false;
      }
   }

   if (! sdc.read(data, privkey) ) {
      ETG_TRACE_ERR(( "Could not extract Private Key from SDC encrypted file %s", privkey));
      return false;
   }

   _loaded = readPemOrDer(data);
   return _loaded;
}

// --------------------------------------------------------- class FileTargetKey
FileTargetKey::FileTargetKey()
{ }

bool FileTargetKey::load() 
{
   if (_loaded) { return true; }

   std::vector<SWU_BYTE> data;
   char const* privkey = SDC::PATH_PRIVKEY;

   // We have two possible paths. If a path exists, we take it.
   // If the first path exists, but the key can't be read from it, we don't try
   // the second path.
   if ( ! swu::exists(privkey) ) {
      ETG_TRACE_USR3(("No private key installed in path %s", privkey));
      privkey = SDC::PSA_PATH_PRIVKEY;
      ETG_TRACE_USR3(("Trying alternative path %s", privkey));
      if ( ! swu::exists(privkey) ) {
         ETG_TRACE_USR3(("No private key installed."));
         return false;
      }
   }

   if (! swu::loadFile(privkey, data) ) {
      ETG_TRACE_ERR(( "Could not read the private key from file %s", privkey));
      return false;
   }

   _loaded = readPemOrDer(data);
   return _loaded;
}

}
