Fonctions agrĂ©gĂ©es dĂ©finies par l’utilisateur Python¶

Les fonctions agrĂ©gĂ©es dĂ©finies par l’utilisateur (UDAFs) prennent une ou plusieurs lignes en entrĂ©e et produisent une seule ligne en sortie. Elles agissent sur les valeurs de lignes pour effectuer des calculs mathĂ©matiques tels que la somme, la moyenne, le comptage, les valeurs minimale/maximale, l’écart type et l’estimation, ainsi que d’autres opĂ©rations non mathĂ©matiques.

Les UDAFs Python vous permettent d’écrire vos propres fonctions d’agrĂ©gation qui sont similaires aux fonctions d’agrĂ©gation SQL dĂ©finies par le systĂšme Snowflake.

Vous pouvez Ă©galement crĂ©er vos propres UDAFs en utilisant des APIs Snowpark comme dĂ©crit dans CrĂ©ation de fonctions dĂ©finies par l’utilisateur (UDAFs) pour DataFrames dans Python.

Limitations¶

  • aggregate_state a une taille maximale de 8 MB dans une version sĂ©rialisĂ©e, essayez donc de contrĂŽler la taille de l’état agrĂ©gĂ©.

  • Vous ne pouvez pas appeler une UDAF en tant que fonction de fenĂȘtre (en d’autres termes, avec une clause OVER).

  • IMMUTABLE n’est pas pris en charge sur une fonction d’agrĂ©gation (lorsque vous utilisez le paramĂštre AGGREGATE). Par consĂ©quent, toutes les fonctions d’agrĂ©gation sont VOLATILE par dĂ©faut.

  • Les fonctions d’agrĂ©gation dĂ©finies par l’utilisateur ne peuvent pas ĂȘtre utilisĂ©es conjointement avec la clause WITHIN GROUP. Les requĂȘtes ne pourront pas ĂȘtre exĂ©cutĂ©es.

Interface pour le gestionnaire (handler) de la fonction d’agrĂ©gation¶

Une fonction d’agrĂ©gation regroupe les Ă©tats des nƓuds enfants, puis ces Ă©tats agrĂ©gĂ©s sont sĂ©rialisĂ©s et envoyĂ©s au nƓud parent oĂč ils sont fusionnĂ©s et oĂč le rĂ©sultat final est calculĂ©.

Pour dĂ©finir une fonction agrĂ©gĂ©e, vous devez dĂ©finir une classe Python (qui est le gestionnaire (handler) de la fonction) qui comprend des mĂ©thodes que Snowflake appelle au moment de l’exĂ©cution. Ces mĂ©thodes sont dĂ©crites dans le tableau ci-dessous. Voir les exemples ailleurs dans cette rubrique.

Méthode

Exigence

Description

__init__

Obligatoire

Initialise l’état interne d’un agrĂ©gat.

aggregate_state

Obligatoire

Renvoie l’état actuel d’un agrĂ©gat.

accumulate

Obligatoire

Accumule l’état de l’agrĂ©gat sur la base de la nouvelle ligne d’entrĂ©e.

merge

Obligatoire

Combine deux états agrégés intermédiaires.

finish

Obligatoire

Produit le rĂ©sultat final sur la base de l’état agrĂ©gĂ©.

Diagramme montrant des valeurs d'entrĂ©e accumulĂ©es dans des nƓuds enfants, puis envoyĂ©es Ă  un nƓud parent et fusionnĂ©es pour produire un rĂ©sultat final.

Exemple : calculer une somme¶

Le code de l’exemple suivant dĂ©finit une fonction d’agrĂ©gation dĂ©finie par l’utilisateur python_sum (UDAF) pour renvoyer la somme des valeurs numĂ©riques.

  1. CrĂ©ez l’UDAF.

    CREATE OR REPLACE AGGREGATE FUNCTION PYTHON_SUM(a INT)
      RETURNS INT
      LANGUAGE PYTHON
      RUNTIME_VERSION = 3.9
      HANDLER = 'PythonSum'
    AS $$
    class PythonSum:
      def __init__(self):
        # This aggregate state is a primitive Python data type.
        self._partial_sum = 0
    
      @property
      def aggregate_state(self):
        return self._partial_sum
    
      def accumulate(self, input_value):
        self._partial_sum += input_value
    
      def merge(self, other_partial_sum):
        self._partial_sum += other_partial_sum
    
      def finish(self):
        return self._partial_sum
    $$;
    
    Copy
  2. Créez une table de données de test.

    CREATE OR REPLACE TABLE sales(item STRING, price INT);
    
    INSERT INTO sales VALUES ('car', 10000), ('motorcycle', 5000), ('car', 7500), ('motorcycle', 3500), ('motorcycle', 1500), ('car', 20000);
    
    SELECT * FROM sales;
    
    Copy
  3. Appelez l”python_sum UDAF.

    SELECT python_sum(price) FROM sales;
    
    Copy
  4. Comparez les rĂ©sultats avec la sortie de la fonction SQL dĂ©finie par le systĂšme Snowflake, SUM, et constatez que le rĂ©sultat est le mĂȘme.

    SELECT sum(col) FROM sales;
    
    Copy
  5. Regroupez par des valeurs de somme par type d’article dans le tableau des ventes.

    SELECT item, python_sum(price) FROM sales GROUP BY item;
    
    Copy

