forked from eden-emu/eden
		
	Extract mbedtls to cpp file
This commit is contained in:
		
							parent
							
								
									83c3ae8be8
								
							
						
					
					
						commit
						22342487e8
					
				
					 5 changed files with 126 additions and 86 deletions
				
			
		|  | @ -739,7 +739,7 @@ const std::string& GetUserPath(UserPath path, const std::string& new_path) { | |||
| std::string GetHactoolConfigurationPath() { | ||||
| #ifdef _WIN32 | ||||
|     char path[MAX_PATH]; | ||||
|     if (SHGetFolderPathA(NULL, CSIDL_PROFILE, NULL, 0, path) != S_OK) | ||||
|     if (SHGetFolderPathA(nullptr, CSIDL_PROFILE, nullptr, 0, path) != S_OK) | ||||
|         return ""; | ||||
|     std::string local_path = Common::StringFromFixedZeroTerminatedBuffer(path, MAX_PATH); | ||||
|     return local_path + "\\.switch"; | ||||
|  |  | |||
|  | @ -12,6 +12,7 @@ add_library(core STATIC | |||
|     core_timing.h | ||||
|     core_timing_util.cpp | ||||
|     core_timing_util.h | ||||
|     crypto/aes_util.cpp | ||||
|     crypto/aes_util.h | ||||
|     crypto/encryption_layer.cpp | ||||
|     crypto/encryption_layer.h | ||||
|  |  | |||
|  | @ -2,5 +2,103 @@ | |||
| // Licensed under GPLv2 or any later version
 | ||||
| // Refer to the license.txt file included.
 | ||||
| 
 | ||||
| namespace Crypto { | ||||
| } // namespace Crypto
 | ||||
| #include "core/crypto/aes_util.h" | ||||
| #include "mbedtls/cipher.h" | ||||
| 
 | ||||
| namespace Core::Crypto { | ||||
| static_assert(static_cast<size_t>(Mode::CTR) == static_cast<size_t>(MBEDTLS_CIPHER_AES_128_CTR), "CTR mode is incorrect."); | ||||
| static_assert(static_cast<size_t>(Mode::ECB) == static_cast<size_t>(MBEDTLS_CIPHER_AES_128_ECB), "ECB mode is incorrect."); | ||||
| static_assert(static_cast<size_t>(Mode::XTS) == static_cast<size_t>(MBEDTLS_CIPHER_AES_128_XTS), "XTS mode is incorrect."); | ||||
| 
 | ||||
| template<typename Key, size_t KeySize> | ||||
| Crypto::AESCipher<Key, KeySize>::AESCipher(Key key, Mode mode) { | ||||
|     mbedtls_cipher_init(encryption_context.get()); | ||||
|     mbedtls_cipher_init(decryption_context.get()); | ||||
| 
 | ||||
|     ASSERT_MSG((mbedtls_cipher_setup( | ||||
|             encryption_context.get(), | ||||
|             mbedtls_cipher_info_from_type(static_cast<mbedtls_cipher_type_t>(mode))) || | ||||
|                 mbedtls_cipher_setup(decryption_context.get(), | ||||
|                                      mbedtls_cipher_info_from_type( | ||||
|                                              static_cast<mbedtls_cipher_type_t>(mode)))) == 0, | ||||
|                "Failed to initialize mbedtls ciphers."); | ||||
| 
 | ||||
|     ASSERT( | ||||
|             !mbedtls_cipher_setkey(encryption_context.get(), key.data(), KeySize * 8, MBEDTLS_ENCRYPT)); | ||||
|     ASSERT( | ||||
|             !mbedtls_cipher_setkey(decryption_context.get(), key.data(), KeySize * 8, MBEDTLS_DECRYPT)); | ||||
|     //"Failed to set key on mbedtls ciphers.");
 | ||||
| } | ||||
| 
 | ||||
| template<typename Key, size_t KeySize> | ||||
| AESCipher<Key, KeySize>::~AESCipher() { | ||||
|     mbedtls_cipher_free(encryption_context.get()); | ||||
|     mbedtls_cipher_free(decryption_context.get()); | ||||
| } | ||||
| 
 | ||||
| template<typename Key, size_t KeySize> | ||||
| void AESCipher<Key, KeySize>::SetIV(std::vector<u8> iv) { | ||||
|     ASSERT_MSG((mbedtls_cipher_set_iv(encryption_context.get(), iv.data(), iv.size()) || | ||||
|                 mbedtls_cipher_set_iv(decryption_context.get(), iv.data(), iv.size())) == 0, | ||||
|                "Failed to set IV on mbedtls ciphers."); | ||||
| } | ||||
| 
 | ||||
| template<typename Key, size_t KeySize> | ||||
| void AESCipher<Key, KeySize>::Transcode(const u8* src, size_t size, u8* dest, Op op)  { | ||||
|     size_t written = 0; | ||||
| 
 | ||||
|     const auto context = op == Op::Encrypt ? encryption_context.get() : decryption_context.get(); | ||||
| 
 | ||||
|     mbedtls_cipher_reset(context); | ||||
| 
 | ||||
|     if (mbedtls_cipher_get_cipher_mode(context) == MBEDTLS_MODE_XTS) { | ||||
|         mbedtls_cipher_update(context, src, size, | ||||
|                               dest, &written); | ||||
|         if (written != size) | ||||
|             LOG_WARNING(Crypto, "Not all data was decrypted requested={:016X}, actual={:016X}.", | ||||
|                         size, written); | ||||
|     } else { | ||||
|         const auto block_size = mbedtls_cipher_get_block_size(context); | ||||
| 
 | ||||
|         for (size_t offset = 0; offset < size; offset += block_size) { | ||||
|             auto length = std::min<size_t>(block_size, size - offset); | ||||
|             mbedtls_cipher_update(context, src + offset, length, | ||||
|                                   dest + offset, &written); | ||||
|             if (written != length) | ||||
|                 LOG_WARNING(Crypto, | ||||
|                             "Not all data was decrypted requested={:016X}, actual={:016X}.", | ||||
|                             length, written); | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     mbedtls_cipher_finish(context, nullptr, nullptr); | ||||
| } | ||||
| 
 | ||||
| template<typename Key, size_t KeySize> | ||||
| void AESCipher<Key, KeySize>::XTSTranscode(const u8* src, size_t size, u8* dest, size_t sector_id, size_t sector_size, | ||||
|                                            Op op) { | ||||
|     if (size % sector_size > 0) { | ||||
|         LOG_CRITICAL(Crypto, "Data size must be a multiple of sector size."); | ||||
|         return; | ||||
|     } | ||||
| 
 | ||||
|     for (size_t i = 0; i < size; i += sector_size) { | ||||
|         SetIV(CalculateNintendoTweak(sector_id++)); | ||||
|         Transcode<u8, u8>(src + i, sector_size, | ||||
|                           dest + i, op); | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| template<typename Key, size_t KeySize> | ||||
| std::vector<u8> AESCipher<Key, KeySize>::CalculateNintendoTweak(size_t sector_id) { | ||||
|     std::vector<u8> out(0x10); | ||||
|     for (size_t i = 0xF; i <= 0xF; --i) { | ||||
|         out[i] = sector_id & 0xFF; | ||||
|         sector_id >>= 8; | ||||
|     } | ||||
|     return out; | ||||
| } | ||||
| 
 | ||||
| template class AESCipher<Key128>; | ||||
| template class AESCipher<Key256>; | ||||
| } | ||||
|  | @ -6,113 +6,53 @@ | |||
| 
 | ||||
| #include "common/assert.h" | ||||
| #include "core/file_sys/vfs.h" | ||||
| #include "mbedtls/cipher.h" | ||||
| 
 | ||||
| namespace Crypto { | ||||
| namespace Core::Crypto { | ||||
| 
 | ||||
| enum class Mode { | ||||
|     CTR = MBEDTLS_CIPHER_AES_128_CTR, | ||||
|     ECB = MBEDTLS_CIPHER_AES_128_ECB, | ||||
|     XTS = MBEDTLS_CIPHER_AES_128_XTS, | ||||
|     CTR = 11, | ||||
|     ECB = 2, | ||||
|     XTS = 70, | ||||
| }; | ||||
| 
 | ||||
| enum class Op { | ||||
|     ENCRYPT, | ||||
|     DECRYPT, | ||||
|     Encrypt, | ||||
|     Decrypt, | ||||
| }; | ||||
| 
 | ||||
| struct mbedtls_cipher_context_t; | ||||
| 
 | ||||
| template <typename Key, size_t KeySize = sizeof(Key)> | ||||
| struct AESCipher { | ||||
| class AESCipher { | ||||
|     static_assert(std::is_same_v<Key, std::array<u8, KeySize>>, "Key must be std::array of u8."); | ||||
|     static_assert(KeySize == 0x10 || KeySize == 0x20, "KeySize must be 128 or 256."); | ||||
| 
 | ||||
|     AESCipher(Key key, Mode mode) { | ||||
|         mbedtls_cipher_init(&encryption_context); | ||||
|         mbedtls_cipher_init(&decryption_context); | ||||
| public: | ||||
|     AESCipher(Key key, Mode mode); | ||||
| 
 | ||||
|         ASSERT_MSG((mbedtls_cipher_setup( | ||||
|                         &encryption_context, | ||||
|                         mbedtls_cipher_info_from_type(static_cast<mbedtls_cipher_type_t>(mode))) || | ||||
|                     mbedtls_cipher_setup(&decryption_context, | ||||
|                                          mbedtls_cipher_info_from_type( | ||||
|                                              static_cast<mbedtls_cipher_type_t>(mode)))) == 0, | ||||
|                    "Failed to initialize mbedtls ciphers."); | ||||
|     ~AESCipher(); | ||||
| 
 | ||||
|         ASSERT( | ||||
|             !mbedtls_cipher_setkey(&encryption_context, key.data(), KeySize * 8, MBEDTLS_ENCRYPT)); | ||||
|         ASSERT( | ||||
|             !mbedtls_cipher_setkey(&decryption_context, key.data(), KeySize * 8, MBEDTLS_DECRYPT)); | ||||
|         //"Failed to set key on mbedtls ciphers.");
 | ||||
|     } | ||||
| 
 | ||||
|     ~AESCipher() { | ||||
|         mbedtls_cipher_free(&encryption_context); | ||||
|         mbedtls_cipher_free(&decryption_context); | ||||
|     } | ||||
| 
 | ||||
|     void SetIV(std::vector<u8> iv) { | ||||
|         ASSERT_MSG((mbedtls_cipher_set_iv(&encryption_context, iv.data(), iv.size()) || | ||||
|                     mbedtls_cipher_set_iv(&decryption_context, iv.data(), iv.size())) == 0, | ||||
|                    "Failed to set IV on mbedtls ciphers."); | ||||
|     } | ||||
|     void SetIV(std::vector<u8> iv); | ||||
| 
 | ||||
|     template <typename Source, typename Dest> | ||||
|     void Transcode(const Source* src, size_t size, Dest* dest, Op op) { | ||||
|         size_t written = 0; | ||||
| 
 | ||||
|         const auto context = op == Op::ENCRYPT ? &encryption_context : &decryption_context; | ||||
| 
 | ||||
|         mbedtls_cipher_reset(context); | ||||
| 
 | ||||
|         if (mbedtls_cipher_get_cipher_mode(context) == MBEDTLS_MODE_XTS) { | ||||
|             mbedtls_cipher_update(context, reinterpret_cast<const u8*>(src), size, | ||||
|                                   reinterpret_cast<u8*>(dest), &written); | ||||
|             if (written != size) | ||||
|                 LOG_WARNING(Crypto, "Not all data was decrypted requested={:016X}, actual={:016X}.", | ||||
|                             size, written); | ||||
|         } else { | ||||
|             const auto block_size = mbedtls_cipher_get_block_size(context); | ||||
| 
 | ||||
|             for (size_t offset = 0; offset < size; offset += block_size) { | ||||
|                 auto length = std::min<size_t>(block_size, size - offset); | ||||
|                 mbedtls_cipher_update(context, reinterpret_cast<const u8*>(src) + offset, length, | ||||
|                                       reinterpret_cast<u8*>(dest) + offset, &written); | ||||
|                 if (written != length) | ||||
|                     LOG_WARNING(Crypto, | ||||
|                                 "Not all data was decrypted requested={:016X}, actual={:016X}.", | ||||
|                                 length, written); | ||||
|             } | ||||
|         Transcode(reinterpret_cast<const u8*>(src), size, reinterpret_cast<u8*>(dest), op); | ||||
|     } | ||||
| 
 | ||||
|         mbedtls_cipher_finish(context, nullptr, nullptr); | ||||
|     } | ||||
|     void Transcode(const u8* src, size_t size, u8* dest, Op op); | ||||
| 
 | ||||
|     template <typename Source, typename Dest> | ||||
|     void XTSTranscode(const Source* src, size_t size, Dest* dest, size_t sector_id, | ||||
|                       size_t sector_size, Op op) { | ||||
|         if (size % sector_size > 0) { | ||||
|             LOG_CRITICAL(Crypto, "Data size must be a multiple of sector size."); | ||||
|             return; | ||||
|         XTSTranscode(reinterpret_cast<const u8*>(src), size, reinterpret_cast<u8*>(dest), sector_id, sector_size, op); | ||||
|     } | ||||
| 
 | ||||
|         for (size_t i = 0; i < size; i += sector_size) { | ||||
|             SetIV(CalculateNintendoTweak(sector_id++)); | ||||
|             Transcode<u8, u8>(reinterpret_cast<const u8*>(src) + i, sector_size, | ||||
|                               reinterpret_cast<u8*>(dest) + i, op); | ||||
|         } | ||||
|     } | ||||
|     void XTSTranscode(const u8* src, size_t size, u8* dest, size_t sector_id, size_t sector_size, Op op); | ||||
| 
 | ||||
| private: | ||||
|     mbedtls_cipher_context_t encryption_context; | ||||
|     mbedtls_cipher_context_t decryption_context; | ||||
|     std::unique_ptr<mbedtls_cipher_context_t> encryption_context; | ||||
|     std::unique_ptr<mbedtls_cipher_context_t> decryption_context; | ||||
| 
 | ||||
|     static std::vector<u8> CalculateNintendoTweak(size_t sector_id) { | ||||
|         std::vector<u8> out(0x10); | ||||
|         for (size_t i = 0xF; i <= 0xF; --i) { | ||||
|             out[i] = sector_id & 0xFF; | ||||
|             sector_id >>= 8; | ||||
|         } | ||||
|         return out; | ||||
|     } | ||||
|     static std::vector<u8> CalculateNintendoTweak(size_t sector_id); | ||||
| }; | ||||
| } // namespace Crypto
 | ||||
|  |  | |||
|  | @ -9,6 +9,7 @@ | |||
| #include <string> | ||||
| #include <vector> | ||||
| 
 | ||||
| #include "core/loader/loader.h" | ||||
| #include "common/common_funcs.h" | ||||
| #include "common/common_types.h" | ||||
| #include "common/swap.h" | ||||
|  | @ -108,7 +109,7 @@ private: | |||
| 
 | ||||
|     Crypto::Key128 GetKeyAreaKey(NCASectionCryptoType type); | ||||
| 
 | ||||
|     VirtualFile Decrypt(NCASectionHeader header, VirtualFile in, size_t starting_offset); | ||||
|     VirtualFile Decrypt(NCASectionHeader header, VirtualFile in, u64 starting_offset); | ||||
| }; | ||||
| 
 | ||||
| } // namespace FileSys
 | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue
	
	 Zach Hilman
						Zach Hilman