# -*- coding: utf-8 -*-
"""
conversion d'un fichier musicxml en objet song minwii.

$Id$
$URL$
"""
import sys
from types import StringTypes
from xml.dom.minidom import parse
from optparse import OptionParser
from itertools import cycle
#from Song import Song

# Do4 <=> midi 60
OCTAVE_REF = 4
DIATO_SCALE = {'C' : 60,
               'D' : 62,
               'E' : 64,
               'F' : 65,
               'G' : 67,
               'A' : 69,
               'B' : 71}

CHROM_SCALE = {  0 : ('C',  0),
                 1 : ('C',  1),
                 2 : ('D',  0),
                 3 : ('E', -1),
                 4 : ('E',  0),
                 5 : ('F',  0),
                 6 : ('F',  1),
                 7 : ('G',  0),
                 8 : ('G',  1),
                 9 : ('A',  0),
                10 : ('B', -1),
                11 : ('B',  0)}


FR_NOTES = {'C' : u'Do',
            'D' : u'Ré',
            'E' : u'Mi',
            'F' : u'Fa',
            'G' : u'Sol',
            'A' : u'La',
            'B' : u'Si'}

_marker = []

class Part(object) :
        
    def __init__(self, node, autoDetectChorus=True) :
        self.node = node
        self.notes = []
        self.repeats = []
        self.distinctNotes = []
        self.quarterNoteDuration = 500
        self._parseMusic()
        self.verses = [[]]
        self.chorus = []
        self.songStartsWithChorus = False
        self._findVersesLoops(autoDetectChorus)
    
    def _parseMusic(self) :
        divisions = 0
        previous = None
        distinctNotesDict = {}

        for measureNode in self.node.getElementsByTagName('measure') :
            measureNotes = []
            
            # iteration sur les notes
            # divisions de la noire
            divisions = int(_getNodeValue(measureNode, 'attributes/divisions', divisions))
            for noteNode in measureNode.getElementsByTagName('note') :
                note = Note(noteNode, divisions, previous)
                if (not note.isRest) and (not note.tiedStop) :
                    measureNotes.append(note)
                    if previous :
                        previous.next = note
                elif note.tiedStop :
                    assert previous.tiedStart
                    previous.addDuration(note)
                    continue
                else :
                    try :
                        previous.addDuration(note)
                    except AttributeError :
                        # can occur if part starts with a rest.
                        if previous is not None :
                            # something else is wrong.
                            raise 
                    continue
                previous = note

            self.notes.extend(measureNotes)

            for note in measureNotes :
                if not distinctNotesDict.has_key(note.midi) :
                    distinctNotesDict[note.midi] = True
                    self.distinctNotes.append(note)
            
            # barres de reprises
            try :
                barlineNode = measureNode.getElementsByTagName('barline')[0]
            except IndexError :
                continue
            
            barline = Barline(barlineNode, measureNotes)
            if barline.repeat :
                self.repeats.append(barline)
        
        self.distinctNotes.sort(lambda a, b : cmp(a.midi, b.midi))
        sounds = self.node.getElementsByTagName('sound')
        tempo = 120
        for sound in sounds :
            if sound.hasAttribute('tempo') :
                tempo = float(sound.getAttribute('tempo'))
                break
        
        self.quarterNoteDuration = int(round(60000/tempo))
        
        
    def _findVersesLoops(self, autoDetectChorus) :
        "recherche des couplets / boucles"
        verse = self.verses[0]
        for note in self.notes[:-1] :
            verse.append(note)
            ll = len(note.lyrics)
            nll = len(note.next.lyrics)
            if ll != nll :
                verse = []
                self.verses.append(verse)
        verse.append(self.notes[-1])
        
        if autoDetectChorus and len(self.verses) > 1 :
            for i, verse in enumerate(self.verses) :
                if len(verse[0].lyrics) == 1 :
                    self.chorus = self.verses.pop(i)
                    self.songStartsWithChorus = i==0
                    break
        
    
    def iterNotes(self) :
        "exécution de la chanson avec l'alternance couplets / refrains"
        for verse in self.verses :
            if self.songStartsWithChorus :
                for note in self.chorus :
                    yield note, 0
                
            #print "---partie---"
            repeats = len(verse[0].lyrics)
            if repeats > 1 :
                for i in range(repeats) :
                    # couplet
                    #print "---couplet%d---" % i
                    for note in verse :
                        yield note, i
                    # refrain
                    #print "---refrain---"
                    for note in self.chorus :
                        yield note, 0
            else :
                for note in verse :
                    yield note, 0
    
    @property
    def intervalsHistogram(self) :
        histogram = {}
        it = self.iterNotes()
        previousNote = it.next()[0]
        for note, _ in it :
            interval = note.midi - previousNote.midi
            if histogram.has_key(interval) :
                histogram[interval] += 1
            else :
                histogram[interval] = 1
            previousNote = note
        return histogram
    
    @property
    def duration(self) :
        'Durée de référence du morceau en milisecondes'
        it = self.iterNotes()
        duration = 0
        for note, verseIndex in it :
            duration = duration + note.duration
        duration = duration * self.quarterNoteDuration # en milisecondes
        return duration
        
        
    def pprint(self) :
        for note, verseIndex in self.iterNotes(indefinitely=False) :
            print note, note.lyrics[verseIndex]


    def assignNotesFromMidiNoteNumbers(self):
        # TODO faire le mapping bande hauteur midi
        for i in range(len(self.midiNoteNumbers)):
            noteInExtendedScale = 0
            while self.midiNoteNumbers[i] > self.scale[noteInExtendedScale] and noteInExtendedScale < len(self.scale)-1:
                noteInExtendedScale += 1
            if self.midiNoteNumbers[i]<self.scale[noteInExtendedScale]:
                noteInExtendedScale -= 1
            self.notes.append(noteInExtendedScale)