Exemple : calculer une moyenne¶

Le code de l’exemple suivant dĂ©finit une fonction d’agrĂ©gation dĂ©finie par l’utilisateur python_avg pour renvoyer la moyenne des valeurs numĂ©riques.

  1. Créez la fonction.

    CREATE OR REPLACE AGGREGATE FUNCTION python_avg(a INT)
      RETURNS FLOAT
      LANGUAGE PYTHON
      RUNTIME_VERSION = 3.9
      HANDLER = 'PythonAvg'
    AS $$
    from dataclasses import dataclass
    
    @dataclass
    class AvgAggState:
        count: int
        sum: int
    
    class PythonAvg:
        def __init__(self):
            # This aggregate state is an object data type.
            self._agg_state = AvgAggState(0, 0)
    
        @property
        def aggregate_state(self):
            return self._agg_state
    
        def accumulate(self, input_value):
            sum = self._agg_state.sum
            count = self._agg_state.count
    
            self._agg_state.sum = sum + input_value
            self._agg_state.count = count + 1
    
        def merge(self, other_agg_state):
            sum = self._agg_state.sum
            count = self._agg_state.count
    
            other_sum = other_agg_state.sum
            other_count = other_agg_state.count
    
            self._agg_state.sum = sum + other_sum
            self._agg_state.count = count + other_count
    
        def finish(self):
            sum = self._agg_state.sum
            count = self._agg_state.count
            return sum / count
    $$;
    
    Copy
  2. Créez une table de données de test.

    CREATE OR REPLACE TABLE sales(item STRING, price INT);
    INSERT INTO sales VALUES ('car', 10000), ('motorcycle', 5000), ('car', 7500), ('motorcycle', 3500), ('motorcycle', 1500), ('car', 20000);
    
    Copy
  3. Appelez la fonction dĂ©finie par l’utilisateur python_avg.

    SELECT python_avg(price) FROM sales;
    
    Copy
  4. Comparez les rĂ©sultats avec la sortie de la fonction SQL dĂ©finie par le systĂšme Snowflake, AVG, et constatez que le rĂ©sultat est le mĂȘme.

    SELECT avg(price) FROM sales;
    
    Copy
  5. Regroupez les valeurs moyennes par type d’article dans le tableau des ventes.

    SELECT item, python_avg(price) FROM sales GROUP BY item;
    
    Copy

Exemple : ne renvoyer que les valeurs uniques¶

Le code de l’exemple suivant prend un tableau et renvoie un tableau contenant uniquement les valeurs uniques.

CREATE OR REPLACE AGGREGATE FUNCTION pythonGetUniqueValues(input ARRAY)
  RETURNS ARRAY
  LANGUAGE PYTHON
  RUNTIME_VERSION = 3.9
  HANDLER = 'PythonGetUniqueValues'
AS $$
class PythonGetUniqueValues:
    def __init__(self):
        self._agg_state = set()

    @property
    def aggregate_state(self):
        return self._agg_state

    def accumulate(self, input):
        self._agg_state.update(input)

    def merge(self, other_agg_state):
        self._agg_state.update(other_agg_state)

    def finish(self):
        return list(self._agg_state)
$$;
Copy
CREATE OR REPLACE TABLE array_table(x array) AS
SELECT ARRAY_CONSTRUCT(0, 1, 2, 3, 4, 'foo', 'bar', 'snowflake') UNION ALL
SELECT ARRAY_CONSTRUCT(1, 3, 5, 7, 9, 'foo', 'barbar', 'snowpark') UNION ALL
SELECT ARRAY_CONSTRUCT(0, 2, 4, 6, 8, 'snow');

SELECT * FROM array_table;

SELECT pythonGetUniqueValues(x) FROM array_table;
Copy

Exemple : renvoyer un dĂ©compte de chaĂźnes¶

