#include <Python.h>
#include <frameobject.h>
#include <sys/types.h>
#include <pthread.h>
#include <stdio.h>
#include <stdbool.h>
#include <unistd.h>

#include "mpack/mpack.h"

//#include <assert.h>
// By default, python c extensions compile with asserts disable and it is not
// super easy to disable this behavior (AFAIK). Use this custom assert for now.
void assert_helper(int lineno, char* expr) {
    printf("Assertion failed on line %d: %s\n", lineno, expr);
    exit(-1);
}
#undef assert
#define assert(expr) \
    if (!(expr)) \
        assert_helper(__LINE__, #expr)

////////////////////////////////////////////////////////////////////////////////
// Types
////////////////////////////////////////////////////////////////////////////////

// The messagepack writer for the current thread.
typedef mpack_writer_t Writer;

////////////////////////////////////////////////////////////////////////////////
// Prototypes
////////////////////////////////////////////////////////////////////////////////
static Writer* Fprofile_Writer(void);

static int     Fprofile_FunctionTrace(PyObject* obj, PyFrameObject* frame, int what, PyObject* arg);

////////////////////////////////////////////////////////////////////////////////
// Globals
////////////////////////////////////////////////////////////////////////////////
// This key is used to fetch the thread local instance of the `Writer`
// structure.
static pthread_key_t Tss_Key = 0;

// True iff we've started tracing and haven't been marked as terminated.  When
// this is set, we're allowed to send messages to the profile generation
// server.
static bool started;

// The exposed apis of the Rust extension.
static struct {
    pthread_key_t (*set_config)(bool* config, int (*functiontrace)(PyObject*, PyFrameObject*, int, PyObject*));
}* rust;

////////////////////////////////////////////////////////////////////////////////
// Server Communication
////////////////////////////////////////////////////////////////////////////////

static Writer* Fprofile_Writer(void) {
    assert(Tss_Key != 0);

    if (!started) {
        // We aren't running (either starting up or shutting down), so
        // shouldn't write anything.
        return NULL;
    }

    // Fetch the Writerr for the current thread.  It's created as late as
    // the first function call for the thread, so we can plausibly get messages
    // before having anywhere to store/send them to.  Drop them if so - they
    // weren't too important if we aren't running code yet.
    Writer* state = pthread_getspecific(Tss_Key);
    if (state == NULL) {
        // The thread hasn't officially started yet, so drop this message.
        return NULL;
    }

    return state;
}

////////////////////////////////////////////////////////////////////////////////
// Misc
////////////////////////////////////////////////////////////////////////////////

// Given a PyObject (likely `co_name` or similar), return a UTF-8
// representation.
static inline const char* Fprofile_UnicodeToUtf8(PyObject* obj) {
    if (obj == NULL) {
        return "<NULL>";
    }

    if (PyUnicode_Check(obj)) {
        const char* utf8 = PyUnicode_AsUTF8(obj);

        return utf8 != NULL ? utf8 : "<DECODE ERROR>";
    } else if (obj == Py_None) {
        return "<NONE>";
    } else {
        return "<UNKNOWN>";
    }
}

////////////////////////////////////////////////////////////////////////////////
// Callbacks
////////////////////////////////////////////////////////////////////////////////
static int Fprofile_FunctionTrace(PyObject* obj, PyFrameObject* frame, int what, PyObject* arg) {
    PyCFunctionObject* fn  = (PyCFunctionObject*) arg;
    struct timespec    tsc = { 0 };

    Writer* writer = Fprofile_Writer();
    if (writer == NULL) {
        return 0;
    }

    clock_gettime(CLOCK_MONOTONIC, &tsc);

    switch (what) {
        case PyTrace_CALL:
            mpack_start_array(writer, 5);
            mpack_write_cstr(writer, "Call");

            mpack_start_array(writer, 2);
            mpack_write_u32(writer, tsc.tv_sec);
            mpack_write_u32(writer, tsc.tv_nsec);
            mpack_finish_array(writer);

            {
                // TODO: Check if frame is none or similar
#if PY_VERSION_HEX >= 0x030a00a0
                // Python >= 3.10 should use the standard accessors
                PyCodeObject* code = PyFrame_GetCode(frame);
                int lineno = PyFrame_GetLineNumber(frame);
#else
                PyCodeObject* code = frame->f_code;
                int lineno = frame->f_lineno;
#endif

#if PY_VERSION_HEX >= 0x030b00a0
                mpack_write_cstr(writer, Fprofile_UnicodeToUtf8(code->co_qualname));
#else
                mpack_write_cstr(writer, Fprofile_UnicodeToUtf8(code->co_name));
#endif
                mpack_write_cstr(writer, Fprofile_UnicodeToUtf8(code->co_filename));
                mpack_write_u32(writer, lineno);

#if PY_VERSION_HEX >= 0x030a00a0
                Py_DECREF(code);
#endif
            }
            mpack_finish_array(writer);
            break;
        case PyTrace_RETURN:
            mpack_start_array(writer, 3);
            mpack_write_cstr(writer, "Return");

            mpack_start_array(writer, 2);
            mpack_write_u32(writer, tsc.tv_sec);
            mpack_write_u32(writer, tsc.tv_nsec);
            mpack_finish_array(writer);

            {
#if PY_VERSION_HEX >= 0x030a00a0
                PyCodeObject* code = PyFrame_GetCode(frame);
#else
                PyCodeObject* code = frame->f_code;
#endif
#if PY_VERSION_HEX >= 0x030b00a0
                mpack_write_cstr(writer, Fprofile_UnicodeToUtf8(code->co_qualname));
#else
                mpack_write_cstr(writer, Fprofile_UnicodeToUtf8(code->co_name));
#endif

#if PY_VERSION_HEX >= 0x030a00a0
                Py_DECREF(code);
#endif
            }
            mpack_finish_array(writer);
            break;
        case PyTrace_C_CALL: {
            mpack_start_array(writer, 4);
            mpack_write_cstr(writer, "NativeCall");

            mpack_start_array(writer, 2);
            mpack_write_u32(writer, tsc.tv_sec);
            mpack_write_u32(writer, tsc.tv_nsec);
            mpack_finish_array(writer);

            {
                // Attempt to determine what module/class this function belongs
                // to.
                PyObject*   self        = fn->m_self;
                PyObject*   module      = fn->m_module;
                const char* name        = fn->m_ml->ml_name;
                const char* module_name = NULL;

                // Check if we belong to a module, and if not we must be a
                // method.  We do this order to avoid finding that the object
                // we belong to is of type module.
                if (module != NULL) {
                    if (PyModule_Check(module)) {
                        module_name = PyModule_GetName(module);
                    } else if (PyUnicode_Check(module)) {
                        module_name = PyUnicode_AsUTF8(module);
                    }
                } else if (self != NULL) {
                    // This is a method call on a class.
                    module_name = self->ob_type->tp_name;
                }

                mpack_write_cstr(writer, name != NULL ? name : "NULL");
                mpack_write_cstr(writer, module_name != NULL ? module_name : "NULL");
            }
            mpack_finish_array(writer);
            break;
        }
        case PyTrace_C_RETURN:
            mpack_start_array(writer, 3);
            mpack_write_cstr(writer, "NativeReturn");

            mpack_start_array(writer, 2);
            mpack_write_u32(writer, tsc.tv_sec);
            mpack_write_u32(writer, tsc.tv_nsec);
            mpack_finish_array(writer);

            {
                const char* name = fn->m_ml->ml_name;
                mpack_write_cstr(writer, name != NULL ? name : "NULL");
            }
            mpack_finish_array(writer);
            break;
        default:
            // TODO: We should handle exceptions here (or somewhere similar).
            break;
    }

    return 0;
}

////////////////////////////////////////////////////////////////////////////////
// Module Initialization
////////////////////////////////////////////////////////////////////////////////

// The set of methods exposed by this module to Python.
static PyMethodDef methods[] = {
    [0] = {NULL, NULL, 0, NULL}
};

static PyModuleDef module = {
    PyModuleDef_HEAD_INIT,
    "_functiontrace",
    "",
    -1,
    methods
};

PyMODINIT_FUNC PyInit__functiontrace(void) {
    // Setup our janky Rust dlopen alternative so we can use Rust code.
    PyObject* rust_mod = PyImport_ImportModule("_functiontrace_rs");

    // Verify that the rust extension loads if we're using it.
    if (rust_mod == NULL) {
        perror("Failed to load internal Rust extension");
        exit(-1);
    }

    // This returns a pointer to the C API exposed by the Rust module.
    PyObject* py_c_api = PyObject_GetAttrString(rust_mod, "c_api");
    assert(py_c_api != NULL);
    PyObject* c_api = PyObject_CallFunctionObjArgs(py_c_api, NULL);

    // Save the Rust vtable so everyone can access it
    rust = PyLong_AsVoidPtr(c_api);

    // Share our globals with Rust
    Tss_Key = rust->set_config(&started, Fprofile_FunctionTrace);

    return PyModule_Create(&module);
}