class Barline(object) :

    def __init__(self, node, measureNotes) :
        self.node = node
        location = self.location = node.getAttribute('location') or 'right'
        try :
            repeatN = node.getElementsByTagName('repeat')[0]
            repeat = {'direction' : repeatN.getAttribute('direction'),
                      'times' : int(repeatN.getAttribute('times') or 1)}
            if location == 'left' :
                repeat['note'] = measureNotes[0]
            elif location == 'right' :
                repeat['note'] = measureNotes[-1]
            else :
                raise ValueError(location)
            self.repeat = repeat
        except IndexError :
            self.repeat = None
    
    def __str__(self)  :
        if self.repeat :
            if self.location == 'left' :
                return '|:'
            elif self.location == 'right' :
                return ':|'
        return '|'

    __repr__ = __str__


class Tone(object) :
    
    @staticmethod
    def midi_to_step_alter_octave(midi):
        stepIndex = midi % 12
        step, alter = CHROM_SCALE[stepIndex]
        octave = midi / 12 - 1
        return step, alter, octave
    
    
    def __init__(self, *args) :
        if len(args) == 3 :
            self.step, self.alter, self.octave = args
        elif len(args) == 1 :
            midi = args[0]
            self.step, self.alter, self.octave = Tone.midi_to_step_alter_octave(midi)

    @property
    def midi(self) :
        mid = DIATO_SCALE[self.step]
        mid = mid + (self.octave - OCTAVE_REF) * 12
        mid = mid + self.alter
        return mid

    
    @property
    def name(self) :
        name = u'%s%d' % (self.step, self.octave)
        if self.alter < 0 :
            alterext = 'b'
        else :
            alterext = '#'
        name = '%s%s' % (name, abs(self.alter) * alterext)
        return name

    @property
    def nom(self) :
        name = FR_NOTES[self.step]
        if self.alter < 0 :
            alterext = u'♭'
        else :
            alterext = u'#'
        name = u'%s%s' % (name, abs(self.alter) * alterext)
        return name
        
        

