diff --git a/firmware_tools/ghidra/vxhunter_analysis.py b/firmware_tools/ghidra/vxhunter_analysis.py index b9b5862..ad9d56d 100755 --- a/firmware_tools/ghidra/vxhunter_analysis.py +++ b/firmware_tools/ghidra/vxhunter_analysis.py @@ -5,357 +5,293 @@ from ghidra.program.model.symbol import RefType, SourceType -def analyze_bss(): - print('{:-^60}'.format('analyze bss info')) - target_function = getFunction("bzero") - if not target_function: - target_function = getFunction("_bzero") - if target_function: - parms_data = dump_call_parm_value(call_address=target_function.getEntryPoint(), search_functions=['sysStart', - 'usrInit', - '_sysStart', - '_usrInit', - ]) - for call_addr in parms_data: - call_parms = parms_data[call_addr] - # print(call_parms) - bss_start_address = call_parms['parms']['parm_1']['parm_value'] - print("bss_start_address: {}".format(hex(bss_start_address))) - bss_length = call_parms['parms']['parm_2']['parm_value'] - if not bss_length: - print("Can't calculate bss length.") - return - print("bss_end_address: {}".format(hex(bss_start_address + bss_length - 1))) - print("bss_length: {}".format(hex(bss_length))) - if not is_address_in_current_program(toAddr(bss_start_address)): - print("bss block not in current program, adding...") - if create_initialized_block(block_name=".bss", start_address=toAddr(bss_start_address), - length=bss_length): - print("bss block created") - else: - print("Can't create bss block, you can create it manually") - - else: - print("Can't find bzero function in firmware") - - print('{}\r\n'.format("-" * 60)) - - -def analyze_login_accouts(): - hard_coded_accounts = {} - print("{:-^60}".format("analyze loginUserAdd function")) - target_function = getFunction("loginUserAdd") - if not target_function: - target_function = getFunction("_loginUserAdd") - if target_function: - parms_data = dump_call_parm_value(target_function.getEntryPoint()) - for call_addr in parms_data: - call_parms = parms_data[call_addr] - parm_data_string = "" - user_name = call_parms["parms"]["parm_1"]["parm_data"] - if isinstance(user_name, DataDB): - user_name = user_name.value - pass_hash = call_parms["parms"]["parm_2"]["parm_data"] - if isinstance(pass_hash, DataDB): - pass_hash = pass_hash.value - if user_name or pass_hash: - hard_coded_accounts[call_parms["call_addr"]] = { - "user_name": user_name, - "pass_hash": pass_hash - } - - for parm in sorted(call_parms['parms'].keys()): - parm_value = call_parms['parms'][parm]['parm_value'] - parm_data = call_parms['parms'][parm]['parm_data'] - if parm_value: - parm_data_string += "{}({:#010x}), ".format(parm_data, parm_value) - else: - # Handle None type - parm_data_string += "{}({}), ".format(parm_data, parm_value) - # remove end ', ' - parm_data_string = parm_data_string.strip(', ') - logger.debug("{}({}) at {:#010x} in {}({:#010x})".format(target_function.name, parm_data_string, - call_parms['call_addr'].offset, - call_parms['refrence_function_name'], - call_parms['refrence_function_addr'].offset - )) - else: - print("Can't find loginUserAdd function in firmware") - - print("Found {} hard coded accounts".format(len(hard_coded_accounts))) - for account in hard_coded_accounts: - print("user_name: {}, pass_hash: {}, added at address: {}".format( - hard_coded_accounts[account]['user_name'], - hard_coded_accounts[account]['pass_hash'], - hex(account.offset) - )) - - print('{}\r\n'.format("-" * 60)) - - -def analyze_service(): - service_status = {} - print('{:-^60}'.format('analyze services')) - for service in sorted(vxworks_service_keyword.keys()): - service_status[service] = "Not available" - for service_function in vxworks_service_keyword[service]: - target_function = getFunction(service_function) - if not target_function: - target_function = getFunction("_{}".format(service_function)) - if target_function: - # print("Found {} in firmware, service {} might available".format(service_function, service)) - service_status[service] = "available" - - for service in sorted(service_status.items(), key=lambda x: x[1], reverse=True): - print('{}: {}'.format(service[0], service[1])) - print('{}\r\n'.format("-" * 60)) - - -def add_symbol(symbol_name, symbol_name_address, symbol_address, symbol_type): - symbol_name_address = toAddr(symbol_name_address) - symbol_address = toAddr(symbol_address) - - # Get symbol_name - if getDataAt(symbol_name_address): - logger.debug("removeDataAt: %s" % symbol_name_address) - removeDataAt(symbol_name_address) - - if getInstructionAt(symbol_address): - logger.debug("removeInstructionAt: %s" % symbol_address) - removeInstructionAt(symbol_address) - - try: - symbol_name_string = createAsciiString(symbol_name_address).getValue() - logger.debug("symbol_name_string: %s" % symbol_name_string) - - except CodeUnitInsertionException as err: - logger.debug("Got CodeUnitInsertionException: {}".format(err)) - symbol_name_string = symbol_name - - except: - return - - # Demangle symName - try: - # Demangle symName - sym_demangled_name = None - if can_demangle: - try: - sym_demangled = demangler.demangle(symbol_name_string, True) - - if not sym_demangled: - # some mangled function name didn't start with mangled prefix - sym_demangled = demangler.demangle(symbol_name_string, False) - - if not sym_demangled: - # Temp fix to handle _ prefix function name by remove _ prefix before demangle - sym_demangled = demangler.demangle(symbol_name_string[1:], False) - - if sym_demangled: - sym_demangled_name = sym_demangled.getSignature(False) - - except DemangledException as err: - sym_demangled_name = None - - if sym_demangled_name: - logger.debug("sym_demangled_name: %s" % sym_demangled_name) - - if symbol_name_string and (symbol_type in need_create_function): - logger.debug("Start disassemble function %s at address %s" % (symbol_name_string, symbol_address.toString())) - disassemble(symbol_address) - # TODO: find out why createFunction didn't set the function name. - function = createFunction(symbol_address, symbol_name_string) - # use createLabel to rename function for now. - createLabel(symbol_address, symbol_name_string, True) - if function and sym_demangled_name: - # Add demangled string to comment - codeUnit = listing.getCodeUnitAt(symbol_address) - codeUnit.setComment(codeUnit.PLATE_COMMENT, sym_demangled_name) - # Rename function - function_return, function_name, function_parameters = demangle_function(sym_demangled_name) - logger.debug("Demangled function name is: %s" % function_name) - logger.debug("Demangled function return is: %s" % function_return) - logger.debug("Demangled function parameters is: %s" % function_parameters) - function.setName(function_name, SourceType.USER_DEFINED) - # Todo: Add parameters later +class VxAnalyzer(object): + def __init__(self, logger=None): + self._vx_version = None + + if logger is None: + self.logger = logging.getLogger('target') + self.logger.setLevel(logging.INFO) + consolehandler = logging.StreamHandler() + console_format = logging.Formatter('[%(levelname)-8s][%(module)s.%(funcName)s] %(message)s') + consolehandler.setFormatter(console_format) + self.logger.addHandler(consolehandler) else: - createLabel(symbol_address, symbol_name_string, True) - if sym_demangled_name: - codeUnit = listing.getCodeUnitAt(symbol_address) - codeUnit.setComment(codeUnit.PLATE_COMMENT, sym_demangled_name) + self.logger = logger + + def analyze_bss(self): + print('{:-^60}'.format('analyze bss info')) + target_function = getFunction("bzero") + if not target_function: + target_function = getFunction("_bzero") + if target_function: + parms_data = dump_call_parm_value(call_address=target_function.getEntryPoint(), search_functions=['sysStart', + 'usrInit', + '_sysStart', + '_usrInit', + ]) + for call_addr in parms_data: + call_parms = parms_data[call_addr] + # print(call_parms) + bss_start_address = call_parms['parms']['parm_1']['parm_value'] + print("bss_start_address: {}".format(hex(bss_start_address))) + bss_length = call_parms['parms']['parm_2']['parm_value'] + if not bss_length: + print("Can't calculate bss length.") + return + print("bss_end_address: {}".format(hex(bss_start_address + bss_length - 1))) + print("bss_length: {}".format(hex(bss_length))) + if not is_address_in_current_program(toAddr(bss_start_address)): + print("bss block not in current program, adding...") + if create_initialized_block(block_name=".bss", start_address=toAddr(bss_start_address), + length=bss_length): + print("bss block created") + else: + print("Can't create bss block, you can create it manually") - except Exception as err: - logger.debug("Create function Failed: %s" % err) + else: + print("Can't find bzero function in firmware") - except: - logger.debug("Create function Failed: Java error") + print('{}\r\n'.format("-" * 60)) + def analyze_login_accouts(self): + hard_coded_accounts = {} + print("{:-^60}".format("analyze loginUserAdd function")) + target_function = getFunction("loginUserAdd") + if not target_function: + target_function = getFunction("_loginUserAdd") + if target_function: + parms_data = dump_call_parm_value(target_function.getEntryPoint()) + for call_addr in parms_data: + call_parms = parms_data[call_addr] + parm_data_string = "" + user_name = call_parms["parms"]["parm_1"]["parm_data"] + if isinstance(user_name, DataDB): + user_name = user_name.value + pass_hash = call_parms["parms"]["parm_2"]["parm_data"] + if isinstance(pass_hash, DataDB): + pass_hash = pass_hash.value + if user_name or pass_hash: + hard_coded_accounts[call_parms["call_addr"]] = { + "user_name": user_name, + "pass_hash": pass_hash + } + + for parm in sorted(call_parms['parms'].keys()): + parm_value = call_parms['parms'][parm]['parm_value'] + parm_data = call_parms['parms'][parm]['parm_data'] + if parm_value: + parm_data_string += "{}({:#010x}), ".format(parm_data, parm_value) + else: + # Handle None type + parm_data_string += "{}({}), ".format(parm_data, parm_value) + # remove end ', ' + parm_data_string = parm_data_string.strip(', ') + logger.debug("{}({}) at {:#010x} in {}({:#010x})".format(target_function.name, parm_data_string, + call_parms['call_addr'].offset, + call_parms['refrence_function_name'], + call_parms['refrence_function_addr'].offset + )) + else: + print("Can't find loginUserAdd function in firmware") -def fix_symbol_by_chains(head, tail, vx_version): - symbol_interval = 0x10 - dt = vx_5_symtbl_dt - if vx_version == 6: - symbol_interval = 20 - dt = vx_6_symtbl_dt - ea = head - while True: - prev_symbol_addr = toAddr(getInt(ea)) - symbol_name_address = getInt(ea.add(0x04)) - symbol_dest_address = getInt(ea.add(0x08)) - symbol_type = getByte(ea.add(symbol_interval - 2)) + print("Found {} hard coded accounts".format(len(hard_coded_accounts))) + for account in hard_coded_accounts: + print("user_name: {}, pass_hash: {}, added at address: {}".format( + hard_coded_accounts[account]['user_name'], + hard_coded_accounts[account]['pass_hash'], + hex(account.offset) + )) - for i in range(dt.getLength()): - removeDataAt(ea.add(i)) + print('{}\r\n'.format("-" * 60)) + + def analyze_service(self): + service_status = {} + print('{:-^60}'.format('analyze services')) + for service in sorted(vxworks_service_keyword.keys()): + service_status[service] = "Not available" + for service_function in vxworks_service_keyword[service]: + target_function = getFunction(service_function) + if not target_function: + target_function = getFunction("_{}".format(service_function)) + if target_function: + # print("Found {} in firmware, service {} might available".format(service_function, service)) + service_status[service] = "available" + + for service in sorted(service_status.items(), key=lambda x: x[1], reverse=True): + print('{}: {}'.format(service[0], service[1])) + print('{}\r\n'.format("-" * 60)) - createData(ea, dt) - # Using symbol_address as default symbol_name. - symbol_name = "0x{:08X}".format(symbol_dest_address) - add_symbol(symbol_name, symbol_name_address, symbol_dest_address, symbol_type) + def analyze_symbols(self): + print('{:-^60}'.format('analyze symbols using sysSymTbl')) + function_manager = currentProgram.getFunctionManager() + functions_count_before = function_manager.getFunctionCount() + sys_sym_tbl = get_symbol('sysSymTbl') - if getInt(ea) == 0 or ea == tail: - break + if not sys_sym_tbl: + print('{}\r\n'.format("-" * 60)) + return - ea = prev_symbol_addr + if not is_address_in_current_program(sys_sym_tbl.getAddress()): + print('{}\r\n'.format("-" * 60)) + return - return + sys_sym_addr = toAddr(getInt(sys_sym_tbl.getAddress())) + if not is_address_in_current_program(sys_sym_addr): + print("sys_sym_addr({:#010x}) is not in current_program".format(sys_sym_addr.getOffset())) + print('{}\r\n'.format("-" * 60)) + return -def analyze_symbols(): - print('{:-^60}'.format('analyze symbols using sysSymTbl')) - function_manager = currentProgram.getFunctionManager() - functions_count_before = function_manager.getFunctionCount() - sys_sym_tbl = getSymbol('sysSymTbl', currentProgram.getGlobalNamespace()) - if not sys_sym_tbl: - sys_sym_tbl = getSymbol('_sysSymTbl', currentProgram.getGlobalNamespace()) + if sys_sym_addr.getOffset() == 0: + print('{}\r\n'.format("-" * 60)) + return - if not sys_sym_tbl: - print('{}\r\n'.format("-" * 60)) - return + else: + try: + if not self._vx_version: + vx_version = askChoice("Choice", "Please choose VxWorks main Version ", ["5.x", "6.x"], "5.x") + if vx_version == u"5.x": + self._vx_version = 5 + + elif vx_version == u"6.x": + self._vx_version = 6 + print("VxHunter didn't support symbols analyze for VxWorks version 6.x") + + if self._vx_version == 5: + print("Functions count: {}(Before analyze) ".format(functions_count_before)) + # for i in range(vx_5_sys_symtab.getLength()): + # removeDataAt(sys_sym_addr.add(i)) + create_struct(sys_sym_addr, vx_5_sys_symtab) + hash_tbl_addr = toAddr(getInt(sys_sym_addr.add(0x04))) + # for i in range(vx_5_hash_tbl.getLength()): + # removeDataAt(hash_tbl_addr.add(i)) + create_struct(hash_tbl_addr, vx_5_hash_tbl) + hash_tbl_length = getInt(hash_tbl_addr.add(0x04)) + hash_tbl_array_addr = toAddr(getInt(hash_tbl_addr.add(0x14))) + hash_tbl_array_data_type = ArrayDataType(vx_5_sl_list, hash_tbl_length, vx_5_sl_list.getLength()) + create_struct(hash_tbl_array_addr, hash_tbl_array_data_type) + for i in range(0, hash_tbl_length): + list_head = toAddr(getInt(hash_tbl_array_addr.add(i * 8))) + list_tail = toAddr(getInt(hash_tbl_array_addr.add((i * 8) + 0x04))) + if is_address_in_current_program(list_head) and is_address_in_current_program(list_tail): + fix_symbol_by_chains(list_head, list_tail, vx_version) + functions_count_after = function_manager.getFunctionCount() + print("Functions count: {}(After analyze) ".format(functions_count_after)) + print("VxHunter found {} new functions".format(functions_count_after - functions_count_before)) + except Exception as err: + print(err) - if not is_address_in_current_program(sys_sym_tbl.getAddress()): print('{}\r\n'.format("-" * 60)) - return - sys_sym_addr = toAddr(getInt(sys_sym_tbl.getAddress())) + def analyze_function_xref_by_symbol_get(self): + print('{:-^60}'.format('analyze symFindByName function call')) + + # symFindByName analyze + target_function = getFunction("symFindByName") + + if not target_function: + target_function = getFunction("_symFindByName") + + if target_function: + parms_data = dump_call_parm_value(call_address=target_function.getEntryPoint()) + logger.debug("Found {} symFindByName call".format(len(parms_data))) + logger.debug("parms_data.keys(): {}".format(parms_data.keys())) + currentReferenceManager = currentProgram.getReferenceManager() + for call_addr in parms_data: + try: + call_parms = parms_data[call_addr] + logger.debug("call_parms: {}".format(call_parms)) + if 'parm_2' not in call_parms['parms'].keys(): + continue + + searched_symbol_name_ptr = call_parms['parms']['parm_2']['parm_data'] + if isinstance(searched_symbol_name_ptr, DataDB): + searched_symbol_name = searched_symbol_name_ptr.value + if isinstance(searched_symbol_name, GenericAddress): + if is_address_in_current_program(searched_symbol_name): + searched_symbol_name = getDataAt(searched_symbol_name) + logger.debug("type(searched_symbol_name): {}".format(type(searched_symbol_name))) + logger.debug("searched_symbol_name: {}".format(searched_symbol_name)) + if isinstance(searched_symbol_name, unicode) is False: + searched_symbol_name = searched_symbol_name.value + print("Found symFindByName({}) call at {:#010x}".format(searched_symbol_name, + call_parms['call_addr'].offset)) + + to_function = getFunction(searched_symbol_name) + + if to_function: + ref_to = to_function.getEntryPoint() + ref_from = call_parms['call_addr'] + currentReferenceManager.addMemoryReference(ref_from, ref_to, RefType.READ, + SourceType.USER_DEFINED, 0) + print("Add Reference for {}( {:#010x} ) function call at {:#010x} in {}( {:#010x} )".format( + to_function, + ref_to.offset, + call_parms['call_addr'].offset, + call_parms['refrence_function_name'], + call_parms['refrence_function_addr'].offset + ) + ) + + else: + print("Can't find {} symbol in firmware".format(searched_symbol_name)) + + logger.debug("{}({}) at {:#010x} in {}({:#010x})".format(target_function.name, searched_symbol_name, + call_parms['call_addr'].offset, + call_parms['refrence_function_name'], + call_parms['refrence_function_addr'].offset + )) + except Exception as err: + print(err) - if not is_address_in_current_program(sys_sym_addr): - print("sys_sym_addr({:#010x}) is not in current_program".format(sys_sym_addr.getOffset())) - print('{}\r\n'.format("-" * 60)) - return + else: + print("Can't find {} function in firmware".format(target_function)) - if sys_sym_addr.getOffset() == 0: print('{}\r\n'.format("-" * 60)) - return - else: - try: - vx_version = askChoice("Choice", "Please choose VxWorks main Version ", ["5.x", "6.x"], "5.x") - if vx_version == u"5.x": - vx_version = 5 - - elif vx_version == u"6.x": - vx_version = 6 - print("VxHunter didn't support symbols analyze for VxWorks version 6.x") - - if vx_version == 5: - print("Functions count: {}(Before analyze) ".format(functions_count_before)) - for i in range(vx_5_sys_symtab.getLength()): - removeDataAt(sys_sym_addr.add(i)) - createData(sys_sym_addr, vx_5_sys_symtab) - hash_tbl_addr = toAddr(getInt(sys_sym_addr.add(0x04))) - for i in range(vx_5_hash_tbl.getLength()): - removeDataAt(hash_tbl_addr.add(i)) - createData(hash_tbl_addr, vx_5_hash_tbl) - hash_tbl_length = getInt(hash_tbl_addr.add(0x04)) - hash_tbl_array_addr = toAddr(getInt(hash_tbl_addr.add(0x14))) - hash_tbl_array_data_type = ArrayDataType(vx_5_sl_list, hash_tbl_length, vx_5_sl_list.getLength()) - for i in range(hash_tbl_array_data_type.getLength()): - removeDataAt(hash_tbl_array_addr.add(i)) - createData(hash_tbl_array_addr, hash_tbl_array_data_type) - for i in range(0, hash_tbl_length): - list_head = toAddr(getInt(hash_tbl_array_addr.add(i * 8))) - list_tail = toAddr(getInt(hash_tbl_array_addr.add((i * 8) + 0x04))) - if is_address_in_current_program(list_head) and is_address_in_current_program(list_tail): - fix_symbol_by_chains(list_head, list_tail, vx_version) - functions_count_after = function_manager.getFunctionCount() - print("Functions count: {}(After analyze) ".format(functions_count_after)) - print("VxHunter found {} new functions".format(functions_count_after - functions_count_before)) - except Exception as err: - print(err) + def analyze_netpool(self): + print('{:-^60}'.format('analyze netpool')) + a = ["_pNetDpool", "_pNetPoolFuncTbl", "_pNetSysPool"] + net_dpool = get_symbol('_pNetDpool') + net_dpool_addr = toAddr(getInt(net_dpool.getAddress())) - print('{}\r\n'.format("-" * 60)) + if not is_address_in_current_program(net_dpool_addr): + print("net_dpool_addr({:#010x}) is not in current_program".format(net_dpool_addr.getOffset())) + elif net_dpool_addr.getOffset() == 0: + pass -def analyze_function_xref_by_symbol_get(): - print('{:-^60}'.format('analyze symFindByName function call')) + print("Found net_dpool_addr at {:#010x}".format(net_dpool_addr.getOffset())) - # symFindByName analyze - target_function = getFunction("symFindByName") + try: + if not self._vx_version: + vx_version = askChoice("Choice", "Please choose VxWorks main Version ", ["5.x", "6.x"], "5.x") - if not target_function: - target_function = getFunction("_symFindByName") + if vx_version == u"5.x": + self._vx_version = 5 - if target_function: - parms_data = dump_call_parm_value(call_address=target_function.getEntryPoint()) - logger.debug("Found {} symFindByName call".format(len(parms_data))) - logger.debug("parms_data.keys(): {}".format(parms_data.keys())) - currentReferenceManager = currentProgram.getReferenceManager() - for call_addr in parms_data: - try: - call_parms = parms_data[call_addr] - logger.debug("call_parms: {}".format(call_parms)) - if 'parm_2' not in call_parms['parms'].keys(): - continue - - searched_symbol_name_ptr = call_parms['parms']['parm_2']['parm_data'] - if isinstance(searched_symbol_name_ptr, DataDB): - searched_symbol_name = searched_symbol_name_ptr.value - if isinstance(searched_symbol_name, GenericAddress): - if is_address_in_current_program(searched_symbol_name): - searched_symbol_name = getDataAt(searched_symbol_name) - logger.debug("type(searched_symbol_name): {}".format(type(searched_symbol_name))) - logger.debug("searched_symbol_name: {}".format(searched_symbol_name)) - if isinstance(searched_symbol_name, unicode) is False: - searched_symbol_name = searched_symbol_name.value - print("Found symFindByName({}) call at {:#010x}".format(searched_symbol_name, - call_parms['call_addr'].offset)) - - to_function = getFunction(searched_symbol_name) - - if to_function: - ref_to = to_function.getEntryPoint() - ref_from = call_parms['call_addr'] - currentReferenceManager.addMemoryReference(ref_from, ref_to, RefType.READ, - SourceType.USER_DEFINED, 0) - print("Add Reference for {}( {:#010x} ) function call at {:#010x} in {}( {:#010x} )".format( - to_function, - ref_to.offset, - call_parms['call_addr'].offset, - call_parms['refrence_function_name'], - call_parms['refrence_function_addr'].offset - ) - ) + elif vx_version == u"6.x": + self._vx_version = 6 + print("VxHunter didn't support netpool analyze for VxWorks version 6.x") - else: - print("Can't find {} symbol in firmware".format(searched_symbol_name)) + if self._vx_version == 5: + fix_netpool(net_dpool_addr, 5) - logger.debug("{}({}) at {:#010x} in {}({:#010x})".format(target_function.name, searched_symbol_name, - call_parms['call_addr'].offset, - call_parms['refrence_function_name'], - call_parms['refrence_function_addr'].offset - )) - except Exception as err: - print(err) + except Exception as err: + print(err) - else: - print("Can't find {} function in firmware".format(target_function)) + print('{}\r\n'.format("-" * 60)) - print('{}\r\n'.format("-" * 60)) + def start_analyzer(self): + self.analyze_bss() + self.analyze_login_accouts() + self.analyze_service() + self.analyze_symbols() + self.analyze_function_xref_by_symbol_get() + self.analyze_netpool() if __name__ == '__main__': - analyze_bss() - analyze_login_accouts() - analyze_service() - analyze_symbols() - analyze_function_xref_by_symbol_get() + analyzer = VxAnalyzer() + analyzer.start_analyzer() diff --git a/firmware_tools/ghidra/vxhunter_utility/symbol.py b/firmware_tools/ghidra/vxhunter_utility/symbol.py index f6ea37f..9f52c5d 100644 --- a/firmware_tools/ghidra/vxhunter_utility/symbol.py +++ b/firmware_tools/ghidra/vxhunter_utility/symbol.py @@ -4,6 +4,7 @@ CharDataType, UnsignedIntegerDataType, IntegerDataType, + UnsignedLongDataType, ShortDataType, PointerDataType, VoidDataType, @@ -73,6 +74,8 @@ char_data_type = CharDataType() void_data_type = VoidDataType() unsigned_int_type = UnsignedIntegerDataType() +int_type = IntegerDataType() +unsigned_long_type = UnsignedLongDataType() short_data_type = ShortDataType() char_ptr_type = ptr_data_type.getPointer(char_data_type, 4) void_ptr_type = ptr_data_type.getPointer(void_data_type, 4) @@ -122,6 +125,139 @@ vx_5_sl_list.replaceAtOffset(0x00, void_ptr_type, 4, "head", "header of list") vx_5_sl_list.replaceAtOffset(0x04, void_ptr_type, 4, "tail", "tail of list") +''' +typedef struct clPool + { + int clSize; /* cluster size */ + int clLg2; /* cluster log 2 size */ + int clNum; /* number of clusters */ + int clNumFree; /* number of clusters free */ + int clUsage; /* number of times used */ + CL_BUF_ID pClHead; /* pointer to the cluster head */ + struct netPool * pNetPool; /* pointer to the netPool */ + } CL_POOL; + +typedef CL_POOL * CL_POOL_ID; +''' +vx_5_clPool = StructureDataType("VX_5_clPool", 0x1c) +vx_5_clPool.replaceAtOffset(0x00, int_type, 4, "clSize", "cluster size") +vx_5_clPool.replaceAtOffset(0x04, int_type, 4, "clLg2", "cluster log 2 size") +vx_5_clPool.replaceAtOffset(0x08, int_type, 4, "clNum", "number of clusters") +vx_5_clPool.replaceAtOffset(0x0c, int_type, 4, "clNumFree", "number of clusters free") +vx_5_clPool.replaceAtOffset(0x10, int_type, 4, "clUsage", "number of times used") +vx_5_clPool.replaceAtOffset(0x14, void_ptr_type, 4, "pClHead", "pointer to the cluster head") +vx_5_clPool.replaceAtOffset(0x18, void_ptr_type, 4, "pNetPool", "pointer to the netPool") + +''' +typedef struct mbstat + { + ULONG mNum; /* mBlks obtained from page pool */ + ULONG mDrops; /* times failed to find space */ + ULONG mWait; /* times waited for space */ + ULONG mDrain; /* times drained protocols for space */ + ULONG mTypes[256]; /* type specific mBlk allocations */ + } M_STAT; +''' +VX_5_M_TYPES_SIZE = 256 +vx_5_mTypes_array_data_type = ArrayDataType(unsigned_long_type, VX_5_M_TYPES_SIZE, unsigned_long_type.getLength()) +vx_5_pool_stat = StructureDataType("VX_5_PoolStat", 0x10 + VX_5_M_TYPES_SIZE * 4) +vx_5_pool_stat.replaceAtOffset(0x00, unsigned_long_type, 4, "mNum", "mBlks obtained from page pool") +vx_5_pool_stat.replaceAtOffset(0x04, int_type, 4, "mDrops", "times failed to find space") +vx_5_pool_stat.replaceAtOffset(0x08, int_type, 4, "mWait", "times waited for space") +vx_5_pool_stat.replaceAtOffset(0x0c, int_type, 4, "mDrain", "times drained protocols for space") +vx_5_pool_stat.replaceAtOffset(0x10, vx_5_mTypes_array_data_type, vx_5_mTypes_array_data_type.getLength(), + "mTypes", "type specific mBlk allocations") + + +''' +struct poolFunc /* POOL_FUNC */ + { + /* pointer to the pool initialization routine */ + STATUS (*pInitRtn) (NET_POOL_ID pNetPool, M_CL_CONFIG * pMclBlkConfig, CL_DESC * pClDescTbl, + int clDescTblNumEnt, BOOL fromKheap); + + /* pointer to mBlk free routine */ + void (*pMblkFreeRtn) (NET_POOL_ID pNetPool, M_BLK_ID pMblk); + + /* pointer to cluster Blk free routine */ + void (*pClBlkFreeRtn) (CL_BLK_ID pClBlk); + + /* pointer to cluster free routine */ + void (*pClFreeRtn) (NET_POOL_ID pNetPool, char * pClBuf); + + /* pointer to mBlk/cluster pair free routine */ + M_BLK_ID (*pMblkClFreeRtn) (NET_POOL_ID pNetPool, M_BLK_ID pMblk); + + /* pointer to mBlk get routine */ + M_BLK_ID (*pMblkGetRtn) (NET_POOL_ID pNetPool, int canWait, UCHAR type); + + /* pointer to cluster Blk get routine */ + CL_BLK_ID (*pClBlkGetRtn) (NET_POOL_ID pNetPool, int canWait); + + /* pointer to a cluster buffer get routine */ + char * (*pClGetRtn) (NET_POOL_ID pNetPool, CL_POOL_ID pClPool); + + /* pointer to mBlk/cluster pair get routine */ + STATUS (*pMblkClGetRtn) (NET_POOL_ID pNetPool, M_BLK_ID pMblk, int bufSize, int canWait, BOOL bestFit); + + /* pointer to cluster pool Id get routine */ + CL_POOL_ID (*pClPoolIdGetRtn) (NET_POOL_ID pNetPool, int bufSize, BOOL bestFit); + }; +''' +vx_5_pool_func_dict = { + "pInitRtn": "pointer to the pool initialization routine", + "pMblkFreeRtn": "pointer to mBlk free routine", + "pClBlkFreeRtn": "pointer to cluster Blk free routine", + "pClFreeRtn": "pointer to cluster free routine", + "pMblkClFreeRtn": "pointer to mBlk/cluster pair free routine", + "pMblkGetRtn": "pointer to mBlk get routine", + "pClBlkGetRtn": "pointer to cluster Blk get routine", + "pClGetRtn": "pointer to a cluster buffer get routine", + "pMblkClGetRtn": "pointer to mBlk/cluster pair get routine", + "pClPoolIdGetRtn": "pointer to cluster pool Id get routine", +} +vx_5_pool_func_tbl = StructureDataType("VX_5_pFuncTbl", 0x28) +func_offset = 0 +for func_name in vx_5_pool_func_dict: + func_desc = vx_5_pool_func_dict[func_name] + vx_5_pool_func_tbl.replaceAtOffset(func_offset, void_ptr_type, 4, "*{}".format(func_name), func_desc) + func_offset += 0x04 + + +''' +struct netPool /* NET_POOL */ + { + M_BLK_ID pmBlkHead; /* head of mBlks */ + CL_BLK_ID pClBlkHead; /* head of cluster Blocks */ + int mBlkCnt; /* number of mblks */ + int mBlkFree; /* number of free mblks */ + int clMask; /* cluster availability mask */ + int clLg2Max; /* cluster log2 maximum size */ + int clSizeMax; /* maximum cluster size */ + int clLg2Min; /* cluster log2 minimum size */ + int clSizeMin; /* minimum cluster size */ + CL_POOL * clTbl [CL_TBL_SIZE]; /* pool table */ + M_STAT * pPoolStat; /* pool statistics */ + POOL_FUNC * pFuncTbl; /* ptr to function ptr table */ + }; +''' +VX_5_CL_TBL_SIZE = 11 +vx_5_clTbl_array_data_type = ArrayDataType(void_ptr_type, VX_5_CL_TBL_SIZE, void_ptr_type.getLength()) +vx_5_netPool = StructureDataType("VX_5_netPool", 0x58) +vx_5_netPool.replaceAtOffset(0x00, void_ptr_type, 4, "pmBlkHead", "head of mBlks") +vx_5_netPool.replaceAtOffset(0x04, void_ptr_type, 4, "pClBlkHead", "head of cluster Blocks") +vx_5_netPool.replaceAtOffset(0x08, int_type, 4, "mBlkCnt", "number of mblks") +vx_5_netPool.replaceAtOffset(0x0C, int_type, 4, "mBlkFree", "number of free mblks") +vx_5_netPool.replaceAtOffset(0x10, int_type, 4, "clMask", "ncluster availability mask") +vx_5_netPool.replaceAtOffset(0x14, int_type, 4, "clLg2Max", "cluster log2 maximum size") +vx_5_netPool.replaceAtOffset(0x18, int_type, 4, "clSizeMax", "maximum cluster size") +vx_5_netPool.replaceAtOffset(0x1C, int_type, 4, "clLg2Min", "cluster log2 minimum size") +vx_5_netPool.replaceAtOffset(0x20, int_type, 4, "clSizeMin", "minimum cluster size") +vx_5_netPool.replaceAtOffset(0x24, vx_5_clTbl_array_data_type, vx_5_clTbl_array_data_type.getLength(), + "clTbl", "pool table") +vx_5_netPool.replaceAtOffset(0x50, void_ptr_type, 4, "pPoolStat", "pool statistics") +vx_5_netPool.replaceAtOffset(0x54, void_ptr_type, 4, "pFuncTbl", "ptr to function ptr table") + function_name_chaset = string.letters function_name_chaset += string.digits @@ -215,7 +351,7 @@ def demangle_function(demangle_string): function_name = None function_return = None function_parameters = None - function_name_end = len(demangle_string) + function_name_end = len(demangle_string) - 1 # get parameters index = len(demangle_string) - 1 @@ -274,6 +410,37 @@ def demangle_function(demangle_string): return function_return, function_name, function_parameters +def demangled_symbol(symbol_string): + sym_demangled_name = None + sym_demangled = None + if can_demangle: + try: + sym_demangled = demangler.demangle(symbol_string, True) + + if not sym_demangled: + # some mangled function name didn't start with mangled prefix + sym_demangled = demangler.demangle(symbol_string, False) + + except DemangledException as err: + logger.debug("DemangledException: symbol_string: {}, reason:{}".format(symbol_string, err)) + + try: + if not sym_demangled: + # Temp fix to handle _ prefix function name by remove _ prefix before demangle + sym_demangled = demangler.demangle(symbol_string[1:], False) + + except DemangledException as err: + logger.debug("DemangledException: symbol_string: {}, reason:{}".format(symbol_string, err)) + + if sym_demangled: + sym_demangled_name = sym_demangled.getSignature(False) + + if sym_demangled_name: + logger.debug("sym_demangled_name: {}".format(sym_demangled_name)) + + return sym_demangled_name + + def add_symbol(symbol_name, symbol_name_address, symbol_address, symbol_type): symbol_address = toAddr(symbol_address) symbol_name_string = symbol_name @@ -282,50 +449,31 @@ def add_symbol(symbol_name, symbol_name_address, symbol_address, symbol_type): if symbol_name_address: symbol_name_address = toAddr(symbol_name_address) if getDataAt(symbol_name_address): - print("removeDataAt: %s" % symbol_name_address) + logger.debug("removeDataAt: {}".format(symbol_name_address)) removeDataAt(symbol_name_address) try: symbol_name_string = createAsciiString(symbol_name_address).getValue() - print("symbol_name_string: %s" % symbol_name_string) + logger.debug("symbol_name_string: {}".format(symbol_name_string)) except CodeUnitInsertionException as err: - print("Got CodeUnitInsertionException: {}".format(err)) + logger.error("Got CodeUnitInsertionException: {}".format(err)) except: return if getInstructionAt(symbol_address): - print("removeInstructionAt: %s" % symbol_address) + logger.debug("removeInstructionAt: {}".format(symbol_address)) removeInstructionAt(symbol_address) # Demangle symName try: # Demangle symName - sym_demangled_name = None - if can_demangle: - try: - sym_demangled = demangler.demangle(symbol_name_string, True) - - if not sym_demangled: - # some mangled function name didn't start with mangled prefix - sym_demangled = demangler.demangle(symbol_name_string, False) - - if not sym_demangled: - # Temp fix to handle _ prefix function name by remove _ prefix before demangle - sym_demangled = demangler.demangle(symbol_name_string[1:], False) - - if sym_demangled: - sym_demangled_name = sym_demangled.getSignature(False) - - except DemangledException as err: - sym_demangled_name = None - - if sym_demangled_name: - print("sym_demangled_name: %s" % sym_demangled_name) + sym_demangled_name = demangled_symbol(symbol_name_string) if symbol_name_string and (symbol_type in need_create_function): - print("Start disassemble function %s at address %s" % (symbol_name_string, symbol_address.toString())) + logger.debug("Start disassemble function {} at address {}".format(symbol_name_string, + symbol_address.toString())) disassemble(symbol_address) function = createFunction(symbol_address, symbol_name_string) if function: @@ -341,9 +489,10 @@ def add_symbol(symbol_name, symbol_name_address, symbol_address, symbol_type): codeUnit.setComment(codeUnit.PLATE_COMMENT, sym_demangled_name) # Rename function function_return, function_name, function_parameters = demangle_function(sym_demangled_name) - print("Demangled function name is: %s" % function_name) - print("Demangled function return is: %s" % function_return) - print("Demangled function parameters is: %s" % function_parameters) + logger.debug("Demangled function name is: {}".format(function_name)) + logger.debug("Demangled function return is: {}".format(function_return)) + logger.debug("Demangled function parameters is: {}".format(function_parameters)) + if function_name: function.setName(function_name, SourceType.USER_DEFINED) # Todo: Add parameters later @@ -357,10 +506,18 @@ def add_symbol(symbol_name, symbol_name_address, symbol_address, symbol_type): codeUnit.setComment(codeUnit.PLATE_COMMENT, sym_demangled_name) except Exception as err: - print("Create function Failed: %s" % err) + logger.error("Create symbol failed: symbol_name:{}, symbol_name_address:{}, " + "symbol_address:{}, symbol_type:{} reason: {}".format(symbol_name_string, + symbol_name_address, + symbol_address, + symbol_type, err)) except: - print("Create function Failed: Java error") + logger.debug("Create symbol failed: symbol_name:{}, symbol_name_address:{}, " + "symbol_address{}, symbol_type{} with Unknown error".format(symbol_name_string, + symbol_name_address, + symbol_address, + symbol_type)) def fix_symbol_table_structs(symbol_table_start, symbol_table_end, vx_version): @@ -386,11 +543,113 @@ def is_vx_symbol_file(file_data, is_big_endian=True): # Check key function names for key_function in function_name_key_words: if key_function not in file_data: - print("key function not found") + logger.debug("key function not found") return False if is_big_endian: return struct.unpack('>I', file_data[:4])[0] == len(file_data) else: - return struct.unpack(',__default_alloc_template>::operator", function_name) self.assertEqual("[](unsigned int)", function_parameters) + + def test_demangle_function_13(self): + demangle_sting = "___tf36CServiceRequestSerialPollActiveState" + function_return, function_name, function_parameters = demangle_function(demangle_sting) + self.assertEqual(None, function_return) + self.assertEqual("___tf36CServiceRequestSerialPollActiveState", function_name) + self.assertEqual("", function_parameters)