'''Python interface to the wiiuse library for the wii remote

Just a simple wrapper, no attempt to make the api pythonic. I tried to hide ctypes where
necessary.

This software is free for any use. If you or your lawyer are stupid enough to believe I have any
liability for it, then don't use it; otherwise, be my guest.

Gary Bishop, January 2008
hacked for new API and data June 2009
'''

import os
import ctypes
from ctypes import c_char_p, c_int, c_byte, c_uint, c_uint16, c_float, c_short, c_void_p, c_char, c_ubyte, c_ushort
from ctypes import CFUNCTYPE, Structure, POINTER, Union, byref, cdll
from ctypes.util import find_library
import sys

# duplicate the wiiuse data structures

class _Structure(Structure):
    def __repr__(self):
        '''Print the fields'''
        res = []
        for field in self._fields_:
            res.append('%s=%s' % (field[0], repr(getattr(self, field[0]))))
        return self.__class__.__name__ + '(' + ','.join(res) + ')'

class vec2b(_Structure):
    _fields_ = [('x', c_byte),
                ('y', c_byte),
                ]

class vec3b(_Structure):
    _fields_ = [('x', c_byte),
                ('y', c_byte),
                ('z', c_byte),
                ]

class vec3w(_Structure):
    _fields_ = [('x', c_ushort),
                ('y', c_ushort),
                ('z', c_ushort),
                ]

class vec3f(_Structure):
    _fields_ = [('x', c_float),
                ('y', c_float),
                ('z', c_float),
                ]

class orient(_Structure):
    _fields_ = [('roll', c_float),
                ('pitch', c_float),
                ('yaw', c_float),
                ('a_roll', c_float),
                ('a_pitch', c_float),
                ]

    def __repr__(self):
        return 'orient(roll=%f pitch=%f yaw=%f a_roll=%f a_pitch=%f)' % (
            self.roll, self.pitch, self.yaw, self.a_roll, self.a_pitch)

class accel(_Structure):
    _fields_ = [('cal_zero', vec3w),
                ('cal_g', vec3w),
                ('st_roll', c_float),
                ('st_pitch', c_float),
                ('st_alpha', c_float),
                ]

class ir_dot(_Structure):
    _fields_ = [('visible', c_byte),
                ('x', c_uint),
                ('y', c_uint),
                ('rx', c_short),
                ('ry', c_short),
                ('order', c_byte),
                ('size', c_byte),
                ]

class ir(_Structure):
    _fields_ = [('dot', ir_dot*4),
                ('num_dots', c_byte),
                ('aspect', c_int),
                ('pos', c_int),
                ('vres', c_uint*2),
                ('offset', c_int*2),
                ('state', c_int),
                ('ax', c_int),
                ('ay', c_int),
                ('x', c_int),
                ('y', c_int),
                ('distance', c_float),
                ('z', c_float),
                ]

    def __str__(self) :
        l = []
        pr = l.append
        for name, typ in self._fields_ :
            try :
                pr('(%s, %s)' % (name, getattr(self, name)))
            except :
                pass
        return '\n'.join(l)

class joystick(_Structure):
    _fields_ = [('max', vec2b),
                ('min', vec2b),
                ('center', vec2b),
                ('ang', c_float),
                ('mag', c_float),
                ]

class nunchuk(_Structure):
    _fields_ = [('accel_calib', accel),
                ('js', joystick),
                ('flags', POINTER(c_int)),
                ('btns', c_byte),
                ('btns_last', c_byte),
                ('btns_held', c_byte),
                ('btns_released', c_byte),
                ('orient_threshold', c_float),
                ('accel_threshold', c_float),
                ('accel', vec3w),
                ('orient', orient),
                ('gforce', vec3f),
                ]

class classic_ctrl(_Structure):
    _fields_ = [('btns', c_short),
                ('btns_last', c_short),
                ('btns_held', c_short),
                ('btns_released', c_short),
                ('r_shoulder', c_float),
                ('l_shoulder', c_float),
                ('ljs', joystick),
                ('rjs', joystick),
                ]

class guitar_hero_3(_Structure):
    _fields_ = [('btns', c_short),
                ('btns_last', c_short),
                ('btns_held', c_short),
                ('btns_released', c_short),
                ('whammy_bar', c_float),
                ('js', joystick),
                ]

