diff --git a/lib/livebook/evaluator/io_proxy.ex b/lib/livebook/evaluator/io_proxy.ex index 10457a38a..ecb398169 100644 --- a/lib/livebook/evaluator/io_proxy.ex +++ b/lib/livebook/evaluator/io_proxy.ex @@ -132,11 +132,11 @@ defmodule Livebook.Evaluator.IOProxy do end defp io_request({:get_chars, prompt, count}, state) when count >= 0 do - get_chars(:latin1, prompt, state, count) + get_chars(:latin1, prompt, count, state) end defp io_request({:get_chars, encoding, prompt, count}, state) when count >= 0 do - get_chars(encoding, prompt, state, count) + get_chars(encoding, prompt, count, state) end defp io_request({:get_line, prompt}, state) do @@ -147,12 +147,12 @@ defmodule Livebook.Evaluator.IOProxy do get_line(encoding, prompt, state) end - defp io_request({:get_until, _prompt, _mod, _fun, _args}, state) do - {{:error, :enotsup}, state} + defp io_request({:get_until, prompt, mod, fun, args}, state) do + get_until(:latin1, prompt, mod, fun, args, state) end - defp io_request({:get_until, _encoding, _prompt, _mod, _fun, _args}, state) do - {{:error, :enotsup}, state} + defp io_request({:get_until, encoding, prompt, mod, fun, args}, state) do + get_until(encoding, prompt, mod, fun, args, state) end defp io_request({:get_password, _encoding}, state) do @@ -228,34 +228,29 @@ defmodule Livebook.Evaluator.IOProxy do end defp get_line(encoding, prompt, state) do - prompt = :unicode.characters_to_binary(prompt, encoding, state.encoding) - - case get_input(prompt, state) do - input when is_binary(input) -> - {line, rest} = line_from_input(input) - - line = - if is_binary(line) do - :unicode.characters_to_binary(line, state.encoding, encoding) - else - line - end - - state = put_in(state.input_buffers[prompt], rest) - {line, state} - - error -> - {error, state} - end + get_consume(encoding, prompt, state, fn input -> + line_from_input(input) + end) end - defp get_chars(encoding, prompt, state, count) do + defp get_chars(encoding, prompt, count, state) do + get_consume(encoding, prompt, state, fn input -> + chars_from_input(input, encoding, count) + end) + end + + defp get_until(encoding, prompt, mod, fun, args, state) do + get_consume(encoding, prompt, state, fn input -> + get_until_from_input(input, encoding, mod, fun, args) + end) + end + + defp get_consume(encoding, prompt, state, consume_fun) do prompt = :unicode.characters_to_binary(prompt, encoding, state.encoding) case get_input(prompt, state) do input when is_binary(input) -> - {chars, rest} = chars_from_input(input, encoding, count) - + {chars, rest} = consume_fun.(input) state = put_in(state.input_buffers[prompt], rest) {chars, state} @@ -297,61 +292,83 @@ defmodule Livebook.Evaluator.IOProxy do {input, ""} {pos, len} -> - size = byte_size(input) - line = binary_part(input, 0, pos + len) - rest = binary_part(input, pos + len, size - pos - len) - {line, rest} + :erlang.split_binary(input, pos + len) end end - defp chars_from_input("", _, _count), do: {:eof, ""} + defp chars_from_input("", _encoding, _count), do: {:eof, ""} defp chars_from_input(input, :unicode, count) do - if byte_size_utf8(input) >= count do - chars_part(input, :unicode, count) - else - {input, ""} - end + {:ok, count} = utf8_split_at(input, count) + :erlang.split_binary(input, count) end defp chars_from_input(input, :latin1, count) do - if byte_size(input) >= count do - chars_part(input, :latin1, count) + if byte_size(input) > count do + :erlang.split_binary(input, count) else {input, ""} end end - defp chars_part(chars, _, 0), do: {"", chars} + defp utf8_split_at(input, count), do: utf8_split_at(input, count, 0) - defp chars_part(input, :unicode, count) do - with {:ok, count} <- split_at(input, count, 0) do - <> = input - {chars, rest} + defp utf8_split_at(_, 0, acc), do: {:ok, acc} + + defp utf8_split_at(<>, count, acc), + do: utf8_split_at(t, count - 1, acc + byte_size(<>)) + + defp utf8_split_at(<<_, _::binary>>, _count, _acc), + do: {:error, :invalid_unicode} + + defp utf8_split_at(<<>>, _count, acc), + do: {:ok, acc} + + defp get_until_from_input(input, encoding, mod, fun, args) do + {chars, rest} = get_until_from_input(input, encoding, mod, fun, args, []) + {get_until_result(chars, encoding), rest} + end + + defp get_until_from_input("", encoding, mod, fun, args, continuation) do + case apply(mod, fun, [continuation, :eof | args]) do + {:done, result, :eof} -> + {result, ""} + + {:done, result, rest} -> + {result, list_to_binary(rest, encoding)} + + {:more, next_continuation} -> + get_until_from_input("", encoding, mod, fun, args, next_continuation) end end - defp chars_part(input, :latin1, count) do - <> = input - {chars, rest} + defp get_until_from_input(input, encoding, mod, fun, args, continuation) do + {line, rest} = line_from_input(input) + + case apply(mod, fun, [continuation, binary_to_list(line, encoding) | args]) do + {:done, result, :eof} -> + {result, rest} + + {:done, result, extra} -> + {result, list_to_binary(extra, encoding) <> rest} + + {:more, next_continuation} -> + get_until_from_input(rest, encoding, mod, fun, args, next_continuation) + end end - defp split_at(_, 0, acc), do: {:ok, acc} + defp binary_to_list(data, _) when is_list(data), do: data + defp binary_to_list(data, :unicode) when is_binary(data), do: String.to_charlist(data) + defp binary_to_list(data, :latin1) when is_binary(data), do: :erlang.binary_to_list(data) - defp split_at(<>, count, acc), - do: split_at(t, count - 1, acc + byte_size(<>)) + defp list_to_binary(data, _) when is_binary(data), do: data + defp list_to_binary(data, :unicode) when is_list(data), do: List.to_string(data) + defp list_to_binary(data, :latin1) when is_list(data), do: :erlang.list_to_binary(data) - defp split_at(<<_, _::binary>>, _count, _acc), - do: {:error, :invalid_unicode} - - defp split_at(<<>>, _count, acc), - do: {:ok, acc} - - defp byte_size_utf8(chars), do: byte_size_utf8(chars, 0) - - defp byte_size_utf8(<<>>, size), do: size - - defp byte_size_utf8(<<_h::utf8, t::binary>>, size), do: byte_size_utf8(t, size + 1) + # From https://erlang.org/doc/apps/stdlib/io_protocol.html - result can be any + # Erlang term, but if it is a list(), the I/O server can convert it to a binary(). + defp get_until_result(data, encoding) when is_list(data), do: list_to_binary(data, encoding) + defp get_until_result(data, _), do: data defp io_reply(from, reply_as, reply) do send(from, {:io_reply, reply_as, reply}) diff --git a/test/livebook/evaluator/io_proxy_test.exs b/test/livebook/evaluator/io_proxy_test.exs index 11e6db612..54abf00ca 100644 --- a/test/livebook/evaluator/io_proxy_test.exs +++ b/test/livebook/evaluator/io_proxy_test.exs @@ -37,12 +37,7 @@ defmodule Livebook.Evaluator.IOProxyTest do end test "IO.gets", %{io: io} do - pid = - spawn_link(fn -> - reply_to_input_request(:ref, "name: ", {:ok, "Jake Peralta"}, 1) - end) - - IOProxy.configure(io, pid, :ref) + configure_owner_with_input(io, "name: ", "Jake Peralta") assert IO.gets(io, "name: ") == "Jake Peralta" end @@ -57,15 +52,121 @@ defmodule Livebook.Evaluator.IOProxyTest do assert IO.gets(io, "name: ") == {:error, "no matching Livebook input found"} end + + test "IO.getn with unicode input", %{io: io} do + configure_owner_with_input(io, "name: ", "🐈 test\n") + + assert IO.getn(io, "name: ", 3) == "🐈 t" + end + + test "IO.getn returns the given number of characters", %{io: io} do + configure_owner_with_input(io, "name: ", "Jake Peralta\nAmy Santiago\n") + + assert IO.getn(io, "name: ", 13) == "Jake Peralta\n" + assert IO.getn(io, "name: ", 13) == "Amy Santiago\n" + assert IO.getn(io, "name: ", 13) == :eof + end + + test "IO.getn returns all characters if requested more than available", %{io: io} do + configure_owner_with_input(io, "name: ", "Jake Peralta\nAmy Santiago\n") + + assert IO.getn(io, "name: ", 10_000) == "Jake Peralta\nAmy Santiago\n" + end + end + + # See https://github.com/elixir-lang/elixir/blob/v1.12.1/lib/elixir/test/elixir/string_io_test.exs + defmodule GetUntilCallbacks do + def until_eof(continuation, :eof) do + {:done, continuation, :eof} + end + + def until_eof(continuation, content) do + {:more, continuation ++ content} + end + + def until_eof_then_try_more('magic-stop-prefix' ++ continuation, :eof) do + {:done, continuation, :eof} + end + + def until_eof_then_try_more(continuation, :eof) do + {:more, 'magic-stop-prefix' ++ continuation} + end + + def until_eof_then_try_more(continuation, content) do + {:more, continuation ++ content} + end + + def up_to_3_bytes(continuation, :eof) do + {:done, continuation, :eof} + end + + def up_to_3_bytes(continuation, content) do + case continuation ++ content do + [a, b, c | tail] -> {:done, [a, b, c], tail} + str -> {:more, str} + end + end + + def up_to_3_bytes_discard_rest(continuation, :eof) do + {:done, continuation, :eof} + end + + def up_to_3_bytes_discard_rest(continuation, content) do + case continuation ++ content do + [a, b, c | _tail] -> {:done, [a, b, c], :eof} + str -> {:more, str} + end + end + end + + describe ":get_until" do + test "with up_to_3_bytes", %{io: io} do + configure_owner_with_input(io, "name: ", "abcdefg") + + result = get_until(io, :unicode, "name: ", GetUntilCallbacks, :up_to_3_bytes) + assert result == "abc" + assert IO.gets(io, "name: ") == "defg" + end + + test "with up_to_3_bytes_discard_rest", %{io: io} do + configure_owner_with_input(io, "name: ", "abcdefg") + + result = get_until(io, :unicode, "name: ", GetUntilCallbacks, :up_to_3_bytes_discard_rest) + assert result == "abc" + assert IO.gets(io, "name: ") == :eof + end + + test "with until_eof", %{io: io} do + configure_owner_with_input(io, "name: ", "abc\nd") + + result = get_until(io, :unicode, "name: ", GetUntilCallbacks, :until_eof) + assert result == "abc\nd" + end + + test "with until_eof and \\r\\n", %{io: io} do + configure_owner_with_input(io, "name: ", "abc\r\nd") + + result = get_until(io, :unicode, "name: ", GetUntilCallbacks, :until_eof) + assert result == "abc\r\nd" + end + + test "with until_eof_then_try_more", %{io: io} do + configure_owner_with_input(io, "name: ", "abc\nd") + + result = get_until(io, :unicode, "name: ", GetUntilCallbacks, :until_eof_then_try_more) + assert result == "abc\nd" + end + + test "with raw bytes (latin1)", %{io: io} do + configure_owner_with_input(io, "name: ", <<181, 255, 194, ?\n>>) + + result = get_until(io, :latin1, "name: ", GetUntilCallbacks, :until_eof) + assert result == <<181, 255, 194, ?\n>> + end end test "consumes the given input only once", %{io: io} do - pid = - spawn_link(fn -> - reply_to_input_request(:ref, "name: ", {:ok, "Jake Peralta\nAmy Santiago\n"}, 1) - end) - - IOProxy.configure(io, pid, :ref) + configure_owner_with_input(io, "name: ", "Jake Peralta\nAmy Santiago\n") assert IO.gets(io, "name: ") == "Jake Peralta\n" assert IO.gets(io, "name: ") == "Amy Santiago\n" @@ -121,43 +222,21 @@ defmodule Livebook.Evaluator.IOProxyTest do assert IOProxy.flush_widgets(io) == MapSet.new() end - test "getn/1 return first character", %{io: io} do - pid = - spawn_link(fn -> - reply_to_input_request(:ref, "name: ", {:ok, "🐈 test\n"}, 1) - end) - - IOProxy.configure(io, pid, :ref) - - assert IO.getn(io, "name: ") == "🐈" - end - - test "getn/2 returns the number of defined characters ", %{io: io} do - pid = - spawn_link(fn -> - reply_to_input_request(:ref, "name: ", {:ok, "Jake Peralta\nAmy Santiago\n"}, 1) - end) - - IOProxy.configure(io, pid, :ref) - - assert IO.getn(io, "name: ", 13) == "Jake Peralta\n" - assert IO.getn(io, "name: ", 13) == "Amy Santiago\n" - assert IO.getn(io, "name: ", 13) == :eof - end - - test "getn/2 all characters", %{io: io} do - pid = - spawn_link(fn -> - reply_to_input_request(:ref, "name: ", {:ok, "Jake Peralta\nAmy Santiago\n"}, 1) - end) - - IOProxy.configure(io, pid, :ref) - - assert IO.getn(io, "name: ", 10_000) == "Jake Peralta\nAmy Santiago\n" - end - # Helpers + defp get_until(pid, encoding, prompt, module, function) do + :io.request(pid, {:get_until, encoding, prompt, module, function, []}) + end + + defp configure_owner_with_input(io, prompt, input) do + pid = + spawn_link(fn -> + reply_to_input_request(:ref, prompt, {:ok, input}, 1) + end) + + IOProxy.configure(io, pid, :ref) + end + defp reply_to_input_request(_ref, _prompt, _reply, 0), do: :ok defp reply_to_input_request(ref, prompt, reply, times) do