#!/usr/bin/env python

import argparse
import codecs
import datetime
import itertools
import json
import os.path
import types

import fitparse


def format_message(num, message, options):
    s = [f"{num}. {message.name}"]
    if options.with_defs:
        s.append(f' [{message.type}]')
    s.append('\n')

    if message.type == 'data':
        for field_data in message:
            s.append(f' * {field_data.name}: {field_data.value}')
            if field_data.units:
                s.append(f' [{field_data.units}]')
            s.append('\n')

    s.append('\n')
    return "".join(s)


def parse_args(args=None):
    parser = argparse.ArgumentParser(
        description='Dump .FIT files to various formats',
        epilog='python-fitparse version %s' % fitparse.__version__,
    )
    parser.add_argument('-v', '--verbose', action='count', default=0)
    parser.add_argument(
        '-o', '--output', type=argparse.FileType(mode='w', encoding="utf-8"),
        default="-",
        help='File to output data into (defaults to stdout)',
    )
    parser.add_argument(
        # TODO: csv
        '-t', '--type', choices=('readable', 'json', 'gpx'), default='readable',
        help='File type to output. (DEFAULT: %(default)s)',
    )
    parser.add_argument(
        '-n', '--name', action='append', help='Message name (or number) to filter',
    )
    parser.add_argument(
        'infile', metavar='FITFILE', type=argparse.FileType(mode='rb'),
        help='Input .FIT file (Use - for stdin)',
    )
    parser.add_argument(
        '--ignore-crc', action='store_const', const=True, help='Some devices seem to write invalid crc\'s, ignore these.'
    )

    options = parser.parse_args(args)

    options.verbose = options.verbose >= 1
    options.with_defs = (options.type == "readable" and options.verbose)
    options.as_dict = (options.type != "readable" and options.verbose)

    return options


class RecordJSONEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, types.GeneratorType):
            return list(obj)
        if isinstance(obj, (datetime.datetime, datetime.time)):
            return obj.isoformat()
        if isinstance(obj, fitparse.DataMessage):
            return {
                "type": obj.name,
                "data": {
                    data.name: data.value for data in obj
                }
            }
        # Fall back to original to raise a TypeError
        return super().default(obj)


def generate_gpx(records, filename=None):
    # TODO: Use xml.etree.ElementTree ?

    GPX_TIME_FMT = "%Y-%m-%dT%H:%M:%SZ"  # ISO 8601 format

    records = iter(records)

    # header + open tags
    yield '<?xml version="1.0"?>\n'
    yield '<gpx xmlns="http://www.topografix.com/GPX/1/1" version="1.1" creator="python-fitparse (fitdump)" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://www.topografix.com/GPX/1/1 http://www.topografix.com/GPX/1/1/gpx.xsd">\n'
    yield ' <metadata>\n'

    # file creation time (if a file_id record exists)
    first_record = []
    for message in records:
        if message.name == "file_id":
            for field_data in message:
                if field_data.name == "time_created" and type(field_data.value) == datetime.datetime:
                    yield f'  <time>{field_data.value.strftime(GPX_TIME_FMT)}</time>\n'
                    break
            else:
                # No time found in the fields, check next record
                continue
            break
        elif message.name == "record":
            first_record.append(message)
            break

    if filename:
        yield f'  <src>{filename}</src>\n'

    yield ' </metadata>\n'
    yield ' <trk>\n'

    if filename:
        yield f'  <name>{filename}</name>\n'

    yield '  <trkseg>\n'

    # track points
    for message in itertools.chain(first_record, records):
        if message.name != "record":
            continue

        trkpt = {}

        # TODO: support more data types (heart rate, cadence, etc)
        for field_data in message:
            if field_data.name == "position_lat":
                # Units are decimal degrees
                trkpt["lat"] = field_data.value
            elif field_data.name == "position_long":
                # Units are decimal degrees
                trkpt["lon"] = field_data.value
            elif field_data.name == "enhanced_altitude":
                # Units are m
                trkpt["ele"] = field_data.value
            elif field_data.name == "timestamp" and type(field_data.value) == datetime.datetime:
                trkpt["time"] = field_data.value.strftime(GPX_TIME_FMT)
            elif field_data.name == "enhanced_speed" and type(field_data.value) == float:
                # convert from km/h to m/s
                trkpt["speed"] = field_data.value / 3.6

        # Add trackpoint
        if "lat" in trkpt and "lon" in trkpt:
            yield '   <trkpt lat="{lat}" lon="{lon}">\n'.format(**trkpt)
            if "ele" in trkpt:
                yield '    <ele>{ele}</ele>\n'.format(**trkpt)
            if "time" in trkpt:
                yield '    <time>{time}</time>\n'.format(**trkpt)
            if "speed" in trkpt:
                yield '    <speed>{speed}</speed>\n'.format(**trkpt)
            yield '   </trkpt>\n'

    # close tags
    yield '  </trkseg>\n'
    yield ' </trk>\n'
    yield '</gpx>\n'


def main(args=None):
    options = parse_args(args)

    fitfile = fitparse.UncachedFitFile(
        options.infile,
        data_processor=fitparse.StandardUnitsDataProcessor(),
        check_crc=not(options.ignore_crc),
    )
    records = fitfile.get_messages(
        name=options.name,
        with_definitions=options.with_defs,
        as_dict=options.as_dict
    )

    try:
        if options.type == "json":
            json.dump(records, fp=options.output, cls=RecordJSONEncoder)
        elif options.type == "readable":
            options.output.writelines(
                format_message(n, record, options) for n, record in enumerate(records, 1)
            )
        elif options.type == "gpx":
            filename = getattr(options.infile, "name")
            if filename:
                filename = os.path.basename(filename)
            options.output.writelines(generate_gpx(records, filename))
    finally:
        try:
            options.output.close()
        except OSError:
            pass

if __name__ == '__main__':
    try:
        main()
    except BrokenPipeError:
        pass
