Skip to content
This repository was archived by the owner on Jan 8, 2026. It is now read-only.
41 changes: 39 additions & 2 deletions gramex/handlers/mlhandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@
from slugify import slugify
from tornado.gen import coroutine
from tornado.web import HTTPError
from sklearn.metrics import get_scorer
from sklearn.model_selection import cross_val_predict, cross_val_score
from sklearn.model_selection import cross_val_predict, cross_val_score
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line appears twice.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This extra line is unnecessary.

from ast import literal_eval
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should not be required.


op = os.path
MLCLASS_MODULES = [
Expand All @@ -40,6 +44,8 @@
'nums': [],
'cats': [],
'target_col': None,
'CV': True,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make it lowercase.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have to support three cases for the cv option:

  • If the user sets cv: false - then no cross validation happens
  • If the user sets cv: 4 (or some other integer) pass it straight to cross_val_score
  • The default should be cv: None, and in this case, the user should not have to write anything in gramex.yaml

'CVargs': []
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's have a single argument, cv, which can take any value, i.e in gramex.yaml, users should be able to write any of the following.

cv: false   # disable cross validation
cv: 5        # Use 5 folds
cv:
  cv: 8   # Use 8 folds
  n_jobs: -1  # with an optional other parameter.

}
ACTIONS = ['predict', 'score', 'append', 'train', 'retrain']
DEFAULT_TEMPLATE = op.join(op.dirname(__file__), '..', 'apps', 'mlhandler', 'template.html')
Expand Down Expand Up @@ -103,7 +109,6 @@ def setup(cls, data=None, model={}, config_dir='', **kwargs):

cls.set_opt('class', model.get('class'))
cls.set_opt('params', model.get('params', {}))

if op.exists(cls.model_path): # If the pkl exists, load it
cls.model = joblib.load(cls.model_path)
elif data is not None:
Expand All @@ -112,14 +117,38 @@ def setup(cls, data=None, model={}, config_dir='', **kwargs):
data = cls._filtercols(data)
data = cls._filterrows(data)
cls.model = cls._assemble_pipeline(data, mclass=mclass, params=params)

# train the model
target = data[target_col]
train = data[[c for c in data if c != target_col]]
# cross validation
print('yayyy we are here')
cls.CrossValidation(train,target)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make it lowercase.

print('should have printed')
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove the prints.

gramex.service.threadpool.submit(
_fit, cls.model, train, target, cls.model_path, cls.name)
cls.config_store.flush()

@classmethod
def modelFunction(cls, mclass = ''):
model_kwargs = cls.config_store.load('model', {})
mclass = model_kwargs.get('class', False)
if mclass:
model = search_modelclass(mclass)(**model_kwargs.get('params', {}))
return model
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function is not required.


@classmethod
def CrossValidation(cls,train,target):
mod = cls.modelFunction()
CV = cls.get_opt('CV') #can edit to make CV true/false etc.
if CV:
CVargs = cls.get_opt('CVargs')
if CVargs:
CVscore = cross_val_score(mod, X=train, y=target, **literal_eval(json.dumps(CVargs)))
else:
CVscore = cross_val_score(mod, train, target)
CV = sum(CVscore)/len(CVscore)
print('CV score: ', CV)

@classmethod
def load_data(cls, default=pd.DataFrame()):
try:
Expand Down Expand Up @@ -268,6 +297,10 @@ def _predict(self, data=None, score_col=''):
self.model = cache.open(self.model_path, joblib.load)
try:
target = data.pop(score_col)
metric = self.get_argument('_metric', False)
if metric:
scorer = get_scorer(metric)
return scorer(self.model, data, target)
return self.model.score(data, target)
except KeyError:
# Set data in the same order as the transformer requests
Expand Down Expand Up @@ -347,6 +380,8 @@ def _train(self, data=None):
target = data[target_col]
train = data[[c for c in data if c != target_col]]
self.model = self._assemble_pipeline(data, force=True)
print('IN TRAIN')
self.CrossValidation(train,target)
_fit(self.model, train, target, self.model_path)
return {'score': self.model.score(train, target)}

Expand All @@ -357,6 +392,8 @@ def _score(self):
self._check_model_path()
data = self._parse_data(False)
target_col = self.get_argument('target_col', self.get_opt('target_col'))
print('IN _SCORE')
#self.CrossValidation(data,target_col)
self.set_opt('target_col', target_col)
return {'score': self._predict(data, target_col)}

Expand Down
5 changes: 5 additions & 0 deletions tests/test_mlhandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,11 @@ def test_get_bulk_score(self):
data=self.df.to_json(orient='records'),
headers={'Content-Type': 'application/json'})
self.assertGreaterEqual(resp.json()['score'], self.ACC_TOL)
resp = self.get(
'/mlhandler?_action=score&_metric=f1_weighted', method='post',
data=self.df.to_json(orient='records'),
headers={'Content-Type': 'application/json'})
self.assertGreaterEqual(resp.json()['score'], self.ACC_TOL)

def test_get_cache(self):
df = pd.DataFrame.from_records(self.get('/mlhandler?_cache=true').json())
Expand Down