import numpy as np
import file_constants as f_constants
import header_constants as h_constants

def get_last_block_in_file(byte_data): # this function checks the last block (the last file in a recording may not be filled until the end)
    expected_block_start_indices = []
    last_non_zero = np.max(np.nonzero(byte_data)) #if empty bytes are set as 0
    last_non_ff = np.max(np.where(byte_data == 255)) #if empty bytes are ff
    last_written_value = np.max([last_non_ff, last_non_zero])
    for j in range(0, len(byte_data), f_constants.BLOCK_SIZE):
        this_block_idx = j / f_constants.BLOCK_SIZE
        if last_written_value > j: #if this block is non empty
            expected_block_start_indices.append(this_block_idx)

    last_filled_block = np.max(expected_block_start_indices)
    return last_filled_block

def extract_data_segments(data_type, data_types_present, block_start_indices, byte_data, data_start_indices, data_segment_engths):
    #extract the data of the type data_type in the file
    data_bytes = []
    segment_idx = data_types_present.index(data_type)  # check which data segment it is
    for block_id in range(0, len(block_start_indices)):
        this_block_start = block_start_indices[block_id]
        this_block = byte_data[this_block_start:this_block_start + f_constants.BLOCK_SIZE]
        data_bytes.extend(
            this_block[data_start_indices[segment_idx]:data_start_indices[segment_idx] + data_segment_engths[segment_idx]])
    return data_bytes

def convert_neural_bytes(neural_bytes, voltage_resolution, offset):
    #convert neural data to voltage
    neural_ints = np.array(neural_bytes).view(np.uint16)

    neural_data = voltage_resolution*(neural_ints - float(offset))
    return neural_data


#get data indices where blocks start, check for dropped blocks or irregularities
def get_block_start_indices(byte_data):
    block_start_indices = []
    dropped_blocks = []
    last_filled_block = get_last_block_in_file(byte_data)
    for i in range(0, len(byte_data), f_constants.BLOCK_SIZE):
        this_block_idx = i / f_constants.BLOCK_SIZE
        if this_block_idx > last_filled_block:
            break
        const_id = byte_data[i:i + len(h_constants.HEX_CONST_ID)]
        if np.array_equal(const_id, h_constants.HEX_CONST_ID):
            block_start_indices.append(i)
        elif np.array_equal(const_id, f_constants.EMPTY_BYTES_00) or np.array_equal(const_id,
                                                                                    f_constants.EMPTY_BYTES_FF):
            dropped_blocks.append(this_block_idx)
        else:
            print("Warning: Unexpected values for header constant - file may be corrupted")

    for i in range(0, len(dropped_blocks)):
        print("Dropped block at block number {}".format(dropped_blocks[i]))

    return block_start_indices, dropped_blocks

def get_timestamps(block_start_indices, byte_data):
    # get timestamps of block starts.
    timestamps = []
    for block_id in range(0, len(block_start_indices)): #timestamps change in every header
        this_block_start = block_start_indices[block_id]
        header = byte_data[this_block_start:this_block_start + h_constants.HEADER_TOTAL_BYTES]
        this_timestamp = np.array(header[h_constants.TIME_STAMP_POSITION:h_constants.TIME_STAMP_POSITION +
                        h_constants.TIME_STAMP_BYTES]).view(np.uint32)[0]
        timestamps.append(this_timestamp)
    return timestamps

def get_partition_data(byte_data):
    # checks how blocks are partitioned into data types
    first_header = byte_data[0:h_constants.HEADER_TOTAL_BYTES] # structure of block is same in all blocks so using first header is sufficient
    partition_info = np.array(first_header[
                              h_constants.PARTITION_START_POSITION:h_constants.PARTITION_START_POSITION + h_constants.PARTITION_BYTES])
    data_types_present = []
    data_start_indices = []  # the byte within the block where this type of data starts
    data_segment_lengths = []
    partition_idx = 0
    num_bytes = 4  # stored as uint32
    while (partition_idx < len(partition_info)):
        data_types_present.append(partition_info[partition_idx:partition_idx + num_bytes].view(np.uint32)[0])
        partition_idx += num_bytes
        data_start_indices.append(partition_info[partition_idx:partition_idx + num_bytes].view(np.uint32)[0])
        partition_idx += num_bytes
        data_segment_lengths.append(partition_info[partition_idx:partition_idx + num_bytes].view(np.uint32)[0])
        partition_idx += num_bytes
    return data_types_present, data_start_indices, data_segment_lengths