Le code de l’exemple suivant renvoie le nombre de toutes les instances de chaünes dans un objet.

CREATE OR REPLACE AGGREGATE FUNCTION pythonMapCount(input STRING)
  RETURNS OBJECT
  LANGUAGE PYTHON
  RUNTIME_VERSION = 3.9
  HANDLER = 'PythonMapCount'
AS $$
from collections import defaultdict

class PythonMapCount:
    def __init__(self):
        self._agg_state = defaultdict(int)

    @property
    def aggregate_state(self):
        return self._agg_state

    def accumulate(self, input):
        # Increment count of lowercase input
        self._agg_state[input.lower()] += 1

    def merge(self, other_agg_state):
        for item, count in other_agg_state.items():
            self._agg_state[item] += count

    def finish(self):
        return dict(self._agg_state)
$$;
Copy
CREATE OR REPLACE TABLE string_table(x STRING);
INSERT INTO string_table SELECT 'foo' FROM TABLE(GENERATOR(ROWCOUNT => 1000));
INSERT INTO string_table SELECT 'bar' FROM TABLE(GENERATOR(ROWCOUNT => 2000));
INSERT INTO string_table SELECT 'snowflake' FROM TABLE(GENERATOR(ROWCOUNT => 50));
INSERT INTO string_table SELECT 'snowpark' FROM TABLE(GENERATOR(ROWCOUNT => 123));
INSERT INTO string_table SELECT 'SnOw' FROM TABLE(GENERATOR(ROWCOUNT => 1));
INSERT INTO string_table SELECT 'snow' FROM TABLE(GENERATOR(ROWCOUNT => 4));

SELECT pythonMapCount(x) FROM string_table;
Copy

Exemple : renvoyer les k premiĂšres valeurs les plus Ă©levĂ©es¶

Le code de l’exemple suivant renvoie une liste des plus grandes valeurs pour k. Le code accumule les valeurs d’entrĂ©e nĂ©gatives sur un tas min, puis renvoie les k valeurs les plus importantes.

CREATE OR REPLACE AGGREGATE FUNCTION pythonTopK(input INT, k INT)
  RETURNS ARRAY
  LANGUAGE PYTHON
  RUNTIME_VERSION = 3.9
  HANDLER = 'PythonTopK'
AS $$
import heapq
from dataclasses import dataclass
import itertools
from typing import List

@dataclass
class AggState:
    minheap: List[int]
    k: int

class PythonTopK:
    def __init__(self):
        self._agg_state = AggState([], 0)

    @property
    def aggregate_state(self):
        return self._agg_state

    @staticmethod
    def get_top_k_items(minheap, k):
      # Return k smallest elements if there are more than k elements on the min heap.
      if (len(minheap) > k):
        return [heapq.heappop(minheap) for i in range(k)]
      return minheap

    def accumulate(self, input, k):
        self._agg_state.k = k

        # Store the input as negative value, as heapq is a min heap.
        heapq.heappush(self._agg_state.minheap, -input)

        # Store only top k items on the min heap.
        self._agg_state.minheap = self.get_top_k_items(self._agg_state.minheap, k)

    def merge(self, other_agg_state):
        k = self._agg_state.k if self._agg_state.k > 0 else other_agg_state.k

        # Merge two min heaps by popping off elements from one and pushing them onto another.
        while(len(other_agg_state.minheap) > 0):
            heapq.heappush(self._agg_state.minheap, heapq.heappop(other_agg_state.minheap))

        # Store only k elements on the min heap.
        self._agg_state.minheap = self.get_top_k_items(self._agg_state.minheap, k)

    def finish(self):
        return [-x for x in self._agg_state.minheap]
$$;
Copy
CREATE OR REPLACE TABLE numbers_table(num_column INT);
INSERT INTO numbers_table SELECT 5 FROM TABLE(GENERATOR(ROWCOUNT => 10));
INSERT INTO numbers_table SELECT 1 FROM TABLE(GENERATOR(ROWCOUNT => 10));
INSERT INTO numbers_table SELECT 9 FROM TABLE(GENERATOR(ROWCOUNT => 10));
INSERT INTO numbers_table SELECT 7 FROM TABLE(GENERATOR(ROWCOUNT => 10));
INSERT INTO numbers_table SELECT 10 FROM TABLE(GENERATOR(ROWCOUNT => 10));
INSERT INTO numbers_table SELECT 3 FROM TABLE(GENERATOR(ROWCOUNT => 10));

-- Return top 15 largest values from numbers_table.
SELECT pythonTopK(num_column, 15) FROM numbers_table;
Copy