/* warz, rob */

#include <windows.h>
#include <stdio.h>
#include "CheckRevisionEx.h"

#define SET_MASKED_BITS(x, mask, bits) ((x) = ((x & ~(mask)) | ((bits) & mask)))

int patch_dword(unsigned long AddressToPatch, unsigned long Value) {
    unsigned long OldProtect = 0;
    if(!VirtualProtect((LPVOID)AddressToPatch, 4, PAGE_EXECUTE_READWRITE, &OldProtect))
		return 1;

    *(unsigned long*)AddressToPatch = Value;
    if(!VirtualProtect((LPVOID)AddressToPatch, 4, OldProtect, &OldProtect))
		return 1;
    return 0;
}

int patch_word(unsigned long AddressToPatch, WORD Value) {
    unsigned long OldProtect = 0;
    if(!VirtualProtect((LPVOID)AddressToPatch, 2, PAGE_EXECUTE_READWRITE, &OldProtect))
		return 1;

    *(WORD*)AddressToPatch = Value;
    if(!VirtualProtect((LPVOID)AddressToPatch, 2, OldProtect, &OldProtect))
		return 1;
    return 0;
}

int prepare_backend(HMODULE hLockdown, int file_lock) {
	seed_table patches = seeds[file_lock];
	if(file_lock != 1) {
		int fails = 0;
		fails += patch_dword((DWORD)hLockdown + 0x2067, patches.seed1);
		fails += patch_dword((DWORD)hLockdown + 0x206D, patches.seed2);
		fails += patch_dword((DWORD)hLockdown + 0x2086, patches.seed1);
		fails += patch_word((DWORD)hLockdown + 0x209F, patches.seed1 & 0xFF);
		fails += patch_word((DWORD)hLockdown + 0x20B4, patches.seed1 & 0xFF);
		return fails;
	}
	return 0;
}

unsigned long get_fileversion(const char *file_path) {
	unsigned long dwBytesRead;
	unsigned long dwSize = GetFileVersionInfoSize(file_path, &dwBytesRead);
	unsigned char *lpbBuffer = (unsigned char *)VirtualAlloc(NULL, dwSize, MEM_COMMIT, PAGE_READWRITE);

	if(!lpbBuffer || !GetFileVersionInfo(file_path, NULL, dwSize, lpbBuffer))
		return 0;

	VS_FIXEDFILEINFO *ffi;
	if(!VerQueryValue(lpbBuffer, "\\", (LPVOID*)&ffi, (PUINT)&dwSize))
		return 0;

	unsigned long dwVersion = ((HIWORD(ffi->dwProductVersionMS) & 0xFF) << 24) |
					((LOWORD(ffi->dwProductVersionMS) & 0xFF) << 16) |
					((HIWORD(ffi->dwProductVersionLS) & 0xFF) << 8) |
					(LOWORD(ffi->dwProductVersionLS) & 0xFF);
	VirtualFree(lpbBuffer, 0lu, MEM_RELEASE);
	return dwVersion;
}

void init_context(SHA1_CTX *context, char *shuffled, int len) {
	memset(context->buffer + 0x44, 0x36, 0x40);
	memset(context->buffer + 0x84, 0x5C, 0x40);
	
	SHA1Init(context);
	for (int x = 0; x < len; x++){
		context->buffer[0x44 + x] ^= shuffled[x];
		context->buffer[0x84 + x] ^= shuffled[x];
	}	
	
	SHA1Update(context, (const unsigned char *)context->buffer + 0x44, 0x40);	
}

bool hash_videodump(SHA1_CTX *context, const char *videobuf) {
	char *video_image = new char[30720];
	
	FILE *dump;
		if(fopen_s(&dump, videobuf, "rb"))
			return false;

		fread(video_image, sizeof(BYTE), 30720, dump);
	fclose(dump);

	for(int i = 0; i < 48; i++){
		SHA1Update(context, (unsigned char *)video_image + (i * 640), 0x0D0);		
	}
	delete [] video_image;
	return true;
}

void double_hash(SHA1_CTX *context, unsigned char result[20]) {
	unsigned char output[20];
	SHA1Final(context, output);
	SHA1Init(context);
	SHA1Update(context, (const unsigned char *)context->buffer + 0x84, 0x40);
	SHA1Update(context, output, 0x14);
	SHA1Final(context, result);
}

unsigned long finish_sub(unsigned long *arg_one, unsigned long *arg_two) {
	unsigned long edi = *arg_one;
	unsigned long esi = *arg_two;
	unsigned long eax = (edi & 0xffff);
	unsigned long ecx = (eax & 0xff00) >> 8;
	unsigned long edx = (eax & 0xff);

	ecx += edx;
	SET_MASKED_BITS(edx, 0xff, (ecx & 0xff));
	ecx = ecx >> 8;
	ecx += edx;

	bool cf = false;
	if((ecx & 0xff) == 0xff)
		cf = true;

	SET_MASKED_BITS(ecx, 0xff, ((ecx + 1) & 0xff));
	if(!cf)
		SET_MASKED_BITS(ecx, 0xff, ((ecx - 1) & 0xff));

	eax -= ecx;

	bool zf = false;
	unsigned char ah = (eax & 0xff00) >> 8;

	if(ah == 0xff) {
		SET_MASKED_BITS(eax, 0xff00, 0x0100);
	} else {
		SET_MASKED_BITS(eax, 0xff00, 0x0000);
	}
	
	SET_MASKED_BITS(eax, 0xff, ~((eax) & 0xff) + 1);

	*arg_one = (eax & 0xffff);
	*arg_two = (ecx & 0xffff);
	return *arg_one;
}

