Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 39 additions & 10 deletions integration/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,33 @@
from weaviate.exceptions import UnexpectedStatusCodeError


_DEFAULT_HOST: str = os.environ.get("WV_TEST_HOST", "localhost")
_DEFAULT_VECTOR_HOST: str = os.environ.get("WV_TEST_VECTOR_HOST", "localhost")
_DEFAULT_CLUSTER_HOST: str = os.environ.get("WV_TEST_CLUSTER_HOST", "localhost")
_DEFAULT_RBAC_HOST: str = os.environ.get("WV_TEST_RBAC_HOST", "localhost")
_DEFAULT_BROKEN_HOST: str = os.environ.get("WV_TEST_BROKEN_HOST", "localhost")
_DEFAULT_PRIMARY_PORTS: Tuple[int, int] = (
int(os.environ.get("WV_TEST_REST_PORT", "8080")),
int(os.environ.get("WV_TEST_GRPC_PORT", "50051")),
)
_DEFAULT_VECTOR_PORTS: Tuple[int, int] = (
int(os.environ.get("WV_TEST_VECTOR_REST_PORT", "8086")),
int(os.environ.get("WV_TEST_VECTOR_GRPC_PORT", "50057")),
)
_DEFAULT_CLUSTER_PORTS: Tuple[int, int] = (
int(os.environ.get("WV_TEST_CLUSTER_REST_PORT", "8087")),
int(os.environ.get("WV_TEST_CLUSTER_GRPC_PORT", "50058")),
)
_DEFAULT_RBAC_PORTS: Tuple[int, int] = (
int(os.environ.get("WV_TEST_RBAC_REST_PORT", "8092")),
int(os.environ.get("WV_TEST_RBAC_GRPC_PORT", "50063")),
)
_DEFAULT_BROKEN_PORTS: Tuple[int, int] = (
int(os.environ.get("WV_TEST_BROKEN_REST_PORT", "8888")),
int(os.environ.get("WV_TEST_BROKEN_GRPC_PORT", "55555")),
)


class CollectionFactory(Protocol):
"""Typing for fixture."""

