aboutsummaryrefslogtreecommitdiff
path: root/R2R/tests/test_llms.py
blob: 666bbff86d10464e46466810c13cbeefbb6e22fa (about) (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import pytest

from r2r import LLMConfig
from r2r.base.abstractions.llm import GenerationConfig
from r2r.providers.llms import LiteLLM


@pytest.fixture
def lite_llm():
    config = LLMConfig(provider="litellm")
    return LiteLLM(config)


@pytest.mark.parametrize("llm_fixture", ["lite_llm"])
def test_get_completion_ollama(request, llm_fixture):
    llm = request.getfixturevalue(llm_fixture)

    messages = [
        {
            "role": "user",
            "content": "This is a test, return only the word `True`",
        }
    ]
    generation_config = GenerationConfig(
        model="ollama/llama2",
        temperature=0.0,
        top_p=0.9,
        max_tokens_to_sample=50,
        stream=False,
    )

    completion = llm.get_completion(messages, generation_config)
    # assert isinstance(completion, LLMChatCompletion)
    assert completion.choices[0].message.role == "assistant"
    assert completion.choices[0].message.content.strip() == "True"


@pytest.mark.parametrize("llm_fixture", ["lite_llm"])
def test_get_completion_openai(request, llm_fixture):
    llm = request.getfixturevalue(llm_fixture)

    messages = [
        {
            "role": "user",
            "content": "This is a test, return only the word `True`",
        }
    ]
    generation_config = GenerationConfig(
        model="gpt-3.5-turbo",
        temperature=0.0,
        top_p=0.9,
        max_tokens_to_sample=50,
        stream=False,
    )

    completion = llm.get_completion(messages, generation_config)
    # assert isinstance(completion, LLMChatCompletion)
    assert completion.choices[0].message.role == "assistant"
    assert completion.choices[0].message.content.strip() == "True"