es.davy.ai

Preguntas y respuestas de programación confiables

¿Tienes una pregunta?

Si tienes alguna pregunta, puedes hacerla a continuación o ingresar lo que estás buscando.

La función Batch Generator devuelve únicamente el último elemento al convertirse en lista.

Tengo el siguiente código para generar lotes a partir de una lista de elementos:

def batch_generator(items, batch_size):
    count = 1
    chunk = []

    for item in items:
        if count % batch_size:
            chunk.append(item)
        else:
            chunk.append(item)
            yield chunk
            chunk.clear()
        count += 1

    if len(chunk):
        yield chunk

Iterar uno por uno produce resultados esperados:

for x in batch_generator(range(17), 5):
    print(x)

[0, 1, 2, 3, 4]

[5, 6, 7, 8, 9]

[10, 11, 12, 13, 14]

[15, 16]

Sin embargo, cuando convierto el generador directamente en una lista, ¡solo se devuelve el último elemento varias veces!

list(batch_generator(range(17), 5))

[[15, 16], [15, 16], [15, 16], [15, 16]]

Mientras que un generador simple convertido en lista funciona correctamente:

list(([i, i*2, i*3] for i in range(5)))

[[0, 0, 0], [1, 2, 3], [2, 4, 6], [3, 6, 9], [4, 8, 12]]

¿Por qué está sucediendo esto?

Tags:  , , ,

Answer

  1. Avatar for davy.ai

    Esto está sucediendo porque la función batch_generator no está devolviendo una nueva lista cada vez que se produce. En su lugar, está devolviendo la misma lista cada vez y la lista está siendo modificada cada vez que se produce. Cuando conviertes el generador en una lista, la lista se está poblado con la misma lista varias veces.

    Para solucionar esto, necesitas crear una nueva lista cada vez que se genere. Puedes hacer esto usando el constructor list() para crear una nueva lista cada vez que se produzca.

    def batch_generator(items, batch_size):
        count = 1
        chunk = []
    
        for item in items:
            if count % batch_size:
                chunk.append(item)
            else:
                chunk.append(item)
                yield list(chunk)
                chunk.clear()
            count += 1
    
        if len(chunk):
            yield list(chunk)
    

    Ahora, cuando conviertes el generador a una lista, se retornan los resultados esperados:

    list(batch_generator(range(17), 5))
    # [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9], [10, 11, 12, 13, 14], [15, 16]]
    

Comments are closed.