Expand All @@ -59,7 +86,7 @@ def __call__(
multi_tenancy_config: Optional[_MultiTenancyConfigCreate] = None,
generative_config: Optional[_GenerativeProvider] = None,
headers: Optional[Dict[str, str]] = None,
ports: Tuple[int, int] = (8080, 50051),
ports: Tuple[int, int] = _DEFAULT_PRIMARY_PORTS,
data_model_properties: Optional[Type[Properties]] = None,
data_model_refs: Optional[Type[Properties]] = None,
replication_config: Optional[_ReplicationConfigCreate] = None,
Expand All @@ -81,7 +108,7 @@ class ClientFactory(Protocol):
def __call__(
self,
headers: Optional[Dict[str, str]] = None,
ports: Tuple[int, int] = (8080, 50051),
ports: Tuple[int, int] = _DEFAULT_PRIMARY_PORTS,
auth_credentials: Optional[weaviate.auth.AuthCredentials] = None,
) -> weaviate.WeaviateClient:
"""Typing for fixture."""
Expand All @@ -94,12 +121,13 @@ def client_factory() -> Generator[ClientFactory, None, None]:

def _factory(
headers: Optional[Dict[str, str]] = None,
ports: Tuple[int, int] = (8080, 50051),
ports: Tuple[int, int] = _DEFAULT_PRIMARY_PORTS,
auth_credentials: Optional[weaviate.auth.AuthCredentials] = None,
) -> weaviate.WeaviateClient:
nonlocal client_fixture
if client_fixture is None:
client_fixture = weaviate.connect_to_local(
host=_DEFAULT_HOST,
headers=headers,
grpc_port=ports[1],
port=ports[0],
Expand Down Expand Up @@ -134,7 +162,7 @@ def _factory(
multi_tenancy_config: Optional[_MultiTenancyConfigCreate] = None,
generative_config: Optional[_GenerativeProvider] = None,
headers: Optional[Dict[str, str]] = None,
ports: Tuple[int, int] = (8080, 50051),
ports: Tuple[int, int] = _DEFAULT_PRIMARY_PORTS,
data_model_properties: Optional[Type[Properties]] = None,
data_model_refs: Optional[Type[Properties]] = None,
replication_config: Optional[_ReplicationConfigCreate] = None,
Expand Down Expand Up @@ -210,7 +238,7 @@ async def __call__(
multi_tenancy_config: Optional[_MultiTenancyConfigCreate] = None,
generative_config: Optional[_GenerativeProvider] = None,
headers: Optional[Dict[str, str]] = None,
ports: Tuple[int, int] = (8080, 50051),
ports: Tuple[int, int] = _DEFAULT_PRIMARY_PORTS,
data_model_properties: Optional[Type[Properties]] = None,
data_model_refs: Optional[Type[Properties]] = None,
replication_config: Optional[_ReplicationConfigCreate] = None,
Expand All @@ -228,7 +256,7 @@ class AsyncClientFactory(Protocol):
async def __call__(
self,
headers: Optional[Dict[str, str]] = None,
ports: Tuple[int, int] = (8080, 50051),
ports: Tuple[int, int] = _DEFAULT_PRIMARY_PORTS,
) -> weaviate.WeaviateAsyncClient:
"""Typing for fixture."""
...
Expand All @@ -240,11 +268,12 @@ async def async_client_factory() -> AsyncGenerator[AsyncClientFactory, None]:

async def _factory(
headers: Optional[Dict[str, str]] = None,
ports: Tuple[int, int] = (8080, 50051),
ports: Tuple[int, int] = _DEFAULT_PRIMARY_PORTS,
) -> weaviate.WeaviateAsyncClient:
nonlocal client_fixture
if client_fixture is None:
client_fixture = weaviate.use_async_with_local(
host=_DEFAULT_HOST,
headers=headers,
grpc_port=ports[1],
port=ports[0],
Expand Down Expand Up @@ -278,7 +307,7 @@ async def _factory(
multi_tenancy_config: Optional[_MultiTenancyConfigCreate] = None,
generative_config: Optional[_GenerativeProvider] = None,
headers: Optional[Dict[str, str]] = None,
ports: Tuple[int, int] = (8080, 50051),
ports: Tuple[int, int] = _DEFAULT_PRIMARY_PORTS,
data_model_properties: Optional[Type[Properties]] = None,
data_model_refs: Optional[Type[Properties]] = None,
replication_config: Optional[_ReplicationConfigCreate] = None,
Expand Down Expand Up @@ -369,7 +398,7 @@ def _factory(
Property(name="extra", data_type=DataType.TEXT),
],
generative_config=Configure.Generative.openai(),
ports=(8086, 50057),
ports=_DEFAULT_VECTOR_PORTS,
headers={"X-OpenAI-Api-Key": api_key},
)

Expand Down Expand Up @@ -418,7 +447,7 @@ async def _factory(
Property(name="extra", data_type=DataType.TEXT),
],
generative_config=Configure.Generative.openai(),
ports=(8086, 50057),
ports=_DEFAULT_VECTOR_PORTS,
headers={"X-OpenAI-Api-Key": api_key},
)

Expand Down
87 changes: 55 additions & 32 deletions integration/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,15 @@
from weaviate.collections.classes.filters import Filter
from weaviate.exceptions import AuthenticationFailedError, UnexpectedStatusCodeError

ANON_PORT = 8080
OKTA_PORT_CC = 8082
OKTA_PORT_USERS = 8083
WCS_PORT = 8085
WCS_PORT_GRPC = 50056
ANON_PORT = int(os.environ.get("WV_TEST_REST_PORT", "8080"))
OKTA_PORT_CC = int(os.environ.get("WV_TEST_OIDC_OKTA_CC_PORT", "8082"))
OKTA_PORT_USERS = int(os.environ.get("WV_TEST_OIDC_OKTA_USERS_PORT", "8083"))
WCS_PORT = int(os.environ.get("WV_TEST_AUTH_REST_PORT", "8085"))
WCS_PORT_GRPC = int(os.environ.get("WV_TEST_AUTH_GRPC_PORT", "50056"))

ANON_HOST = os.environ.get("WV_TEST_HOST", "localhost")
OIDC_HOST = os.environ.get("WV_TEST_OIDC_HOST", "localhost")
WCS_HOST = os.environ.get("WV_TEST_AUTH_HOST", "localhost")


def is_auth_enabled(url: str) -> bool:
Expand All @@ -29,41 +33,43 @@ def is_auth_enabled(url: str) -> bool:

def test_no_auth_provided() -> None:
"""Test exception when trying to access a weaviate that requires authentication."""
assert is_auth_enabled(f"localhost:{OKTA_PORT_CC}")
assert is_auth_enabled(f"{OIDC_HOST}:{OKTA_PORT_CC}")
with pytest.raises(AuthenticationFailedError):
weaviate.connect_to_local(port=OKTA_PORT_CC)
weaviate.connect_to_local(host=OIDC_HOST, port=OKTA_PORT_CC)


@pytest.mark.parametrize(
"name,env_variable_name,port,scope",
"name,env_variable_name,host,port,scope",
[
("okta", "OKTA_CLIENT_SECRET", OKTA_PORT_CC, "some_scope"),
("okta", "OKTA_CLIENT_SECRET", OIDC_HOST, OKTA_PORT_CC, "some_scope"),
],
)
def test_authentication_client_credentials(
name: str, env_variable_name: str, port: int, scope: Optional[str]
name: str, env_variable_name: str, host: str, port: int, scope: Optional[str]
) -> None:
"""Test client credential flow with various providers."""
client_secret = os.environ.get(env_variable_name)
if client_secret is None:
pytest.skip(f"No {name} login data found.")

assert is_auth_enabled(f"localhost:{port}")
assert is_auth_enabled(f"{host}:{port}")

with weaviate.connect_to_local(
host=host,
port=port,
auth_credentials=wvc.init.Auth.client_credentials(client_secret=client_secret, scope=scope),
) as client:
client.collections.list_all() # no exception


@pytest.mark.parametrize(
"name,user,env_variable_name,port,scope,warning",
"name,user,env_variable_name,host,port,scope,warning",
[
( # WCS keycloak times out too often
"WCS",
"oidc-test-user@weaviate.io",
"WCS_DUMMY_CI_PW",
WCS_HOST,
WCS_PORT,
None,
False,
Expand All @@ -72,6 +78,7 @@ def test_authentication_client_credentials(
"okta",
"test@test.de",
"OKTA_DUMMY_CI_PW",
OIDC_HOST,
OKTA_PORT_USERS,
"some_scope offline_access",
False,
Expand All @@ -80,6 +87,7 @@ def test_authentication_client_credentials(
"okta - no refresh",
"test@test.de",
"OKTA_DUMMY_CI_PW",
OIDC_HOST,
OKTA_PORT_USERS,
"some_scope",
True,
Expand All @@ -91,12 +99,13 @@ def test_authentication_user_pw(
name: str,
user: str,
env_variable_name: str,
host: str,
port: int,
scope: str,
warning: bool,
) -> None:
"""Test authentication using Resource Owner Password Credentials Grant (User + PW)."""
assert is_auth_enabled(f"localhost:{port}")
assert is_auth_enabled(f"{host}:{port}")

pw = os.environ.get(env_variable_name)
if pw is None:
Expand All @@ -107,7 +116,7 @@ def test_authentication_user_pw(
else:
auth = wvc.init.Auth.client_password(username=user, password=pw)

with weaviate.connect_to_local(port=port, auth_credentials=auth) as client:
with weaviate.connect_to_local(host=host, port=port, auth_credentials=auth) as client:
client.collections.list_all() # no exception

if warning:
Expand All @@ -124,12 +133,14 @@ def test_authentication_user_pw(

def test_client_with_authentication_with_anon_weaviate() -> None:
"""Test that we warn users when their client has auth enabled, but weaviate has only anon access."""
assert not is_auth_enabled(f"localhost:{ANON_PORT}")
assert not is_auth_enabled(f"{ANON_HOST}:{ANON_PORT}")

auth = wvc.init.Auth.client_password(username="someUser", password="SomePw")
with pytest.warns(UserWarning) as recwarn:
warnings.filterwarnings(action="ignore", message=r"datetime.datetime.utcnow")
with weaviate.connect_to_local(auth_credentials=auth) as client:
with weaviate.connect_to_local(
host=ANON_HOST, port=ANON_PORT, auth_credentials=auth
) as client:
client.collections.list_all()
if len(recwarn) > 1:
for rwarning in recwarn.list:
Expand Down Expand Up @@ -163,27 +174,29 @@ def _get_access_token(url: str, user: str, pw: str) -> Dict[str, str]:


@pytest.mark.parametrize(
"name,user,env_variable_name,port",
"name,user,env_variable_name,host,port",
[
( # WCS keycloak times out too often
"WCS",
"oidc-test-user@weaviate.io",
"WCS_DUMMY_CI_PW",
WCS_HOST,
WCS_PORT,
),
(
"okta",
"test@test.de",
"OKTA_DUMMY_CI_PW",
OIDC_HOST,
OKTA_PORT_USERS,
),
],
)
def test_authentication_with_bearer_token(
name: str, user: str, env_variable_name: str, port: int
name: str, user: str, env_variable_name: str, host: str, port: int
) -> None:
"""Test authentication using existing bearer token."""
url = f"localhost:{port}"
url = f"{host}:{port}"
assert is_auth_enabled(url)
pw = os.environ.get(env_variable_name)
if pw is None:
Expand All @@ -196,13 +209,13 @@ def test_authentication_with_bearer_token(
expires_in=int(token["expires_in"]),
refresh_token=token["refresh_token"],
)
with weaviate.connect_to_local(port=port, auth_credentials=auth) as client:
with weaviate.connect_to_local(host=host, port=port, auth_credentials=auth) as client:
client.collections.exists("something")


def test_authentication_with_bearer_token_no_refresh() -> None:
"""Test authentication using existing bearer token."""
url = f"localhost:{OKTA_PORT_USERS}"
url = f"{OIDC_HOST}:{OKTA_PORT_USERS}"
assert is_auth_enabled(url)
pw = os.environ.get("OKTA_DUMMY_CI_PW")
if pw is None:
Expand All @@ -215,30 +228,37 @@ def test_authentication_with_bearer_token_no_refresh() -> None:
expires_in=int(token["expires_in"]),
)
with pytest.warns(UserWarning) as recwarn:
with weaviate.connect_to_local(port=OKTA_PORT_USERS, auth_credentials=auth) as client:
with weaviate.connect_to_local(
host=OIDC_HOST, port=OKTA_PORT_USERS, auth_credentials=auth
) as client:
client.collections.list_all()
assert len(recwarn) == 1
assert str(recwarn.list[0].message).startswith("Auth002")


def test_api_key_string() -> None:
assert is_auth_enabled(f"localhost:{WCS_PORT}")
with weaviate.connect_to_local(port=WCS_PORT, auth_credentials="my-secret-key") as client:
assert is_auth_enabled(f"{WCS_HOST}:{WCS_PORT}")
with weaviate.connect_to_local(
host=WCS_HOST, port=WCS_PORT, auth_credentials="my-secret-key"
) as client:
client.collections.list_all()


def test_api_key() -> None:
assert is_auth_enabled(f"localhost:{WCS_PORT}")
assert is_auth_enabled(f"{WCS_HOST}:{WCS_PORT}")
with weaviate.connect_to_local(
port=WCS_PORT, auth_credentials=wvc.init.Auth.api_key(api_key="my-secret-key")
host=WCS_HOST,
port=WCS_PORT,
auth_credentials=wvc.init.Auth.api_key(api_key="my-secret-key"),
) as client:
client.collections.list_all()


@pytest.mark.parametrize("creds", [None, grpc.ssl_channel_credentials()])
def test_custom_grpc_credentials(creds: Optional[grpc.ChannelCredentials]) -> None:
assert is_auth_enabled(f"localhost:{WCS_PORT}")
assert is_auth_enabled(f"{WCS_HOST}:{WCS_PORT}")
with weaviate.connect_to_local(
host=WCS_HOST,
port=WCS_PORT,
grpc_port=WCS_PORT_GRPC,
auth_credentials=wvc.init.Auth.api_key(api_key="my-secret-key"),
Expand All @@ -251,29 +271,32 @@ def test_custom_grpc_credentials(creds: Optional[grpc.ChannelCredentials]) -> No

@pytest.mark.parametrize("header_name", ["Authorization", "authorization"])
def test_api_key_in_header(header_name: str) -> None:
assert is_auth_enabled(f"localhost:{WCS_PORT}")
assert is_auth_enabled(f"{WCS_HOST}:{WCS_PORT}")
with weaviate.connect_to_local(
port=WCS_PORT, headers={header_name: "Bearer my-secret-key"}
host=WCS_HOST, port=WCS_PORT, headers={header_name: "Bearer my-secret-key"}
) as client:
client.collections.list_all()


def test_api_key_wrong_key() -> None:
assert is_auth_enabled(f"localhost:{WCS_PORT}")
assert is_auth_enabled(f"{WCS_HOST}:{WCS_PORT}")

with pytest.raises(UnexpectedStatusCodeError) as e:
weaviate.connect_to_local(
port=WCS_PORT, auth_credentials=wvc.init.Auth.api_key(api_key="my-secret-key-wrong")
host=WCS_HOST,
port=WCS_PORT,
auth_credentials=wvc.init.Auth.api_key(api_key="my-secret-key-wrong"),
)
assert e.value.status_code == 401


def test_auth_e2e(request: SubRequest) -> None:
name = _sanitize_collection_name(request.node.name)
url = f"localhost:{WCS_PORT}"
url = f"{WCS_HOST}:{WCS_PORT}"
assert is_auth_enabled(url)

with weaviate.connect_to_local(
host=WCS_HOST,
port=WCS_PORT,
grpc_port=WCS_PORT_GRPC,
auth_credentials=wvc.init.Auth.api_key(api_key="my-secret-key"),
Expand Down
Loading
Loading