centroids.py 3.5 KB
Newer Older
1
2
from senpy.plugins import EmotionConversionPlugin
from senpy.models import EmotionSet, Emotion, Error
drevicko's avatar
drevicko committed
3
from collections import defaultdict
4
5
6
7
8
9

import logging
logger = logging.getLogger(__name__)


class CentroidConversion(EmotionConversionPlugin):
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
    def __init__(self, info):
        if 'centroids' not in info:
            raise Error('Centroid conversion plugins should provide '
                        'the centroids in their senpy file')
        if 'onyx:doesConversion' not in info:
            if 'centroids_direction' not in info:
                raise Error('Please, provide centroids direction')

            cf, ct = info['centroids_direction']
            info['onyx:doesConversion'] = [{
                'onyx:conversionFrom': cf,
                'onyx:conversionTo': ct
            }, {
                'onyx:conversionFrom': ct,
                'onyx:conversionTo': cf
            }]

        if 'aliases' in info:
            aliases = info['aliases']
            ncentroids = {}
            for k1, v1 in info['centroids'].items():
                nv1 = {}
                for k2, v2 in v1.items():
                    nv1[aliases.get(k2, k2)] = v2
                ncentroids[aliases.get(k1, k1)] = nv1
            info['centroids'] = ncentroids
        super(CentroidConversion, self).__init__(info)
37
38

    def _forward_conversion(self, original):
39
        """Sum the VAD value of all categories found weighted by intensity. """
40
        res = Emotion()
41
42
43
        maxIntensity = float(original.get("onyx__maxIntensityValue",1))
        sumIntensities = 0
        neutralPoint = self.get("origin",None)
44
45
        for e in original.onyx__hasEmotion:
            category = e.onyx__hasEmotionCategory
46
            intensity = e.get("onyx__hasEmotionIntensity",maxIntensity)/maxIntensity
47
48
            if intensity == 0:
                continue
49
50
51
52
53
54
            sumIntensities += intensity
            centoid = self.centroids.get(category,None)
            if centroid:
                for dim, value in centroid.items():
                    if neutralPoint:
                        value -= neutralPoint[dim]
55
                    try:
drevicko's avatar
drevicko committed
56
                        res[dim] += value * intensity
57
                    except KeyError:
drevicko's avatar
drevicko committed
58
                        res[dim] = value * intensity
59
60
61
        if neutralPoint:
            for dim in res:
                res[dim] += neutralPoint[dim]
62
63
64
65
66
67
68
        return res

    def _backwards_conversion(self, original):
        """Find the closest category"""
        dimensions = list(self.centroids.values())[0]

        def distance(e1, e2):
69
            return sum((e1[k] - e2.get(k, 0)) for k in dimensions)
70
71
72
73
74
75
76
77
78
79
80
81
82
83

        emotion = ''
        mindistance = 10000000000000000000000.0
        for state in self.centroids:
            d = distance(self.centroids[state], original)
            if d < mindistance:
                mindistance = d
                emotion = state
        result = Emotion(onyx__hasEmotionCategory=emotion)
        return result

    def convert(self, emotionSet, fromModel, toModel, params):

        cf, ct = self.centroids_direction
84
85
        logger.debug(
            '{}\n{}\n{}\n{}'.format(emotionSet, fromModel, toModel, params))
86
        e = EmotionSet()
87
        if fromModel == cf and toModel == ct:
88
            e.onyx__hasEmotion.append(self._forward_conversion(emotionSet))
89
        elif fromModel == ct and toModel == cf:
90
91
92
93
94
            for i in emotionSet.onyx__hasEmotion:
                e.onyx__hasEmotion.append(self._backwards_conversion(i))
        else:
            raise Error('EMOTION MODEL NOT KNOWN')
        yield e