int finish(unsigned char *arg_output, unsigned int *arg_output_length, unsigned char *arg_heap, int arg_10h) {
	unsigned char *var_output_byte = arg_output;
	int var_return_value = 0, var_iterations1 = 0;

	for(var_return_value = 1; 1; var_iterations1++) {
		int var_heap_index = arg_10h;
		if(!var_heap_index)
			break;

		do {
			if(*(arg_heap + var_heap_index - 1) != 0)
				break;
			arg_10h = var_heap_index;
		} while(var_heap_index--);
        
		if(!var_heap_index)
			break;

		unsigned long eax = 0;
		unsigned long var_unknown_0C = 0;
		for(var_unknown_0C = 0; 1; eax = var_unknown_0C) {
			var_heap_index--;
			unsigned long cx = *(arg_heap + var_heap_index) & 0xFFFF;
			eax = (eax << 8) + (cx & 0xFFFF);
			unsigned long var_modified_heap_byte = eax;

			finish_sub(&var_modified_heap_byte, &var_unknown_0C);
			*(arg_heap + var_heap_index) = var_modified_heap_byte;
            
			if(var_heap_index <= 0)
				break;
		}

		if(var_iterations1 < *arg_output_length) {
            *var_output_byte = var_unknown_0C + 1;
		} else {
			var_return_value = 0;
		}
		
		var_output_byte++;
	}

	arg_output_length = (unsigned int*)(var_output_byte - arg_output); 
	return var_return_value;
}

int WINAPI CheckRevisionEx(const char *file_game, const char *file_strm, const char *file_bttl,
						   const char *server_hash, unsigned long &out_version, unsigned long &out_checksum,
						   char *out_digest, const char *file_lock, const char *file_vdmp) {
		
	if(HMODULE hLockdown = LoadLibrary("backend.dll")) {
		char *digit_ptr = strchr((char*)file_lock, '.');
		if(!digit_ptr)
			return 1;
		
		int digit_1 = (int)(*(digit_ptr - 1) - '0');
		int digit_2 = (int)(*(digit_ptr - 2) - '0');
		
		if(digit_2 == 1)
			digit_1 += 10;

		if(digit_1 < 0 || digit_1 > 19)
			return 2;

		// prepare the back end file for use with
		// the lockdown version requested by the
		// server
		if(prepare_backend(hLockdown, digit_1))
			return 3;

		// initialize our function pointers so that
		// we can use the two remaining functions
		// that have not been converted
		lcr_shuf shuffle_serverhash	= (lcr_shuf)((char*)hLockdown + 0x1A0E);
		lcr_hash hash_gamefile		= (lcr_hash)((char*)hLockdown + 0x24E1);
		if(!shuffle_serverhash || !hash_gamefile)
			return 4;

		// produce file version that is used directly
		// by the client to server protocol
		unsigned long file_version = get_fileversion(file_game);
		if(!file_version)
			return 5;

		// find the length of the original server hash
		int hash_length = 0;
		hash_length = strlen(server_hash);

		// currently no server hashes have been shorter than
		// 16, or longer than 17, to my knowledge
		if(!hash_length || hash_length < 16 || hash_length > 17)
			return 6;

		// create the 'shuffled' server hash based on
		// the hash provided by the server
		char shuffled_server_hash[64] = {0};
		if(!shuffle_serverhash(shuffled_server_hash, hash_length, server_hash, hash_length))
			return 7;

		// instanciate and initialize the sha context
		// structs
		SHA1_CTX context;
		init_context(&context, shuffled_server_hash, hash_length);

		// load the game files
		HMODULE file_handles[4] = {0};
		file_handles[0] = LoadLibrary(file_lock);
		file_handles[1] = LoadLibrary(file_game);
		file_handles[2] = LoadLibrary(file_strm);
		file_handles[3] = LoadLibrary(file_bttl);

		if(!file_handles[0] || !file_handles[1] || !file_handles[2] || !file_handles[3])
			return 8;

		// perform the hashing of the game files
		// and then unload each file
		for(int x = 0; x <= 3; x++) {
			hash_gamefile(&context, file_handles[x]);
			FreeLibrary(file_handles[x]);
		}

		// hash video image
		if(!hash_videodump(&context, file_vdmp))
			return 9;

		// hash battle, and starcraft check function return values
		SHA1Update(&context, (const unsigned char *)"\x01\x00\x00\x00", 4);
		SHA1Update(&context, (const unsigned char *)"\x00\x00\x00\x00", 4);

		// produce checksum
		unsigned long checksum;
		unsigned char dblhash_result[20];
		double_hash(&context, dblhash_result);
		memmove(&out_checksum, dblhash_result, 4);

		// produce digest
		unsigned int ret = 17;
		if(!finish((unsigned char*)out_digest, &ret, dblhash_result + 4, 16))
			return 10;

		out_version = file_version;

		FreeLibrary(hLockdown);
		return 0;
	}

	return 11;
}