#include <string.h>
#include "memoryHeap.h"
#include "printf.h"
#include "irqs.h"

#define HEAP_CHUNK_ALIGN		8
#define MIN_CHUNK_SIZE			((HEAP_CHUNK_ALIGN > sizeof(struct MemoryHeapNodeFreeData)) ? HEAP_CHUNK_ALIGN : sizeof(struct MemoryHeapNodeFreeData))

#define LOG_HEAP_STATS			0


struct MemoryHeapNode {
	
	struct MemoryHeapNode			*prev;
	uint32_t						used: 1;
	uint32_t						size:31;
	uint8_t							data[];
};

struct MemoryHeapNodeFreeData {
	
	struct MemoryHeapNodeFreeData	*prevFree;
	struct MemoryHeapNodeFreeData	*nextFree;
};

struct MemoryHeap {
	
	struct MemoryHeapNodeFreeData	*freeLists[31];	//by powers of two in size of available mem. anything under 4K will be in the [11] bucket, anything under 1M in the [19] bucket
	uint32_t 						size;
};

static struct MemoryHeapNode* heapPrvGetNextChunk(struct MemoryHeap *heap, struct MemoryHeapNode *thisOne)
{
	uintptr_t nextLoc = ((uintptr_t)(thisOne + 1)) + thisOne->size;
	uintptr_t endLoc = ((uintptr_t)heap) + heap->size;
	
	if (endLoc < nextLoc)
		fatal("chunk goes past end of heap: 0x%08x > 0x%08x\n", nextLoc, endLoc);
	
	return (nextLoc == endLoc) ? NULL : (struct MemoryHeapNode*)nextLoc;
}

static struct MemoryHeapNode* heapPrvChunkHdrFromFreeListItem(struct MemoryHeap *heap, struct MemoryHeapNodeFreeData *fli)
{
	return ((struct MemoryHeapNode*)fli) - 1;
}

static struct MemoryHeapNodeFreeData* heapPrvFreeListItemFromChunkHdr(struct MemoryHeap *heap, struct MemoryHeapNode* hdr)
{
	return (struct MemoryHeapNodeFreeData*)(hdr + 1);
}

static void heapPrvUnlinkFreeChunkFromFreeList(struct MemoryHeap *heap, struct MemoryHeapNode *node)
{
	struct MemoryHeapNodeFreeData *fli = heapPrvFreeListItemFromChunkHdr(heap, node);
	
	if (node->used)
		fatal("Cannot unlink a non-free chunk\n");
	
	if (fli->nextFree)
		fli->nextFree->prevFree = fli->prevFree;
	if (fli->prevFree)
		fli->prevFree->nextFree = fli->nextFree;
	else {
		
		uint32_t idx = 31 - __builtin_clz(node->size);
		
		if (heap->freeLists[idx] == fli)
			heap->freeLists[idx] = fli->nextFree;
		else
			fatal("Node with no prev free not first inits free list\n");
	}
	
	fli->nextFree = NULL;
	fli->prevFree = NULL;
}

static void heapPrvInsertFreeChunk(struct MemoryHeap *heap, struct MemoryHeapNode *node)	//node->size will never be zero!
{
	if (node->used || node->size < MIN_CHUNK_SIZE)
		fatal("inserting invalid node into free list\n");
	else {
	
		struct MemoryHeapNodeFreeData **where = &heap->freeLists[31 - __builtin_clz(node->size)];
		struct MemoryHeapNodeFreeData *fli = heapPrvFreeListItemFromChunkHdr(heap, node);
		
		fli->prevFree = NULL;
		fli->nextFree = *where;
		if (fli->nextFree)
			fli->nextFree->prevFree = fli;
		*where = fli;
	}
}

//given a chunk and a desired size. if we can split is such that it makes sense to make a free chunk out of the result, doso and insert that into a freelist
static void heapPrvSplitFreeChunk(struct MemoryHeap *heap, struct MemoryHeapNode *node, uint32_t wantedSize)
{
	if (node->size < wantedSize)
		fatal("Chunk too small to even contain the desired data amount\n");
	
	wantedSize = (wantedSize + HEAP_CHUNK_ALIGN - 1) / HEAP_CHUNK_ALIGN * HEAP_CHUNK_ALIGN;
	
	if (node->size < wantedSize)	//chunk too small to round up but big enough for data - let it be
		return;
	
	if (node->size - wantedSize >= sizeof(struct MemoryHeapNode) + MIN_CHUNK_SIZE) {

		struct MemoryHeapNode *newOne = (struct MemoryHeapNode*)(node->data + wantedSize);
		struct MemoryHeapNode *next = heapPrvGetNextChunk(heap, node);
		
		newOne->size = node->size - wantedSize - sizeof(struct MemoryHeapNode);
		newOne->used = 0;
		newOne->prev = node;
		if (next)
			next->prev = newOne;
		node->size = wantedSize;
		
		heapPrvInsertFreeChunk(heap, newOne);
	}
}

