diff --git a/solr/modules/language-models/src/java/org/apache/solr/languagemodels/model/SolrLanguageModel.java b/solr/modules/language-models/src/java/org/apache/solr/languagemodels/model/SolrLanguageModel.java new file mode 100644 index 000000000000..78a181d278a7 --- /dev/null +++ b/solr/modules/language-models/src/java/org/apache/solr/languagemodels/model/SolrLanguageModel.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.languagemodels.model; + +import java.util.Map; + +/** + * Abstract base class for Solr-managed wrappers around langchain4j used in {@code language-models} module + */ +public abstract class SolrLanguageModel { + + // common parameters + protected static final String TIMEOUT_PARAM = "timeout"; + protected static final String MAX_RETRIES_PARAM = "maxRetries"; + + protected final String name; + protected final Map params; + + protected SolrLanguageModel(String name, Map params) { + this.name = name; + this.params = params; + } + + public String getName() { + return name; + } + + public Map getParams() { + return params; + } + + /** Returns the class name of the underlying langchain4j model instance. */ + public abstract String getModelClassName(); +} diff --git a/solr/modules/language-models/src/java/org/apache/solr/languagemodels/model/package-info.java b/solr/modules/language-models/src/java/org/apache/solr/languagemodels/model/package-info.java new file mode 100644 index 000000000000..f385bff798d3 --- /dev/null +++ b/solr/modules/language-models/src/java/org/apache/solr/languagemodels/model/package-info.java @@ -0,0 +1,19 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** Contains model related classes. */ +package org.apache.solr.languagemodels.model; diff --git a/solr/modules/language-models/src/java/org/apache/solr/languagemodels/textvectorisation/store/TextToVectorModelException.java b/solr/modules/language-models/src/java/org/apache/solr/languagemodels/store/LanguageModelException.java similarity index 78% rename from solr/modules/language-models/src/java/org/apache/solr/languagemodels/textvectorisation/store/TextToVectorModelException.java rename to solr/modules/language-models/src/java/org/apache/solr/languagemodels/store/LanguageModelException.java index 8709ebf69298..6710ae85903a 100644 --- a/solr/modules/language-models/src/java/org/apache/solr/languagemodels/textvectorisation/store/TextToVectorModelException.java +++ b/solr/modules/language-models/src/java/org/apache/solr/languagemodels/store/LanguageModelException.java @@ -14,17 +14,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.solr.languagemodels.textvectorisation.store; +package org.apache.solr.languagemodels.store; -public class TextToVectorModelException extends RuntimeException { +public class LanguageModelException extends RuntimeException { private static final long serialVersionUID = 1L; - public TextToVectorModelException(String message) { + public LanguageModelException(String message) { super(message); } - public TextToVectorModelException(String message, Exception cause) { + public LanguageModelException(String message, Exception cause) { super(message, cause); } } diff --git a/solr/modules/language-models/src/java/org/apache/solr/languagemodels/textvectorisation/store/TextToVectorModelStore.java b/solr/modules/language-models/src/java/org/apache/solr/languagemodels/store/LanguageModelStore.java similarity index 58% rename from solr/modules/language-models/src/java/org/apache/solr/languagemodels/textvectorisation/store/TextToVectorModelStore.java rename to solr/modules/language-models/src/java/org/apache/solr/languagemodels/store/LanguageModelStore.java index 7d24d25f57e3..a8c2aabaefaf 100644 --- a/solr/modules/language-models/src/java/org/apache/solr/languagemodels/textvectorisation/store/TextToVectorModelStore.java +++ b/solr/modules/language-models/src/java/org/apache/solr/languagemodels/store/LanguageModelStore.java @@ -14,25 +14,25 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.solr.languagemodels.textvectorisation.store; +package org.apache.solr.languagemodels.store; import java.util.ArrayList; import java.util.Collections; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; -import org.apache.solr.languagemodels.textvectorisation.model.SolrTextToVectorModel; +import org.apache.solr.languagemodels.model.SolrLanguageModel; -/** Simple store to manage CRUD operations on the {@link SolrTextToVectorModel} */ -public class TextToVectorModelStore { +/** Generic store to manage CRUD operations on models that extend {@link SolrLanguageModel} */ +public class LanguageModelStore { - private final Map availableModels; + private final Map availableModels; - public TextToVectorModelStore() { + public LanguageModelStore() { availableModels = Collections.synchronizedMap(new LinkedHashMap<>()); } - public SolrTextToVectorModel getModel(String name) { + public M getModel(String name) { return availableModels.get(name); } @@ -40,27 +40,26 @@ public void clear() { availableModels.clear(); } - public List getModels() { + public List getModels() { synchronized (availableModels) { - final List availableModelsValues = - new ArrayList(availableModels.values()); + final List availableModelsValues = new ArrayList<>(availableModels.values()); return Collections.unmodifiableList(availableModelsValues); } } @Override public String toString() { - return "ModelStore [availableModels=" + availableModels.keySet() + "]"; + return "LanguageModelStore [availableModels=" + availableModels.keySet() + "]"; } - public SolrTextToVectorModel delete(String modelName) { + public M delete(String modelName) { return availableModels.remove(modelName); } - public void addModel(SolrTextToVectorModel modeldata) throws TextToVectorModelException { - final String name = modeldata.getName(); - if (availableModels.putIfAbsent(modeldata.getName(), modeldata) != null) { - throw new TextToVectorModelException( + public void addModel(M modelData) throws LanguageModelException { + final String name = modelData.getName(); + if (availableModels.putIfAbsent(name, modelData) != null) { + throw new LanguageModelException( "model '" + name + "' already exists. Please use a different name"); } } diff --git a/solr/modules/language-models/src/java/org/apache/solr/languagemodels/textvectorisation/store/package-info.java b/solr/modules/language-models/src/java/org/apache/solr/languagemodels/store/package-info.java similarity index 92% rename from solr/modules/language-models/src/java/org/apache/solr/languagemodels/textvectorisation/store/package-info.java rename to solr/modules/language-models/src/java/org/apache/solr/languagemodels/store/package-info.java index 5e79341f9927..7a80ec25fb37 100644 --- a/solr/modules/language-models/src/java/org/apache/solr/languagemodels/textvectorisation/store/package-info.java +++ b/solr/modules/language-models/src/java/org/apache/solr/languagemodels/store/package-info.java @@ -16,4 +16,4 @@ */ /** Contains model store related classes. */ -package org.apache.solr.languagemodels.textvectorisation.store; +package org.apache.solr.languagemodels.store; diff --git a/solr/modules/language-models/src/java/org/apache/solr/languagemodels/store/rest/ManagedLanguageModelStore.java b/solr/modules/language-models/src/java/org/apache/solr/languagemodels/store/rest/ManagedLanguageModelStore.java new file mode 100644 index 000000000000..4d4316052524 --- /dev/null +++ b/solr/modules/language-models/src/java/org/apache/solr/languagemodels/store/rest/ManagedLanguageModelStore.java @@ -0,0 +1,156 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.languagemodels.store.rest; + +import java.lang.invoke.MethodHandles; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import net.jcip.annotations.ThreadSafe; +import org.apache.solr.common.SolrException; +import org.apache.solr.common.util.NamedList; +import org.apache.solr.core.SolrResourceLoader; +import org.apache.solr.languagemodels.model.SolrLanguageModel; +import org.apache.solr.languagemodels.store.LanguageModelException; +import org.apache.solr.languagemodels.store.LanguageModelStore; +import org.apache.solr.response.SolrQueryResponse; +import org.apache.solr.rest.BaseSolrResource; +import org.apache.solr.rest.ManagedResource; +import org.apache.solr.rest.ManagedResourceStorage; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Abstract base class for {@link ManagedResource} wrappers that expose a {@link LanguageModelStore} + * via the REST API. Concrete subclasses supply the REST endpoint and the model instantiation logic. + */ +@ThreadSafe +public abstract class ManagedLanguageModelStore extends ManagedResource + implements ManagedResource.ChildResourceSupport { + private static final Logger log = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass()); + + private static final String MODELS_JSON_FIELD = "models"; + + protected static final String CLASS_KEY = "class"; + protected static final String NAME_KEY = "name"; + protected static final String PARAMS_KEY = "params"; + + private final LanguageModelStore store; + private Object managedData; + + protected ManagedLanguageModelStore( + String resourceId, SolrResourceLoader loader, ManagedResourceStorage.StorageIO storageIO) + throws SolrException { + super(resourceId, loader, storageIO); + store = new LanguageModelStore<>(); + } + + /** + * Creates a model instance from the JSON map persisted in the managed resource storage. + * + * @param loader the resource loader for the current core + * @param modelMap a map containing {@code "class"}, {@code "name"}, and {@code "params"} keys + * @return the instantiated model + */ + protected abstract M fromModelMap(SolrResourceLoader loader, Map modelMap); + + private static LinkedHashMap toModelMap(SolrLanguageModel model) { + final LinkedHashMap modelMap = new LinkedHashMap<>(3, 1.0f); + modelMap.put(NAME_KEY, model.getName()); + modelMap.put(CLASS_KEY, model.getModelClassName()); + modelMap.put(PARAMS_KEY, model.getParams()); + return modelMap; + } + + @Override + protected void onManagedDataLoadedFromStorage(NamedList managedInitArgs, Object managedData) + throws SolrException { + store.clear(); + this.managedData = managedData; + } + + public void loadStoredModels() { + log.info("------ managed models ~ loading ------"); + if ((managedData != null) && (managedData instanceof List)) { + @SuppressWarnings("unchecked") + final List> models = (List>) managedData; + for (final Map model : models) { + addModelFromMap(model); + } + } + } + + private void addModelFromMap(Map modelMap) { + try { + addModel(fromModelMap(solrResourceLoader, modelMap)); + } catch (final LanguageModelException e) { + throw new SolrException(SolrException.ErrorCode.BAD_REQUEST, e); + } + } + + public void addModel(M model) throws SolrException { + try { + if (log.isInfoEnabled()) { + log.info("adding model {}", model.getName()); + } + store.addModel(model); + } catch (final LanguageModelException e) { + throw new SolrException(SolrException.ErrorCode.BAD_REQUEST, e); + } + } + + @SuppressWarnings("unchecked") + @Override + protected Object applyUpdatesToManagedData(Object updates) { + if (updates instanceof List) { + final List> models = (List>) updates; + for (final Map model : models) { + addModelFromMap(model); + } + } + if (updates instanceof Map) { + addModelFromMap((Map) updates); + } + return modelsAsManagedResources(store.getModels()); + } + + @Override + public void doDeleteChild(BaseSolrResource endpoint, String childId) { + store.delete(childId); + storeManagedData(applyUpdatesToManagedData(null)); + } + + @Override + public void doGet(BaseSolrResource endpoint, String childId) { + final SolrQueryResponse response = endpoint.getSolrResponse(); + response.add(MODELS_JSON_FIELD, modelsAsManagedResources(store.getModels())); + } + + public M getModel(String modelName) { + return store.getModel(modelName); + } + + private static List modelsAsManagedResources(List models) { + return models.stream().map(ManagedLanguageModelStore::toModelMap).collect(Collectors.toList()); + } + + @Override + public String toString() { + return getClass().getSimpleName() + " [store=" + store + "]"; + } +} diff --git a/solr/modules/language-models/src/java/org/apache/solr/languagemodels/store/rest/package-info.java b/solr/modules/language-models/src/java/org/apache/solr/languagemodels/store/rest/package-info.java new file mode 100644 index 000000000000..dd4548c93a5d --- /dev/null +++ b/solr/modules/language-models/src/java/org/apache/solr/languagemodels/store/rest/package-info.java @@ -0,0 +1,19 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** Contains model store rest related classes. */ +package org.apache.solr.languagemodels.store.rest; diff --git a/solr/modules/language-models/src/java/org/apache/solr/languagemodels/textvectorisation/model/SolrTextToVectorModel.java b/solr/modules/language-models/src/java/org/apache/solr/languagemodels/textvectorisation/model/SolrTextToVectorModel.java index 21f7f8035be7..fc21d81fad63 100644 --- a/solr/modules/language-models/src/java/org/apache/solr/languagemodels/textvectorisation/model/SolrTextToVectorModel.java +++ b/solr/modules/language-models/src/java/org/apache/solr/languagemodels/textvectorisation/model/SolrTextToVectorModel.java @@ -28,26 +28,22 @@ import org.apache.lucene.util.RamUsageEstimator; import org.apache.solr.common.SolrException; import org.apache.solr.core.SolrResourceLoader; -import org.apache.solr.languagemodels.textvectorisation.store.TextToVectorModelException; +import org.apache.solr.languagemodels.model.SolrLanguageModel; +import org.apache.solr.languagemodels.store.LanguageModelException; import org.apache.solr.languagemodels.textvectorisation.store.rest.ManagedTextToVectorModelStore; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** - * This object wraps a {@link dev.langchain4j.model.embedding.EmbeddingModel} to encode text to - * vector. It's meant to be used as a managed resource with the {@link - * ManagedTextToVectorModelStore} + * This object wraps a {@link EmbeddingModel} to encode text to vector. It's meant to be used as a + * managed resource with the {@link ManagedTextToVectorModelStore} */ -public class SolrTextToVectorModel implements Accountable { +public class SolrTextToVectorModel extends SolrLanguageModel implements Accountable { private static final Logger log = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass()); private static final long BASE_RAM_BYTES = RamUsageEstimator.shallowSizeOfInstance(SolrTextToVectorModel.class); - private static final String TIMEOUT_PARAM = "timeout"; private static final String MAX_SEGMENTS_PER_BATCH_PARAM = "maxSegmentsPerBatch"; - private static final String MAX_RETRIES_PARAM = "maxRetries"; - private final String name; - private final Map params; private final EmbeddingModel textToVector; private final int hashCode; @@ -56,7 +52,7 @@ public static SolrTextToVectorModel getInstance( String className, String name, Map params) - throws TextToVectorModelException { + throws LanguageModelException { try { /* * The idea here is to build a {@link dev.langchain4j.model.embedding.EmbeddingModel} using inversion @@ -125,15 +121,14 @@ public static SolrTextToVectorModel getInstance( textToVector = (EmbeddingModel) builder.getClass().getMethod("build").invoke(builder); return new SolrTextToVectorModel(name, textToVector, params); } catch (final Exception e) { - throw new TextToVectorModelException("Model loading failed for " + className, e); + throw new LanguageModelException("Model loading failed for " + className, e); } } public SolrTextToVectorModel( String name, EmbeddingModel textToVector, Map params) { - this.name = name; + super(name, params); this.textToVector = textToVector; - this.params = params; this.hashCode = calculateHashCode(); } @@ -170,20 +165,12 @@ private int calculateHashCode() { @Override public boolean equals(Object obj) { if (this == obj) return true; - if (!(obj instanceof SolrTextToVectorModel)) return false; - final SolrTextToVectorModel other = (SolrTextToVectorModel) obj; + if (!(obj instanceof SolrTextToVectorModel other)) return false; return Objects.equals(textToVector, other.textToVector) && Objects.equals(name, other.name); } - public String getName() { - return name; - } - - public String getEmbeddingModelClassName() { + @Override + public String getModelClassName() { return textToVector.getClass().getName(); } - - public Map getParams() { - return params; - } } diff --git a/solr/modules/language-models/src/java/org/apache/solr/languagemodels/textvectorisation/store/rest/ManagedTextToVectorModelStore.java b/solr/modules/language-models/src/java/org/apache/solr/languagemodels/textvectorisation/store/rest/ManagedTextToVectorModelStore.java index 70c03ffc47ea..65a656291233 100644 --- a/solr/modules/language-models/src/java/org/apache/solr/languagemodels/textvectorisation/store/rest/ManagedTextToVectorModelStore.java +++ b/solr/modules/language-models/src/java/org/apache/solr/languagemodels/textvectorisation/store/rest/ManagedTextToVectorModelStore.java @@ -16,48 +16,23 @@ */ package org.apache.solr.languagemodels.textvectorisation.store.rest; -import java.lang.invoke.MethodHandles; -import java.util.LinkedHashMap; -import java.util.List; import java.util.Map; -import java.util.stream.Collectors; import net.jcip.annotations.ThreadSafe; import org.apache.solr.common.SolrException; -import org.apache.solr.common.util.NamedList; import org.apache.solr.core.SolrCore; import org.apache.solr.core.SolrResourceLoader; +import org.apache.solr.languagemodels.store.rest.ManagedLanguageModelStore; import org.apache.solr.languagemodels.textvectorisation.model.SolrTextToVectorModel; -import org.apache.solr.languagemodels.textvectorisation.store.TextToVectorModelException; -import org.apache.solr.languagemodels.textvectorisation.store.TextToVectorModelStore; -import org.apache.solr.response.SolrQueryResponse; -import org.apache.solr.rest.BaseSolrResource; -import org.apache.solr.rest.ManagedResource; import org.apache.solr.rest.ManagedResourceObserver; import org.apache.solr.rest.ManagedResourceStorage; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -/** Managed Resource wrapper for the {@link TextToVectorModelStore} to expose it via REST */ +/** Managed Resource wrapper for the text-to-vector model store, exposed via REST */ @ThreadSafe -public class ManagedTextToVectorModelStore extends ManagedResource - implements ManagedResource.ChildResourceSupport { - private static final Logger log = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass()); +public class ManagedTextToVectorModelStore extends ManagedLanguageModelStore { /** the model store rest endpoint */ public static final String REST_END_POINT = "/schema/text-to-vector-model-store"; - /** Managed model store: the name of the attribute containing all the models of a model store */ - private static final String MODELS_JSON_FIELD = "models"; - - /** name of the attribute containing a class */ - static final String CLASS_KEY = "class"; - - /** name of the attribute containing a name */ - static final String NAME_KEY = "name"; - - /** name of the attribute containing parameters */ - static final String PARAMS_KEY = "params"; - public static void registerManagedTextToVectorModelStore( SolrResourceLoader solrResourceLoader, ManagedResourceObserver managedResourceObserver) { solrResourceLoader @@ -70,21 +45,9 @@ public static ManagedTextToVectorModelStore getManagedModelStore(SolrCore core) return (ManagedTextToVectorModelStore) core.getRestManager().getManagedResource(REST_END_POINT); } - /** - * Returns the available models as a list of Maps objects. After an update the managed resources - * needs to return the resources in this format in order to store in json somewhere (zookeeper, - * disk...) - * - * @return the available models as a list of Maps objects - */ - private static List modelsAsManagedResources(List models) { - return models.stream() - .map(ManagedTextToVectorModelStore::toModelMap) - .collect(Collectors.toList()); - } - + @Override @SuppressWarnings("unchecked") - public static SolrTextToVectorModel fromModelMap( + protected SolrTextToVectorModel fromModelMap( SolrResourceLoader solrResourceLoader, Map embeddingModel) { return SolrTextToVectorModel.getInstance( solrResourceLoader, @@ -93,108 +56,9 @@ public static SolrTextToVectorModel fromModelMap( (Map) embeddingModel.get(PARAMS_KEY)); } - private static LinkedHashMap toModelMap(SolrTextToVectorModel model) { - final LinkedHashMap modelMap = new LinkedHashMap<>(5, 1.0f); - modelMap.put(NAME_KEY, model.getName()); - modelMap.put(CLASS_KEY, model.getEmbeddingModelClassName()); - modelMap.put(PARAMS_KEY, model.getParams()); - return modelMap; - } - - private final TextToVectorModelStore store; - private Object managedData; - public ManagedTextToVectorModelStore( String resourceId, SolrResourceLoader loader, ManagedResourceStorage.StorageIO storageIO) throws SolrException { super(resourceId, loader, storageIO); - store = new TextToVectorModelStore(); - } - - @Override - protected ManagedResourceStorage createStorage( - ManagedResourceStorage.StorageIO storageIO, SolrResourceLoader loader) throws SolrException { - return new ManagedResourceStorage.JsonStorage(storageIO, loader, -1); - } - - @Override - protected void onManagedDataLoadedFromStorage(NamedList managedInitArgs, Object managedData) - throws SolrException { - store.clear(); - this.managedData = managedData; - } - - public void loadStoredModels() { - log.info("------ managed models ~ loading ------"); - - if ((managedData != null) && (managedData instanceof List)) { - @SuppressWarnings({"unchecked"}) - final List> textToVectorModels = (List>) managedData; - for (final Map textToVectorModel : textToVectorModels) { - addModelFromMap(textToVectorModel); - } - } - } - - private void addModelFromMap(Map modelMap) { - try { - addModel(fromModelMap(solrResourceLoader, modelMap)); - } catch (final TextToVectorModelException e) { - throw new SolrException(SolrException.ErrorCode.BAD_REQUEST, e); - } - } - - public void addModel(SolrTextToVectorModel model) throws TextToVectorModelException { - try { - if (log.isInfoEnabled()) { - log.info("adding model {}", model.getName()); - } - store.addModel(model); - } catch (final TextToVectorModelException e) { - throw new SolrException(SolrException.ErrorCode.BAD_REQUEST, e); - } - } - - @SuppressWarnings("unchecked") - @Override - protected Object applyUpdatesToManagedData(Object updates) { - if (updates instanceof List) { - final List> textToVectorModels = (List>) updates; - for (final Map textToVectorModel : textToVectorModels) { - addModelFromMap(textToVectorModel); - } - } - - if (updates instanceof Map) { - final Map map = (Map) updates; - addModelFromMap(map); - } - - return modelsAsManagedResources(store.getModels()); - } - - @Override - public void doDeleteChild(BaseSolrResource endpoint, String childId) { - store.delete(childId); - storeManagedData(applyUpdatesToManagedData(null)); - } - - /** - * Called to retrieve a named part (the given childId) of the resource at the given endpoint. - * Note: since we have a unique child managed store we ignore the childId. - */ - @Override - public void doGet(BaseSolrResource endpoint, String childId) { - final SolrQueryResponse response = endpoint.getSolrResponse(); - response.add(MODELS_JSON_FIELD, modelsAsManagedResources(store.getModels())); - } - - public SolrTextToVectorModel getModel(String modelName) { - return store.getModel(modelName); - } - - @Override - public String toString() { - return "ManagedModelStore [store=" + store + "]"; } } diff --git a/solr/modules/language-models/src/test-files/modelExamples/cohere-model.json b/solr/modules/language-models/src/test-files/embeddingModelExamples/cohere-model.json similarity index 100% rename from solr/modules/language-models/src/test-files/modelExamples/cohere-model.json rename to solr/modules/language-models/src/test-files/embeddingModelExamples/cohere-model.json diff --git a/solr/modules/language-models/src/test-files/modelExamples/dummy-model-ambiguous.json b/solr/modules/language-models/src/test-files/embeddingModelExamples/dummy-model-ambiguous.json similarity index 100% rename from solr/modules/language-models/src/test-files/modelExamples/dummy-model-ambiguous.json rename to solr/modules/language-models/src/test-files/embeddingModelExamples/dummy-model-ambiguous.json diff --git a/solr/modules/language-models/src/test-files/modelExamples/dummy-model-unsupported.json b/solr/modules/language-models/src/test-files/embeddingModelExamples/dummy-model-unsupported.json similarity index 100% rename from solr/modules/language-models/src/test-files/modelExamples/dummy-model-unsupported.json rename to solr/modules/language-models/src/test-files/embeddingModelExamples/dummy-model-unsupported.json diff --git a/solr/modules/language-models/src/test-files/modelExamples/dummy-model.json b/solr/modules/language-models/src/test-files/embeddingModelExamples/dummy-model.json similarity index 100% rename from solr/modules/language-models/src/test-files/modelExamples/dummy-model.json rename to solr/modules/language-models/src/test-files/embeddingModelExamples/dummy-model.json diff --git a/solr/modules/language-models/src/test-files/modelExamples/exception-throwing-model.json b/solr/modules/language-models/src/test-files/embeddingModelExamples/exception-throwing-model.json similarity index 100% rename from solr/modules/language-models/src/test-files/modelExamples/exception-throwing-model.json rename to solr/modules/language-models/src/test-files/embeddingModelExamples/exception-throwing-model.json diff --git a/solr/modules/language-models/src/test-files/modelExamples/huggingface-model.json b/solr/modules/language-models/src/test-files/embeddingModelExamples/huggingface-model.json similarity index 100% rename from solr/modules/language-models/src/test-files/modelExamples/huggingface-model.json rename to solr/modules/language-models/src/test-files/embeddingModelExamples/huggingface-model.json diff --git a/solr/modules/language-models/src/test-files/modelExamples/mistralai-model.json b/solr/modules/language-models/src/test-files/embeddingModelExamples/mistralai-model.json similarity index 100% rename from solr/modules/language-models/src/test-files/modelExamples/mistralai-model.json rename to solr/modules/language-models/src/test-files/embeddingModelExamples/mistralai-model.json diff --git a/solr/modules/language-models/src/test-files/modelExamples/openai-model.json b/solr/modules/language-models/src/test-files/embeddingModelExamples/openai-model.json similarity index 100% rename from solr/modules/language-models/src/test-files/modelExamples/openai-model.json rename to solr/modules/language-models/src/test-files/embeddingModelExamples/openai-model.json diff --git a/solr/modules/language-models/src/test/org/apache/solr/languagemodels/TestLanguageModelBase.java b/solr/modules/language-models/src/test/org/apache/solr/languagemodels/TestLanguageModelBase.java index a54e8e1875d5..240a75464f99 100644 --- a/solr/modules/language-models/src/test/org/apache/solr/languagemodels/TestLanguageModelBase.java +++ b/solr/modules/language-models/src/test/org/apache/solr/languagemodels/TestLanguageModelBase.java @@ -38,7 +38,7 @@ public class TestLanguageModelBase extends RestTestBase { protected static Path tmpSolrHome; protected static Path tmpConfDir; - public static final String MODEL_FILE_NAME = "_schema_text-to-vector-model-store.json"; + public static final String EMBEDDING_MODEL_FILE_NAME = "_schema_text-to-vector-model-store.json"; protected static final String COLLECTION = "collection1"; protected static final String CONF_DIR = COLLECTION + "/conf"; @@ -61,17 +61,17 @@ protected static void initFolders(boolean isPersistent) throws Exception { tmpSolrHome = createTempDir(); tmpConfDir = tmpSolrHome.resolve(CONF_DIR); PathUtils.copyDirectory(TEST_PATH(), tmpSolrHome.toAbsolutePath()); - final Path modelStore = tmpConfDir.resolve(MODEL_FILE_NAME); + final Path embeddingStore = tmpConfDir.resolve(EMBEDDING_MODEL_FILE_NAME); if (isPersistent) { - embeddingModelStoreFile = modelStore; + embeddingModelStoreFile = embeddingStore; } - if (Files.exists(modelStore)) { + if (Files.exists(embeddingStore)) { if (log.isInfoEnabled()) { - log.info("remove model store config file in {}", modelStore.toAbsolutePath()); + log.info("remove model store config file in {}", embeddingStore.toAbsolutePath()); } - Files.delete(modelStore); + Files.delete(embeddingStore); } System.setProperty("managed.schema.mutable", "true"); @@ -87,7 +87,7 @@ protected static void afterTest() throws Exception { } public static void loadModel(String fileName, String status) throws Exception { - final URL url = TestLanguageModelBase.class.getResource("/modelExamples/" + fileName); + final URL url = TestLanguageModelBase.class.getResource("/embeddingModelExamples/" + fileName); final String multipleModels = Files.readString(Path.of(url.toURI()), StandardCharsets.UTF_8); assertJPut( @@ -97,7 +97,7 @@ public static void loadModel(String fileName, String status) throws Exception { } public static void loadModel(String fileName) throws Exception { - final URL url = TestLanguageModelBase.class.getResource("/modelExamples/" + fileName); + final URL url = TestLanguageModelBase.class.getResource("/embeddingModelExamples/" + fileName); final String multipleModels = Files.readString(Path.of(url.toURI()), StandardCharsets.UTF_8); assertJPut( diff --git a/solr/modules/language-models/src/test/org/apache/solr/languagemodels/textvectorisation/store/rest/TestManagedModelStoreInitialization.java b/solr/modules/language-models/src/test/org/apache/solr/languagemodels/textvectorisation/store/rest/TestManagedLanguageModelStoreInitialization.java similarity index 96% rename from solr/modules/language-models/src/test/org/apache/solr/languagemodels/textvectorisation/store/rest/TestManagedModelStoreInitialization.java rename to solr/modules/language-models/src/test/org/apache/solr/languagemodels/textvectorisation/store/rest/TestManagedLanguageModelStoreInitialization.java index 244094b8764e..c384ff50f00d 100644 --- a/solr/modules/language-models/src/test/org/apache/solr/languagemodels/textvectorisation/store/rest/TestManagedModelStoreInitialization.java +++ b/solr/modules/language-models/src/test/org/apache/solr/languagemodels/textvectorisation/store/rest/TestManagedLanguageModelStoreInitialization.java @@ -20,7 +20,7 @@ import org.junit.After; import org.junit.Test; -public class TestManagedModelStoreInitialization extends TestLanguageModelBase { +public class TestManagedLanguageModelStoreInitialization extends TestLanguageModelBase { @After public void cleanUp() throws Exception { diff --git a/solr/modules/language-models/src/test/org/apache/solr/languagemodels/textvectorisation/store/rest/TestModelManagerPersistence.java b/solr/modules/language-models/src/test/org/apache/solr/languagemodels/textvectorisation/store/rest/TestModelManagerPersistence.java index 92e8b68244e6..81988903f065 100644 --- a/solr/modules/language-models/src/test/org/apache/solr/languagemodels/textvectorisation/store/rest/TestModelManagerPersistence.java +++ b/solr/modules/language-models/src/test/org/apache/solr/languagemodels/textvectorisation/store/rest/TestModelManagerPersistence.java @@ -38,15 +38,6 @@ public void cleanup() throws Exception { afterTest(); } - @Test - public void testModelAreStoredCompact() throws Exception { - loadModel("cohere-model.json"); - - final String JSONOnDisk = Files.readString(embeddingModelStoreFile, StandardCharsets.UTF_8); - Object objectFromDisk = Utils.fromJSONString(JSONOnDisk); - assertEquals(new String(Utils.toJSON(objectFromDisk, -1), UTF_8), JSONOnDisk); - } - @Test public void testModelStorePersistence() throws Exception { // check models are empty