class motion_plus(_Structure):
    _fields_ = [('rx', c_short),
                ('ry', c_short),
                ('rz', c_short),
                ('status', c_ubyte),
                ('ext', c_ubyte)
                ]

class expansion_union(Union):
    _fields_ = [('nunchuk', nunchuk),
                ('classic', classic_ctrl),
                ('gh3', guitar_hero_3),
                ('mp', motion_plus)
                ]

class expansion(_Structure):
    _fields_ = [('type', c_int),
                ('u', expansion_union),
                ]

class wiimote_state(_Structure):
    _fields_ = [('exp_ljs_ang', c_float),
                ('exp_rjs_ang', c_float),
                ('exp_ljs_mag', c_float),
                ('exp_rjs_mag', c_float),
                ('exp_btns', c_ushort),
                ('exp_orient', orient),
                ('exp_accel', vec3w),
                ('exp_r_shoulder', c_float),
                ('exp_l_shoulder', c_float),
                ('drx', c_short),
                ('dry', c_short),
                ('drz', c_short),
                ('ir_ax', c_int),
                ('ir_ay', c_int),
                ('ir_distance', c_float),
                ('orient', orient),
                ('btns', c_ushort),
                ('accel', vec3b),
                ('exp', expansion)
                ]

if os.name == 'nt':
    JunkSkip = [('dev_handle', c_void_p),
                ('hid_overlap', c_void_p*5), # skipping over this data structure
                ('stack', c_int),
                ('timeout', c_int),
                ('normal_timeout', c_byte),
                ('exp_timeout', c_byte),
                ]

elif sys.platform == 'darwin' :
    JunkSkip = [('device', c_void_p),
                ('bdaddr_str', c_char*18)
                ]

else:
    JunkSkip = [('bdaddr', c_void_p),
                ('bdaddr_str', c_char*18),
                ('out_sock', c_int),
                ('in_sock', c_int),
                ]

EVENT = 1
STATUS = 2
CONNECT = 3
DISCONNECT = 4
UNEXPECTED_DISCONNECT = 5
READ_DATA = 6
NUNCHUK_INSERTED = 7
NUNCHUK_REMOVED = 8
CLASSIC_CTRL_INSERTED = 9
CLASSIC_CTRL_REMOVED = 10
GUITAR_HERO_3_CTRL_INSERTED = 11
GUITAR_HERO_3_CTRL_REMOVED = 12

class wiimote(_Structure):
    _fields_ = [('unid', c_int),
                ] + JunkSkip + [
                ('state', c_int),
                ('leds', c_byte),
                ('battery_level', c_float),
                
                ('flags', c_int),
                
                ('handshake_state', c_byte),
                ('expansion_state', c_ubyte),
                ('read_req', c_void_p),
                ('data_req', c_void_p),
                
                ('cmd_head', c_void_p),
                ('cmd_tail', c_void_p),
                ('accel_calib', accel),
                ('exp', expansion),
                
                ('accel', vec3w),
                ('orient', orient),
                ('gforce', vec3f),
                
                ('ir', ir),
                
                ('btns', c_ushort),
                ('btns_last', c_ushort),
                ('btns_held', c_ushort),
                ('btns_released', c_ushort),
                ('orient_threshold', c_float),
                ('accel_threshold', c_int),
                
                ('lstate', wiimote_state),
                
                ('event', c_int),
                ('event_buf', c_byte*32),
                ('motion_plus_id', c_ubyte*6)
                ]

wiimote_p = POINTER(wiimote)
wiimote_pp = POINTER(wiimote_p)

# make function prototypes a bit easier to declare
def cfunc(name, dll, result, *args):
    '''build and apply a ctypes prototype complete with parameter flags
    e.g.
cvMinMaxLoc = cfunc('cvMinMaxLoc', _cxDLL, None,
                    ('image', POINTER(IplImage), 1),
                    ('min_val', POINTER(double), 2),
                    ('max_val', POINTER(double), 2),
                    ('min_loc', POINTER(CvPoint), 2),
                    ('max_loc', POINTER(CvPoint), 2),
                    ('mask', POINTER(IplImage), 1, None))
means locate cvMinMaxLoc in dll _cxDLL, it returns nothing.
The first argument is an input image. The next 4 arguments are output, and the last argument is
input with an optional value. A typical call might look like:

min_val,max_val,min_loc,max_loc = cvMinMaxLoc(img)
    '''
    atypes = []
    aflags = []
    for arg in args:
        atypes.append(arg[1])
        aflags.append((arg[2], arg[0]) + arg[3:])
    return CFUNCTYPE(result, *atypes)((name, dll), tuple(aflags))