class Note(Tone) :
    scale = [55, 57, 59, 60, 62, 64, 65, 67, 69, 71, 72]
    
    def __init__(self, node, divisions, previous) :
        self.node = node
        self.isRest = False
        self.tiedStart = False
        self.tiedStop = False
        
        tieds = _getElementsByPath(node, 'notations/tied', [])
        for tied in tieds :
            if tied.getAttribute('type') == 'start' :
                self.tiedStart = True
            elif tied.getAttribute('type') == 'stop' :
                self.tiedStop = True
        
        self.step = _getNodeValue(node, 'pitch/step', None)
        if self.step is not None :
            self.octave = int(_getNodeValue(node, 'pitch/octave'))
            self.alter = int(_getNodeValue(node, 'pitch/alter', 0))
        elif self.node.getElementsByTagName('rest') :
            self.isRest = True
        else :
            NotImplementedError(self.node.toxml('utf-8'))
            
        self._duration = float(_getNodeValue(node, 'duration'))
        self.lyrics = []
        for ly in node.getElementsByTagName('lyric') :
            self.lyrics.append(Lyric(ly))

        self.divisions = divisions
        self.previous = previous
        self.next = None
    
    def __str__(self) :
        return (u'%5s %2s %2d %4s' % (self.nom, self.name, self.midi, round(self.duration, 2))).encode('utf-8')
    
    def __repr__(self) :
        return self.name.encode('utf-8')
    
    def addDuration(self, note) :
        self._duration = self.duration + note.duration
        self.divisions = 1
    
    @property
    def duration(self) :
        return self._duration / self.divisions
    
    @property
    def column(self):
        return self.scale.index(self.midi)
    

class Lyric(object) :
    
    _syllabicModifiers = {
        'single' : u'%s',
        'begin'  : u'%s -',
        'middle' : u'- %s -',
        'end'    : u'- %s'
        }
    
    def __init__(self, node) :
        self.node = node
        self.syllabic = _getNodeValue(node, 'syllabic', 'single')
        self.text = _getNodeValue(node, 'text')
    
    def syllabus(self):
        text = self._syllabicModifiers[self.syllabic] % self.text
        return text
    
    def __str__(self) :
        return self.syllabus().encode('utf-8')
    __repr__  = __str__
        
        


def _getNodeValue(node, path, default=_marker) :
    try :
        for name in path.split('/') :
            node = node.getElementsByTagName(name)[0]
        return node.firstChild.nodeValue
    except :
        if default is _marker :
            raise
        else :
            return default

def _getElementsByPath(node, path, default=_marker) :
    try :
        parts = path.split('/')
        for name in parts[:-1] :
            node = node.getElementsByTagName(name)[0]
        return node.getElementsByTagName(parts[-1])
    except IndexError :
        if default is _marker :
            raise
        else :
            return default

def musicXml2Song(input, partIndex=0, autoDetectChorus=True, printNotes=False) :
    if isinstance(input, StringTypes) :
        input = open(input, 'r')
    
    d = parse(input)
    doc = d.documentElement
    
    # TODO conversion préalable score-timewise -> score-partwise
    if doc.nodeName != u'score-partwise' :
        raise ValueError('not a musicxml file')
    
    parts = doc.getElementsByTagName('part')
    leadPart = parts[partIndex]
    
    part = Part(leadPart, autoDetectChorus=autoDetectChorus)
    
    if printNotes :
        part.pprint()

    return part

    
    
def main() :
    usage = "%prog musicXmlFile.xml [options]"
    op = OptionParser(usage)
    op.add_option("-i", "--part-index", dest="partIndex"
                 , default = 0
                 , help = "Index de la partie qui contient le champ.")

    op.add_option("-p", '--print', dest='printNotes'
                  , action="store_true"
                  , default = False
                  , help = "Affiche les notes sur la sortie standard (debug)")

    op.add_option("-c", '--no-chorus', dest='autoDetectChorus'
                , action="store_false"
                , default = True
                , help = "désactive la détection du refrain")

    
    options, args = op.parse_args()
    
    if len(args) != 1 :
        raise SystemExit(op.format_help())
    
    song = musicXml2Song(args[0],
                  partIndex=options.partIndex,
                  autoDetectChorus=options.autoDetectChorus,
                  printNotes=options.printNotes)
    from pprint import pprint
    pprint(song.intervalsHistogram)
    print song.duration


if __name__ == '__main__' :
    sys.exit(main())