static void heapPrvFree(struct MemoryHeap *heap, struct MemoryHeapNode *node)
{
	if (!node->used)
		fatal("cannot free a free chunk\n");
	else {
		struct MemoryHeapNode *prev, *next;
		uint32_t otherWord;
		
		node->used = 0;
		
		//merge backwards (in theory loop is unnecessary)
		while ((prev = node->prev) && !prev->used) {
			
			next = heapPrvGetNextChunk(heap, node);
			heapPrvUnlinkFreeChunkFromFreeList(heap, prev);
			prev->size += node->size + sizeof(struct MemoryHeapNode);
			
			if (next)
				next->prev = prev;
			
			node = prev;
		}
		
		//merge forwards (in theory loop is unnecessary)
		while ((next = heapPrvGetNextChunk(heap, node)) && !next->used) {
			
			heapPrvUnlinkFreeChunkFromFreeList(heap, next);
			node->size += next->size + sizeof(struct MemoryHeapNode);
			
			next = heapPrvGetNextChunk(heap, node);
			if (next)
				next->prev = node;
		}
		heapPrvInsertFreeChunk(heap, node);
	}
}

static struct MemoryHeapNode* heapPrvAlloc(struct MemoryHeap *heap, uint32_t size)	//do not pass zero size!
{
	uint32_t bucket;
	
	for (bucket = 31 - __builtin_clz(size); bucket < 31; bucket++) {
		
		struct MemoryHeapNode *node, *best = NULL;
		struct MemoryHeapNodeFreeData *fli;
		
		for (fli = heap->freeLists[bucket]; fli; fli = fli->nextFree) {
			
			node = heapPrvChunkHdrFromFreeListItem(heap, fli);
			
			if (node->size < size)
				continue;
			
			if (!best || node->size < best->size)
				best = node;
		}
		
		if (best) {
			
			heapPrvUnlinkFreeChunkFromFreeList(heap, best);
			best->used = 1;
			heapPrvSplitFreeChunk(heap, best, size);

			return best;
		}
	}
	
	return NULL;
}

struct MemoryHeap* memoryheapInit(uintptr_t base, uintptr_t size)
{
	struct MemoryHeapNode *node;
	struct MemoryHeap *heap;
	
	if (size < HEAP_CHUNK_ALIGN)
		return NULL;
	while (base % HEAP_CHUNK_ALIGN) {
		base++;
		size--;
	}
	
	heap = (struct MemoryHeap*)base;
	
	if (size < sizeof(struct MemoryHeap) + sizeof(struct MemoryHeapNode) + MIN_CHUNK_SIZE)
		return NULL;
	
	memset(heap, 0, sizeof(*heap));
	heap->size = size;
	
	//create a single free chunk
	node = (struct MemoryHeapNode*)(heap + 1);
	node->used = 0;
	node->prev = NULL;
	node->size = size - sizeof(struct MemoryHeap) - sizeof(struct MemoryHeapNode);
	
	heapPrvInsertFreeChunk(heap, node);
	
	return heap;
}

static void heapStats(struct MemoryHeap *heap)
{
	uint32_t free = 0, largest = 0, i;
	struct MemoryHeapNodeFreeData* node;
	
	for (i = 0; i < sizeof(heap->freeLists) / sizeof(*heap->freeLists); i++) {
		for (node = heap->freeLists[i]; node; node = node->nextFree) {
			
			struct MemoryHeapNode *ch = heapPrvChunkHdrFromFreeListItem(heap, node);

			free += ch->size;
			if (ch->size > largest)
				largest = ch->size;
		}
	}
	
	logw("HEAP 0x%08x is %lu bytes large, has %lu bytes free, largest chunk is %lu bytes\n", heap, heap->size, free, largest);
}

void* memoryheapAlloc(struct MemoryHeap *heap, uint32_t sz)
{
	struct MemoryHeapNode* node;
	void *ret = NULL;
	irq_state_t sta;
	
	if (!sz)
		return NULL;
	
	sta = irqsAllOff();
	if (LOG_HEAP_STATS)
		loge("Allocing %u bytes from heap at 0x%08x\n", sz, heap);
	node = heapPrvAlloc(heap, sz);
	if (node)
		ret = node->data;
	if (LOG_HEAP_STATS)
		heapStats(heap);
	irqsRestoreState(sta);
	
	return ret;
}

void memoryheapFree(struct MemoryHeap *heap, void* ptr)
{
	if (!ptr)
		return;
	else {
		struct MemoryHeapNode *node = ((struct MemoryHeapNode*)ptr) - 1;
		irq_state_t sta;
		
		sta = irqsAllOff();
		if (LOG_HEAP_STATS)
			loge("Freeing %u bytes from heap at 0x%08x\n", node->size, heap);
		heapPrvFree(heap, node);
		if (LOG_HEAP_STATS)
			heapStats(heap);
		irqsRestoreState(sta);
	}
}

uint32_t memoryheapGetChunkActualSize(struct MemoryHeap *heap, const void* ptr)
{
	struct MemoryHeapNode *node = ((struct MemoryHeapNode*)ptr) - 1;
	
	return node->size;
}

bool memoryheapIsThisInThere(struct MemoryHeap *heap, const void* ptr)
{
	const char *heapStart = (const char*)heap;
	const char *heapEnd = heapStart + heap->size;
	const char *ptrUnderTest = (const char*)ptr;
	
	return ptrUnderTest >= heapStart && ptrUnderTest < heapEnd;
}