# get the shared library
lib = find_library('wiiuse') or find_library('libwiiuse')
dll = cdll.LoadLibrary(lib)

#if os.name == 'nt':
#    dll = cdll.LoadLibrary('wiiuse.dll')
#else:
#    dll = cdll.LoadLibrary('libwiiuse.so')

# access the functions
init = cfunc('wiiuse_init', dll, wiimote_pp,
             ('wiimotes', c_int, 1))
# find = cfunc('wiiuse_find', dll, c_int,
#              ('wm', wiimote_pp, 1),
#              ('max_wiimotes', c_int, 1),
#              ('timeout', c_int, 1))
# connect = cfunc('wiiuse_connect', dll, c_int,
#                 ('wm', wiimote_pp, 1),
#                 ('wiimotes', c_int, 1))
# poll = cfunc('wiiuse_poll', dll, c_int,
#              ('wm', wiimote_pp, 1),
#              ('wiimotes', c_int, 1))
find = dll.wiiuse_find
connect = dll.wiiuse_connect
poll = dll.wiiuse_poll
set_leds = dll.wiiuse_set_leds
motion_sensing = dll.wiiuse_motion_sensing
set_accel_threshold = dll.wiiuse_set_accel_threshold
set_orient_threshold = dll.wiiuse_set_orient_threshold
set_orient_threshold.argtypes = [wiimote_p, c_float]
set_timeout = dll.wiiuse_set_timeout
set_ir = dll.wiiuse_set_ir
set_ir_position = dll.wiiuse_set_ir_position
set_ir_vres = dll.wiiuse_set_ir_vres

def is_pressed(dev, button):
    return dev.btns & button

def is_held(dev, button):
    return dev.btns_held & button

def is_released(dev, button):
    return dev.btns_released & button

def is_just_pressed(dev, button):
    return is_pressed(dev, button) and not is_held(dev, button)

def using_acc(wm):
    return wm.state & 0x10

def using_exp(wm):
    return wm.state & 0x20

def using_ir(wm):
    return wm.state & 0x40

LED_NONE = 0
LED_1 = 0x10
LED_2 = 0x20
LED_3 = 0x40
LED_4 = 0x80

LED = [LED_1, LED_2, LED_3, LED_4]

EXP_NONE = 0
EXP_NUNCHUK = 1
EXP_CLASSIC = 2

SMOOTHING = 0x01
CONTINUOUS = 0x02
ORIENT_THRESH = 0x04
INIT_FLAGS = SMOOTHING | ORIENT_THRESH

IR_ABOVE = 0
IR_BELOW = 1

ASPECT_4_3 = 0
ASPECT_16_9 = 1

button = { '2':0x0001,
           '1':0x0002,
           'B':0x0004,
           'A':0x0008,
           '-':0x0010,
           'Home':0x0080,
           'Left':0x0100,
           'Right':0x0200,
           'Down':0x0400,
           'Up':0x0800,
           '+':0x1000,
           }

nunchuk_button = { 'Z':0x01,
                   'C':0x02,
                   }


if __name__ == '__main__':
    def handle_event(wm):
        print 'EVENT', wm.unid, wm.btns
        #print wm.gforce.x, wm.gforce.y, wm.gforce.z
        print wm.ir
        
    nmotes = 1
    wiimotes = init(nmotes)
    print 'press 1&2'
    found = find(wiimotes, nmotes, 2)
    if not found:
        print 'no wiimotes found'
        sys.exit(1)
    
    connected = connect(wiimotes, nmotes)
    if connected:
        print 'connected to %d wiimotes (of %d found)' % (connected, found)
    else:
        print 'failed to connect to any wiimote.'
        sys.exit(1)

    set_leds(wiimotes[0], 0x20)
    motion_sensing(wiimotes[0], 1)
    set_ir(wiimotes[0], 1)

    while True:
        try :
            if poll(wiimotes, nmotes):
                print '.'
                for i in range(nmotes):
                    m = wiimotes[i][0]
                    if wiimotes[i][0].event == EVENT:
                        handle_event(wiimotes[i][0])
        except KeyboardInterrupt :
            break
 
