I have a function that (should) flatten arbitrarily many times nested list
T = TypeVar("T")
type Nested[T] = Sequence[T | Sequence[Nested]]
def flatten(seq: Nested[T]) -> list[T]:
flattened: list[T] = []
for elem in seq:
if isinstance(item, Sequence):
flattened.extend(flatten(cast(Nested[T], elem)))
else:
flattened.append(elem)
return flattened
Now if I pass in for example a list of list, the result seems to still be list of list. Why’s that?
test: list[list[str]]
# (parameter) flattened: list[list[str]]
flattened = flatten(test)
It looks like it is always returning the same type that is passed in. Why’s that? Looks like whatever is inside the first list is thought to be the generic type. How can I define this nested (recursive) type and have the flatten function to work and show the type hints correctly like this?
test: list[list[str]]
test2: list[list[list[list[int]]]]
# (parameter) flattened: list[str]
flattened = flatten(test)
# (parameter) flattened2: list[int]
flattened2 = flatten(test2)
Python used is 3.12
Edit:
Just when I posted this I found out there was a little mistake in my definitions. If the function is defined like this
T = TypeVar("T")
type Nested[T] = Sequence[T | Nested[T]] # <--- Fix here!!
def flatten(seq: Nested[T]) -> list[T]:
flattened: list[T] = []
for elem in seq:
if isinstance(item, Sequence):
flattened.extend(flatten(cast(Nested[T], elem)))
else:
flattened.append(elem)
return flattened
Also another problem was that in my actual code I was assigning the flattened list back to the variable that was already defined to be a list of list (or more). Assigning the flattened list to a new variable shows the type correctly… EXCEPT in case of bytes. I guess that is because under the hood the type bytes
is actually some kind of Iterable[int]
test: list[list[str]]
bytes_test: list[list[bytes]]
# test is already defined as list[list[str]]
# (parameter) test: list[list[str]]
test = flatten(test)
# type for flattened is inferred so it's list[str]
# (parameter) flattened: list[str]
flattened = flatten(test)
# Apparently bytes type is equal to Iterable[int], so it is flattened as well
# (parameter) flattened_bytes: list[int]
flattened_bytes = flatten(bytes_test)
Now I wonder how can I preserve the nesting in case of bytes
, so that the result would be correctly nested list[bytes]
Edit2:
There seems to be a bug of infinite recursion when using str
or bytes
as the type, because both of them are iterables. Looks like now the best way to do this kind of generic flattening is to invert the order of the isinstance if like this
T = TypeVar("T")
type Nested[T] = Sequence[T | Nested[T]]
def flatten(seq: Nested[T]) -> list[T]:
flattened: list[T] = []
for elem in seq:
if isinstance(item, T): # <-- This doesn't work. Need custom generic isinstance checker
flattened.append(elem)
else:
flattened.extend(flatten(elem))
return flattened
But then this would need some custom generic isinstance checker